Сохранение & Загрузка моделей
Тема дорожной карты · Глубокое обучение
Сохранение и загрузка моделей — это критически важные этапы в процессе обучения глубоких нейронных сетей. Они позволяют сохранить текущее состояние модели, оптимизатора, шедулера и других важных компонентов, что делает возможным восстановление обучения после перезагрузки системы или сбоя. Важно использовать правильные методы для сохранения и загрузки, чтобы избежать потери прогресса и обеспечить воспроизводимость результатов.
Сохранение состояния модели обычно осуществляется через state_dict (Python-словарь тензоров), который содержит все параметры модели. Это позволяет переносить модель между различными версиями кода без потери обучения. В PyTorch это делается с помощью команд torch.save(model.state_dict(), path) для сохранения и model.load_state_dict(torch.load(path)) для загрузки. Для полноценного чекпоинта необходимо также сохранять состояние оптимизатора, шедулера и случайного числового генератора, чтобы корректно возобновить обучение. В TensorFlow/Keras используются методы model.save() для сохранения модели в формате SavedModel или однофайлового файла .keras.
Для удобства шаринга моделей можно использовать платформы, такие как Hugging Face Hub, или экспортировать модель в форматы ONNX или safetensors, что обеспечивает безопасную загрузку и использование модели на других системах.
Как это работает
Сохранение и загрузка моделей охватывают важные аспекты инженерной реальности обучения: параллельная загрузка данных с помощью Dataset и DataLoader, цикл обучения (forward pass, вычисление потерь, backward pass, шаги оптимизатора, обнуление градиентов), сохранение чекпоинтов (state dict модели, состояние оптимизатора, состояние шедулера, состояние случайного числового генератора), загрузка модели позже, отладка моделей, которые не сходятся. Используются различные инструменты, такие как torchinfo для просмотра формы тензоров, tqdm для отслеживания прогресса, wandb или MLflow для отслеживания метрик, git-lfs или DVC для версионирования данных.
Когда применять
Сохранение и загрузка моделей особенно полезны в следующих ситуациях: использование тренеров (Lightning, HuggingFace) вместо ручного управления циклом обучения, что помогает избежать ошибок в цикле; частое сохранение чекпоинтов (например, каждые несколько эпох или после каждого батча), чтобы предотвратить потерю обучения при сбое системы; версионирование всех компонентов обучения (код, данные, гиперпараметры, среда), что обеспечивает воспроизводимость результатов; отладка моделей глубокой нейронной сети как любого другого кода, начиная с переобучения на одном батче (если это невозможно, значит, что-то сломано в пайплайне).
Типичные ошибки
Типичными ошибками при сохранении и загрузке моделей являются: забывание вызвать optimizer.zero_grad() (что приводит к аккумулированию градиентов и потере обучения); не установка правильного режима работы модели (например, model.train() или model.eval()), что может привести к неправильному поведению BatchNorm или Dropout; загрузка чекпоинта без соответствующего состояния случайного числового генератора (что приводит к другим результатам); развертывание экспериментов без оценочного набора данных (что делает невозможным оценку качества модели).