引言

在深度学习项目中,模型保存是一个看似简单但极其关键的环节。一个小小的疏忽可能导致数小时甚至数天的训练成果付诸东流。本文将深入探讨深度学习模型保存的最佳实践、高级技巧以及常见陷阱,帮助您构建可靠的模型持久化策略。

1. 模型保存的基础知识

1.1 为什么要重视模型保存

模型保存不仅仅是将训练结果写入磁盘,它还关系到:

  • 训练中断恢复:长时间训练过程中,硬件故障、断电或人为中断时,能够从最近的检查点恢复
  • 模型版本管理:对比不同版本的模型性能,选择最优版本
  1. 部署与分享:将训练好的模型交付给生产环境或团队成员
  • 调试与分析:保存中间状态用于分析训练过程中的问题

1.2 常见的模型保存格式

PyTorch 格式 (.pt/.pth)

PyTorch 主要使用 .pt.pth 扩展名,保存的是模型的序列化状态字典(state_dict)。

import torch
import torch.nn as nn

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __i

TensorFlow/Keras 格式 (.h5/.keras)

Keras 提供了 .h5.keras 格式,保存整个模型或仅权重。

import tensorflow as tf
from tensorflow import keras

# 定义一个简单的模型
model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    keras.layers.Dense(10, activation='softmax')
])

# 保存整个模型(包括架构、权重和优化器状态)
model.save('my_model.keras')

1.3 保存什么?(权重、架构、优化器状态)

一个完整的模型保存通常包括:

  • 模型权重:训练得到的参数
  • 模型架构:模型的结构定义
  • 优化器状态:Adam 等优化器的内部状态(momentum, learning rate schedule 等)
  • 训练状态:epoch 数、学习率、自定义指标等

2. 基础保存技巧

2.1 保存完整模型 vs 仅保存权重

PyTorch 实践

# 方法1:仅保存权重(推荐)
torch.save(model.state_dict(), 'model_weights.pth')

# 方法2:保存完整模型(包括架构)
torch.save(model, 'full_model.pth')

# 方法3:保存检查点(包含更多训练信息)
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'accuracy': accuracy
}
torch.save(checkpoint, 'checkpoint.pth')

对比分析

  • 仅保存权重:文件小、通用性强、便于迁移学习,但需要原始模型定义
  • 保存完整模型:文件大、依赖原始代码,但加载简单
  • 保存检查点:最灵活,支持训练恢复,但需要手动管理

Keras 实践

# 方法1:仅保存权重
model.save_weights('model_weights.h5')

# 方法2:保存完整模型(推荐用于部署)
model.save('full_model.keras')

# 方法3:保存为 SavedModel 格式(TensorFlow 特有)
tf.saved_model.save(model, 'saved_model_dir')

2.2 最佳实践:使用检查点(Checkpointing)

检查点是深度学习训练中的最佳实践,它允许我们在训练过程中定期保存模型状态。

PyTorch 检查点实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os

# 定义模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.fc1 = nn.Linear(32 * 26 * 26, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# 训练循环中的检查点保存
def train_with_checkpoints(model, train_loader, optimizer, epochs=10, save_dir='./checkpoints'):
    os.makedirs(save_dir, exist_ok=True)
    best_accuracy = 0.0
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        # 验证阶段
        accuracy = validate(model, val_loader)
        
        # 保存常规检查点
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss / len(train_loader),
            'accuracy': accuracy,
        }
        
        # 保存最新检查点
        torch.save(checkpoint, os.path.join(save_dir, 'latest_checkpoint.pth'))
        
        # 保存最佳检查点
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(checkpoint, os.path.join(save_dir, 'best_model.pth'))
            print(f"New best model saved with accuracy: {accuracy:.4f}")
        
        # 每5个epoch保存一个历史检查点
        if epoch % 5 == 0:
            torch.save(checkpoint, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth'))
        
        print(f"Epoch {epoch}: Loss = {total_loss / len(train_loader):.4f}, Accuracy = {accuracy:.4f}")

# 恢复训练
def resume_training(model, checkpoint_path, train_loader, optimizer):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    previous_accuracy = checkpoint['accuracy']
    
    print(f"Resuming training from epoch {start_epoch}")
    print(f"Previous accuracy: {previous_accuracy:.4f}")
    
    return start_epoch, previous_accuracy

# 使用示例
model = CNNModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 如果有检查点,恢复训练
checkpoint_path = './checkpoints/latest_checkpoint.pth'
if os.path.exists(checkpoint_path):
    start_epoch, prev_acc = resume_training(model, checkpoint_path, train_loader, optimizer)
else:
    start_epoch = 0

train_with_checkpoints(model, train_loader, optimizer, epochs=10, save_dir='./checkpoints')

Keras 回调函数实现

Keras 提供了强大的 ModelCheckpoint 回调函数,可以自动管理检查点。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# 定义模型
def create_model():
    model = keras.Sequential([
        keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# 配置多个检查点回调
checkpoint_callback = ModelCheckpoint(
    filepath='checkpoints/epoch_{epoch:02d}_val_acc_{val_accuracy:.2f}.keras',
    save_weights_only=False,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True,
    verbose=1
)

# 保存最新模型
latest_checkpoint = ModelCheckpoint(
    filepath='checkpoints/latest_model.keras',
    save_weights_only=False,
    save_freq='epoch'  # 每个epoch保存一次
)

# 早期停止回调(配合检查点)
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    patience=5,
    restore_best_weights=True
)

# 训练模型
model = create_model()
history = model.fit(
    train_images, train_labels,
    epochs=50,
    validation_data=(val_images, val_labels),
    callbacks=[checkpoint_callback, latest_checkpoint, early_stopping]
)

2.3 自动保存策略

基于时间的保存

import datetime

def save_with_timestamp(model, prefix='model'):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{prefix}_{timestamp}.pth"
    torch.save(model.state_dict(), filename)
    return filename

基于指标的保存

def save_on_metric(model, current_metric, best_metric, metric_name='accuracy'):
    if current_metric > best_metric:
        best_metric = current_metric
        torch.save(model.state_dict(), f'best_{metric_name}.pth')
        print(f"New best {metric_name}: {best_metric:.4f}")
    return best_metric

3. 高级保存技巧

3.1 分布式训练中的模型保存

在分布式训练中,需要特别注意模型保存的方式。

PyTorch DistributedDataParallel (DDP)

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def save_distributed_model(model, optimizer, epoch, save_path):
    # 只在 rank 0 保存
    if dist.get_rank() == 0:
        # DDP 模型需要先获取内部模型
        if isinstance(model, DDP):
            model_state = model.module.state_dict()
        else:
            model_state = model.state_dict()
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model_state,
            'optimizer_state_dict': optimizer.state_dict(),
        }
        torch.save(checkpoint, save_path)
        print(f"Model saved to {save_path}")

def load_distributed_model(model, optimizer, checkpoint_path):
    # 加载时需要映射到当前设备
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    
    # DDP 模型需要特殊处理
    if isinstance(model, DDP):
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

TensorFlow 分布式策略

import tensorflow as tf

# 创建分布式策略
strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = create_model()
    optimizer = tf.keras.optimizers.Adam()

# 在分布式策略下保存
def save_distributed_tf_model(model, save_path):
    # 保存为 SavedModel 格式
    tf.saved_model.save(model, save_path)
    # 或者保存为 Keras 格式
    model.save(save_path + '.keras')

# 加载分布式模型
def load_distributed_tf_model(save_path):
    # 加载 SavedModel
    loaded_model = tf.saved_model.load(save_path)
    # 或者加载 Keras 模型
    loaded_model = keras.models.load_model(save_path + '.keras')
    return loaded_model

3.2 混合精度训练中的模型保存

混合精度训练(Mixed Precision Training)使用 float16 和 float32 混合,能显著提升训练速度并减少显存占用。

import torch
from torch.cuda.amp import autocast, GradScaler

def train_mixed_precision(model, train_loader, optimizer, epochs=10):
    # 初始化 GradScaler
    scaler = GradScaler()
    
    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            
            optimizer.zero_grad()
            
            # 前向传播使用 autocast
            with autocast():
                output = model(data)
                loss = nn.CrossEntropyLoss()(output, target)
            
            # 反向传播使用 scaler
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        # 保存模型时,自动处理精度转换
        if epoch % 5 == 0:
            # 保存的模型会自动转换为 float32
            torch.save(model.state_dict(), f'model_mixed_precision_epoch_{epoch}.pth')
            print(f"Saved mixed precision model at epoch {epoch}")

# 加载混合精度模型(无需特殊处理)
model = CNNModel().cuda()
model.load_state_dict(torch.load('model_mixed_precision_epoch_5.pth'))

3.3 模型压缩与量化保存

PyTorch 量化保存

import torch.quantization as quantization

def quantize_and_save_model(model, save_path):
    # 准备量化
    model.eval()
    model.qconfig = quantization.get_default_qconfig('fbgemm')
    quantized_model = quantization.prepare(model, inplace=False)
    
    # 校准(使用少量数据)
    with torch.no_grad():
        for data, _ in calibration_loader:
            quantized_model(data)
    
    # 转换为量化模型
    quantized_model = quantization.convert(quantized_model, inplace=False)
    
    # 保存量化模型
    torch.save(quantized_model.state_dict(), save_path)
    print(f"Quantized model saved. Size reduction: {get_model_size(model)} -> {get_model_size(quantized_model)}")

def get_model_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / 1e6
    os.remove("temp.p")
    return f"{size:.2f} MB"

TensorFlow 量化保存

import tensorflow as tf

def quantize_and_save_tf_model(model, save_path):
    # 转换为 TensorFlow Lite 量化模型
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    
    # 设置量化配置
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    
    # 转换并保存
    tflite_model = converter.convert()
    with open(save_path, 'bwbwbw

3.4 跨框架模型保存与 ONNX

ONNX (Open Neural Network Exchange) 是一种开放格式,允许模型在不同框架间转换。

import torch
import onnx
import onnxruntime as ort

def convert_to_onnx(model, input_sample, save_path):
    """
    将 PyTorch 模型转换为 ONNX 格式
    """
    model.eval()
    
    # 导出为 ONNX
    torch.onnx.export(
        model,
        input_sample,
        save_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )
    
    print(f"Model converted to ONNX format: {save_path}")
    
    # 验证 ONNX 模型
    onnx_model = onnx.load(save_path)
    onnx.checker.check_model(onnx_model)
    print("ONNX model validation passed")

def run_onnx_inference(onnx_path, input_data):
    """
    使用 ONNX Runtime 运行推理
    """
    session = ort.InferenceSession(onnx_path)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    
    result = session.run([output_name], {input_name: input_data.numpy()})
    return result

# 使用示例
model = CNNModel()
input_sample = torch.randn(1, 1, 28, 28)
convert_to_onnx(model, input_sample, 'model.onnx')

# 加载并运行
result = run_onnx_inference('model.onnx', input_sample)

4. 常见陷阱与避免方法

4.1 陷阱1:忘记保存优化器状态

问题:只保存模型权重,恢复训练时优化器状态丢失,导致学习率、动量等信息丢失。

解决方案

# ❌ 错误做法
torch.save(model.state_dict(), 'model.pth')  # 丢失优化器状态

# ✅ 正确做法
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
    'scaler_state_dict': scaler.state_dict() if scaler else None,  # 混合精度
}
torch.save(checkpoint, 'checkpoint.pth')

4.2 陷阱2:模型架构变更后加载权重

问题:修改模型架构后,直接加载旧权重会失败。

解决方案

# 方法1:部分加载(推荐)
def load_partial_weights(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model_dict = model.state_dict()
    
    # 过滤匹配的键
    pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() 
                      if k in model_dict and v.shape == model_dict[k].shape}
    
    # 更新模型字典
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    
    # 报告加载情况
    print(f"Loaded {len(pretrained_dict)}/{len(model_dict)} parameters")
    return model

# 方法2:使用 strict=False
model.load_state_dict(checkpoint['model_state_dict'], strict=False)

4.3 陷阱3:设备不匹配问题

问题:在 GPU 上训练的模型直接加载到 CPU 上,或反之。

解决方案

# ✅ 安全加载方法
def safe_load(model, checkpoint_path, device='cpu'):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # 如果是完整检查点
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    return model

# 使用示例
model = CNNModel()
# 自动映射到当前可用设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
safe_load(model, 'checkpoint.pth', device)

4.4 陷阱4:文件权限与磁盘空间

问题:磁盘空间不足或权限问题导致保存失败。

解决方案

import shutil
import os

def safe_save(model, save_path, max_checkpoints=5):
    # 检查磁盘空间(至少保留 1GB)
    free_space = shutil.disk_usage(os.path.dirname(save_path)).free / (1024**3)
    if free_space < 1:
        # 删除旧的检查点
        cleanup_old_checkpoints(os.path.dirname(save_path), max_checkpoints)
    
    # 检查路径权限
    try:
        # 保存到临时文件,然后原子性重命名
        temp_path = save_path + '.tmp'
        torch.save(model.state_dict(), temp_path)
        os.replace(temp_path, save_path)  # 原子操作
        print(f"Successfully saved to {save_path}")
    except Exception as e:
        print(f"Save failed: {e}")
        if os.path.exists(temp_path):
            os.remove(temp_path)

def cleanup_old_checkpoints(checkpoint_dir, keep_last=5):
    """保留最新的 N 个检查点,删除旧的"""
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
    checkpoints.sort(key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x)))
    
    if len(checkpoints) > keep_last:
        for old_checkpoint in checkpoints[:-keep_last]:
            os.remove(os.path.join(checkpoint_dir, old_checkpoint))
            print(f"Removed old checkpoint: {old_checkpoint}")

4.5 陷阱5:多线程/多进程中的竞态条件

问题:多个进程同时写入同一文件导致损坏。

解决方案

import fcntl
import contextlib

@contextlib.contextmanager
def file_lock(lock_path, timeout=10):
    """文件锁,防止多进程同时写入"""
    lock_file = open(lock_path, 'w')
    try:
        # 尝试获取锁
        fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
        yield
    except BlockingIOError:
        raise TimeoutError(f"Could not acquire lock within {timeout} seconds")
    finally:
        fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
        lock_file.close()

def save_with_lock(model, save_path):
    lock_path = save_path + '.lock'
    with file_lock(lock_path):
        torch.save(model.state_dict(), save_path)

4.6 陷阱6:版本不兼容

问题:PyTorch/TensorFlow 版本升级后,旧模型无法加载。

解决方案

# 保存时记录版本信息
import torch
import sys

def save_with_version_info(model, save_path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'pytorch_version': torch.__version__,
        'python_version': sys.version,
        'model_architecture': str(model.__class__.__name__),
        'save_date': datetime.datetime.now().isoformat(),
    }
    torch.save(checkpoint, save_path)

def load_with_version_check(save_path, model):
    checkpoint = torch.load(save_path)
    
    # 检查版本
    if 'pytorch_version' in checkpoint:
        saved_version = checkpoint['pytorch_version']
        current_version = torch.__version__
        print(f"Model saved with PyTorch {saved_version}, current: {current_version}")
        
        # 如果版本差异大,给出警告
        if saved_version != current_version:
            print("⚠️  Version mismatch! Model may not load correctly.")
    
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

4.7 陷阱7:自定义层/函数的序列化问题

问题:使用 lambda 函数或自定义类时,pickle 可能失败。

解决方案

# ❌ 错误:使用 lambda
model = nn.Sequential(
    nn.Linear(10, 20),
    lambda x: torch.relu(x)  # 无法序列化
)

# ✅ 正确:使用 nn.Module 包装
class ReLULayer(nn.Module):
    def forward(self, x):
        return torch.relu(x)

model = nn.Sequential(
    nn.Linear(10, 20),
    ReLULayer()
)

# 对于自定义类,确保实现 __getstate__ 和 __setstate__
class CustomLayer(nn.Module):
    def __init__(self, param):
        super().__init__()
        self.param = nn.Parameter(torch.tensor(param))
    
    def forward(self, x):
        return x * self.param
    
    def __getstate__(self):
        # 自定义序列化
        state = self.state_dict()
        state['custom_info'] = 'some_info'
        return state
    
    def __setstate__(self, state):
        # 自定义反序列化
        self.load_state_dict(state)
        # 恢复其他属性
        self.custom_info = state.get('custom_info', '')

5. 模型版本管理与实验跟踪

5.1 使用 MLflow 进行版本管理

import mlflow
import mlflow.pytorch

def train_with_mlflow_tracking(model, train_loader, optimizer, epochs=10):
    mlflow.set_experiment("mnist_cnn")
    
    with mlflow.start_run():
        # 记录参数
        mlflow.log_param("epochs", epochs)
        mllog_param("learning_rate", optimizer.param_groups[0]['lr'])
        
        for epoch in range(epochs):
            # 训练代码...
            loss = train_one_epoch(model, train_loader, optimizer)
            accuracy = validate(model, val_loader)
            
            # 记录指标
            mlflow.log_metric("loss", loss, step=epoch)
            mlflow.log_metric("accuracy", accuracy, step=epoch)
            
            # 自动保存模型到 MLflow
            if accuracy > 0.95:
                mlflow.pytorch.log_model(model, f"models/epoch_{epoch}")
                mlflow.log_artifact("model_architecture.py")
        
        # 保存最终模型
        mlflow.pytorch.log_model(model, "final_model")
        print(f"Model logged to MLflow: {mlflow.active_run().info.run_id}")

# 加载 MLflow 模型
def load_mlflow_model(run_id, model_uri=""):
    model_uri = f"runs:/{run_id}/final_model"
    loaded_model = mlflow.pytorch.load_model(model_uri)
    return loaded_model

5.2 使用 Weights & Biases (wandb) 进行跟踪

import wandb

def train_with_wandb(model, train_loader, optimizer, epochs=10):
    # 初始化 wandb
    wandb.init(
        project="mnist-cnn",
        config={
            "learning_rate": 0.001,
            "architecture": "CNN",
            "dataset": "MNIST",
            "epochs": epochs,
        }
    )
    
    # 将模型注册为 wandb Artifact
    wandb.watch(model, log="all", log_freq=100)
    
    for epoch in range(epochs):
        # 训练...
        loss = train_one_epoch(model, train_loader, optimizer)
        accuracy = validate(model, val_loader)
        
        # 记录指标
        wandb.log({
            "train_loss": loss,
            "val_accuracy": accuracy,
            "epoch": epoch
        })
        
        # 保存模型作为 artifact
        if epoch % 5 == 0:
            torch.save(model.state_dict(), f'model_epoch_{epoch}.pth')
            artifact = wandb.Artifact(
                f'model-epoch-{epoch}',
                type='model',
                description=f'Model at epoch {epoch}'
            )
            artifact.add_file(f'model_epoch_{epoch}.pth')
            wandb.log_artifact(artifact)
    
    # 保存最终模型
    torch.save(model.state_dict(), 'final_model.pth')
    artifact = wandb.Artifact(
        'final-model',
        type='model',
        description='Final trained model'
    )
    artifact.add_file('final_model.pth')
    wandb.log_artifact(artifact)
    
    wandb.finish()

# 加载 wandb 模型
def load_wandb_model(run_path, artifact_name="final-model:latest"):
    api = wandb.Api()
    artifact = api.artifact(run_path + "/" + artifact_name)
    artifact_dir = artifact.download()
    
    model = CNNModel()
    model.load_state_dict(torch.load(artifact_dir + "/final_model.pth"))
    return model

5.3 自定义版本管理脚本

import json
import hashlib

class ModelVersionManager:
    def __init__(self, base_dir='./model_registry'):
        self.base_dir = base_dir
        os.makedirs(base_dir, exist_ok=True)
        self.metadata_file = os.path.join(base_dir, 'metadata.json')
        self.metadata = self.load_metadata()
    
    def load_metadata(self):
        if os.path.exists(self.metadata_file):
            with open(self.metadata_file, 'r') as f:
                return json.load(f)
        return {}
    
    def save_metadata(self):
        with open(self.metadata_file, 'w') as f:
            json.dump(self.metadata, f, indent=2)
    
    def register_model(self, model, metrics, hyperparams, tags=None):
        # 计算模型哈希
        model_hash = hashlib.md5(str(model.state_dict()).encode()).hexdigest()[:8]
        
        # 创建版本
        version = {
            'hash': model_hash,
            'metrics': metrics,
            'hyperparams': hyperparams,
            'tags': tags or [],
            'timestamp': datetime.datetime.now().isoformat(),
            'file_path': f'model_{model_hash}.pth'
        }
        
        # 保存模型
        save_path = os.path.join(self.base_dir, version['file_path'])
        torch.save(model.state_dict(), save_path)
        
        # 更新元数据
        version_id = f"v{len(self.metadata) + 1}"
        self.metadata[version_id] = version
        self.save_metadata()
        
        print(f"Model registered as {version_id} (hash: {model_hash})")
        return version_id
    
    def get_best_model(self, metric='accuracy', mode='max'):
        """获取最佳模型"""
        if not self.metadata:
            return None
        
        best_version = None
        best_value = -float('inf') if mode == 'max' else float('inf')
        
        for version_id, info in self.metadata.items():
            value = info['metrics'].get(metric)
            if value is None:
                continue
            
            if mode == 'max' and value > best_value:
                best_value = value
                best_version = version_id
            elif mode == 'min' and value < best_value:
                best_value = value
                best_version = version_id
        
        return best_version, best_value
    
    def load_model(self, version_id, model_class):
        """加载指定版本的模型"""
        if version_id not in self.metadata:
            raise ValueError(f"Version {version_id} not found")
        
        info = self.metadata[version_id]
        model_path = os.path.join(self.base_dir, info['file_path'])
        
        model = model_class()
        model.load_state_dict(torch.load(model_path))
        
        return model, info

# 使用示例
manager = ModelVersionManager()

# 训练并注册多个模型
for i in range(3):
    model = CNNModel()
    optimizer = optim.Adam(model.parameters(), lr=0.001 * (i+1))
    
    # 训练...
    accuracy = 0.9 + i * 0.02  # 模拟不同性能
    
    version_id = manager.register_model(
        model,
        metrics={'accuracy': accuracy, 'loss': 1.0 - i*0.1},
        hyperparams={'lr': 0.001 * (i+1), 'batch_size': 64},
        tags=['experiment', f'run_{i}']
    )

# 获取最佳模型
best_version, best_accuracy = manager.get_best_model('accuracy')
print(f"Best model: {best_version} with accuracy {best_accuracy}")

# 加载最佳模型
best_model, info = manager.load_model(best_version, CNNModel)
print(f"Loaded model info: {info}")

6. 生产环境部署策略

6.1 模型序列化最佳实践

TorchScript (PyTorch)

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 20)
    
    def forward(self, x):
        return torch.relu(self.linear(x))

# 方法1:Tracing(追踪)
model = SimpleModel()
model.eval()
example_input = torch.randn(1, 10)

# 追踪模型
traced_model = torch.jit.trace(model, example_input)
traced_model.save('model_traced.pt')

# 加载
loaded_model = torch.jit.load('model_traced.pt')
output = loaded_model(torch.randn(1, 10))

# 方法2:Scripting(脚本化)
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')

# 区别:Tracing 只记录特定输入的执行路径,Scripting 解析代码逻辑

TensorFlow SavedModel

import tensorflow as tf

# 保存为 SavedModel
tf.saved_model.save(model, 'saved_model_dir')

# 加载
loaded_model = tf.saved_model.load('saved_model_dir')
inference_func = loaded_model.signatures['serving_default']

# 或者使用 Keras 模型
loaded_keras_model = tf.keras.models.load_model('saved_model_dir')

6.2 模型服务化

使用 TorchServe

# 1. 安装 TorchServe
pip install torchserve torch-model-archiver

# 2. 打包模型
torch-model-archiver --model-name mnist_model \
                     --version 1.0 \
                     --serialized-file model.pth \
                     --handler image_classifier \
                     --export-path model_store

# 3. 启动服务
torchserve --start --model-store model_store --models mnist_model=mnist_model.mar

# 4. 推理请求
curl http://localhost:8080/predictions/mnist_model -T sample_image.png

使用 FastAPI 部署

from fastapi import FastAPI, File, UploadFile
import torch
from PIL import Image
import io

app = FastAPI()

# 全局加载模型
model = None

@app.on_event("startup")
async def load_model():
    global model
    model = torch.jit.load('model_scripted.pt')
    model.eval()

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    # 读取图像
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    
    # 预处理
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    input_tensor = transform(image).unsqueeze(0)
    
    # 推理
    with torch.no_grad():
        output = model(input_tensor)
    
    return {"prediction": output.argmax().item()}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

6.3 模型加密与安全

from cryptography.fernet import Fernet
import base64

def generate_key():
    """生成加密密钥"""
    return Fernet.generate_key()

def encrypt_model(model_state_dict, key):
    """加密模型"""
    f = Fernet(key)
    # 将 state_dict 序列化为 bytes
    model_bytes = torch.save(model_state_dict, bytes())
    encrypted = f.encrypt(model_bytes)
    return encrypted

def decrypt_model(encrypted_data, key, model):
    """解密模型"""
    f = Fernet(key)
    decrypted = f.decrypt(encrypted_data)
    # 从 bytes 反序列化
    state_dict = torch.load(io.BytesIO(decrypted))
    model.load_state_dict(state_dict)
    return model

# 使用示例
key = generate_key()
encrypted = encrypt_model(model.state_dict(), key)

# 保存加密模型
with open('model_encrypted.bin', 'wb') as f:
    f.write(encrypted)

# 加载解密
with open('model_encrypted.bin', 'rb') as f:
    encrypted_data = f.read()
model = decrypt_model(encrypted_data, key, CNNModel())

7. 性能优化技巧

7.1 异步保存

import threading
import queue

class AsyncModelSaver:
    def __init__(self):
        self.save_queue = queue.Queue()
        self.save_thread = threading.Thread(target=self._save_worker, daemon=True)
        self.save_thread.start()
    
    def _save_worker(self):
        while True:
            try:
                model_state, path = self.save_queue.get()
                torch.save(model_state, path)
                print(f"Async save completed: {path}")
            except Exception as e:
                print(f"Async save failed: {e}")
            finally:
                self.save_queue.task_done()
    
    def save_async(self, model_state, path):
        """非阻塞保存"""
        self.save_queue.put((model_state, path))
    
    def wait_completion(self):
        """等待所有保存完成"""
        self.save_queue.join()

# 使用示例
saver = AsyncModelSaver()

# 在训练循环中
for epoch in range(epochs):
    # 训练...
    if epoch % 5 == 0:
        # 非阻塞保存
        saver.save_async(model.state_dict(), f'checkpoint_epoch_{epoch}.pth')

# 训练结束后等待保存完成
saver.wait_completion()

7.2 增量保存

def save_model_incremental(model, base_path, keep_last=3):
    """只保存与之前版本不同的参数"""
    current_state = model.state_dict()
    
    # 检查是否存在之前的检查点
    previous_checkpoints = []
    for i in range(keep_last):
        path = f"{base_path}_diff_{i}.pth"
        if os.path.exists(path):
            previous_checkpoints.append(path)
    
    # 计算差异
    if previous_checkpoints:
        last_state = torch.load(previous_checkpoints[-1])
        diff_state = {}
        for key in current_state:
            if key not in last_state or not torch.equal(current_state[key], last_state[key]):
                diff_state[key] = current_state[key]
        
        # 如果差异很小,保存差异
        if len(diff_state) < len(current_state) * 0.1:  # 差异小于10%
            save_path = f"{base_path}_diff_{len(previous_checkpoints)}.pth"
            torch.save(diff_state, save_path)
            print(f"Saved incremental diff: {len(diff_state)} params changed")
            return
    
    # 差异太大,保存完整模型
    torch.save(current_state, f"{base_path}_full.pth")
    print("Saved full model (too many changes)")

8. 故障恢复与灾难应对

8.1 自动恢复脚本

import os
import sys
import time

def auto_resume_training(script_path, checkpoint_dir, max_retries=3):
    """
    自动恢复训练脚本
    """
    retry_count = 0
    while retry_count < max_retries:
        try:
            # 检查是否有检查点
            latest_checkpoint = find_latest_checkpoint(checkpoint_dir)
            
            if latest_checkpoint:
                print(f"Resuming from {latest_checkpoint}")
                cmd = f"python {script_path} --resume {latest_checkpoint}"
            else:
                print("Starting fresh training")
                cmd = f"python {script_path}"
            
            # 执行训练
            exit_code = os.system(cmd)
            
            if exit_code == 0:
                print("Training completed successfully")
                return True
            else:
                print(f"Training failed with exit code {exit_code}")
                retry_count += 1
                
        except Exception as e:
            print(f"Error: {e}")
            retry_count += 1
        
        if retry_count < max_retries:
            print(f"Retrying in 60 seconds... (Attempt {retry_count}/{max_retries})")
            time.sleep(60)
    
    print("Max retries reached. Training failed.")
    return False

def find_latest_checkpoint(checkpoint_dir):
    """找到最新的检查点"""
    if not os.path.exists(checkpoint_dir):
        return None
    
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
    if not checkpoints:
        return None
    
    # 按修改时间排序
    checkpoints.sort(key=lambda x: os.path.getctime(
        os.path.join(checkpoint_dir, x)
    ), reverse=True)
    
    return os.path.join(checkpoint_dir, checkpoints[0])

# 使用方式
# auto_resume_training('train.py', './checkpoints')

8.2 损坏文件检测

def verify_checkpoint_integrity(checkpoint_path):
    """
    验证检查点文件完整性
    """
    try:
        # 尝试加载
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # 检查关键字段
        required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch']
        for key in required_keys:
            if key not in checkpoint:
                print(f"Warning: Missing key '{key}'")
                return False
        
        # 检查模型权重是否有效
        model_state = checkpoint['model_state_dict']
        for key, value in model_state.items():
            if torch.isnan(value).any() or torch.isinf(value).any():
                print(f"Invalid values in {key}: NaN or Inf detected")
                return False
        
        print("Checkpoint integrity check passed")
        return True
        
    except Exception as e:
        print(f"Checkpoint corrupted: {e}")
        return False

# 定期检查
def periodic_integrity_check(checkpoint_dir, interval_hours=1):
    import schedule
    import time
    
    def job():
        for file in os.listdir(checkpoint_dir):
            if file.endswith('.pth'):
                path = os.path.join(checkpoint_dir, file)
                if not verify_checkpoint_integrity(path):
                    print(f"Corrupted checkpoint: {path}")
                    # 可以触发告警或删除
    
    schedule.every(interval_hours).hours.do(job)
    
    while True:
        schedule.run_pending()
        time.sleep(60)

9. 总结与检查清单

9.1 模型保存检查清单

在训练重要模型前,请确认:

  • [ ] 保存策略:使用检查点而非仅保存最终模型
  • [ ] 完整信息:包含模型权重、优化器状态、epoch、学习率等
  • [ ] 设备兼容性:使用 map_location 确保跨设备加载
  • [ ] 版本控制:记录框架版本、Python 版本、模型架构
  • [ ] 异常处理:添加 try-catch 防止保存失败中断训练
  • [ ] 磁盘空间:定期清理旧检查点,监控磁盘使用
  • [ ] 文件权限:确保有写入权限,使用原子操作
  • [ ] 备份策略:重要模型备份到云存储或不同磁盘
  • [ ] 加密保护:敏感模型使用加密存储
  • [ ] 文档记录:记录模型性能、超参数、训练环境

9.2 快速参考:不同场景下的最佳实践

场景 推荐格式 必须保存的内容 额外建议
持续训练 检查点 (.pth) 模型+优化器+epoch 每epoch保存最新,每5epoch保存历史
最终部署 TorchScript (.pt) 仅推理所需 移除不必要的元数据,量化压缩
迁移学习 仅权重 (.pth) model.state_dict() 保存为通用格式,便于分享
分布式训练 检查点 + DDP包装 需 module.state_dict() 仅 rank 0 保存,避免重复
研究实验 完整检查点 + 元数据 所有信息 + 配置 使用 MLflow/wandb 跟踪
生产环境 ONNX / SavedModel 优化后的推理模型 考虑加密和版本管理

9.3 性能对比表

方法 文件大小 加载速度 灵活性 推荐场景
仅权重 迁移学习、研究
完整模型 快速原型、部署
检查点 长期训练、恢复
TorchScript 生产部署
量化模型 极小 极快 移动端、边缘设备

10. 常见问题解答

Q1: 为什么我的模型加载后性能下降了? A: 可能原因:1) 未保存优化器状态导致学习率丢失;2) 数据预处理不一致;3) 模型在保存时处于 train() 模式;4) 设备精度差异(CPU vs GPU)。

Q2: 如何在不中断训练的情况下保存模型? A: 使用异步保存(见 7.1 节)或检查点回调(Keras)/自定义保存线程。

Q3: 模型文件太大怎么办? A: 1) 仅保存权重;2) 使用量化;3) 使用增量保存;4) 压缩为 zip;5) 保存到云存储。

Q4: 如何安全地共享模型? A: 1) 使用 ONNX 格式跨框架;2) 加密敏感模型;3) 记录版本和依赖;4) 提供加载示例代码。

Q5: 恢复训练时学习率应该重置吗? A: 不应该。如果保存了优化器状态,学习率会从保存点继续;如果未保存,需要手动设置或使用学习率调度器重新初始化。

结语

模型保存是深度学习工作流中至关重要但常被忽视的环节。通过采用本文介绍的检查点策略、版本管理、异常处理和生产部署最佳实践,您可以显著提高项目的可靠性和效率。

记住:好的模型保存策略不是事后的补救,而是训练前的设计。在开始训练前就规划好保存策略,将为您的项目节省大量时间和资源。


本文档将持续更新,欢迎反馈和建议。最后更新时间:2024年