pytorch怎么保存最佳模型

avatar
作者
筋斗云
阅读量:0

在PyTorch中保存最佳模型通常是通过保存模型的参数和优化器状态来实现的。以下是一个示例代码,演示了如何保存最佳模型:

import torch import torch.nn as nn import torch.optim as optim  # 定义模型 class MyModel(nn.Module):     def __init__(self):         super(MyModel, self).__init__()         self.fc = nn.Linear(10, 1)      def forward(self, x):         return self.fc(x)  model = MyModel()  # 定义损失函数和优化器 criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001)  # 训练模型 best_loss = float('inf') for epoch in range(num_epochs):     # 训练过程     train_loss = 0.0     for inputs, labels in train_loader:         optimizer.zero_grad()         outputs = model(inputs)         loss = criterion(outputs, labels)         loss.backward()         optimizer.step()         train_loss += loss.item()      train_loss /= len(train_loader)      # 保存最佳模型     if train_loss < best_loss:         best_loss = train_loss         torch.save({             'model_state_dict': model.state_dict(),             'optimizer_state_dict': optimizer.state_dict(),             'best_loss': best_loss         }, 'best_model.pth') 

在上面的示例中,我们首先定义了一个模型、损失函数和优化器。然后在训练过程中,我们通过比较当前训练损失和最佳损失来保存最佳模型。当训练损失小于最佳损失时,我们保存模型的状态字典和优化器的状态字典,并将最佳损失更新为当前训练损失。

最后,我们可以通过加载best_model.pth文件来恢复最佳模型的状态,并继续使用该模型进行推理或进一步的训练。

    广告一刻

    为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!