阅读量:4
深度学习早停(Early Stopping)训练策略
早停(Early Stopping)是一种防止深度学习模型过拟合的正则化技术。在训练过程中,当模型在验证集上的性能不再显著提高时,早停策略会提前停止训练。这样可以避免模型在训练集上表现得越来越好,但在验证集上表现变差。
早停策略的步骤
- 划分数据集:将数据集分为训练集和验证集。
- 定义监控指标:通常是验证集上的损失或精度。
- 设定耐心值(Patience):耐心值表示在验证指标不再改善的情况下,允许继续训练的最大次数。
- 训练模型:在每个训练轮次后,计算验证集上的指标。如果在耐心值内验证指标没有改善,则停止训练。
示例代码实现
我们使用TensorFlow和Keras来实现早停策略。假设我们使用一个简单的全连接神经网络来分类MNIST手写数字数据集。
import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Flatten from tensorflow.keras.callbacks import EarlyStopping # 加载MNIST数据集 (x_train, y_train), (x_val, y_val) = mnist.load_data() # 数据归一化处理 x_train = x_train / 255.0 x_val = x_val / 255.0 # 定义模型 model = Sequential([ Flatten(input_shape=(28, 28)), # 将28x28的图片展平为一维向量 Dense(128, activation='relu'), # 第一个全连接层,128个神经元,激活函数为ReLU Dense(10, activation='softmax') # 输出层,10个神经元(10个类别),激活函数为softmax ]) # 编译模型 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 定义早停回调函数 early_stopping = EarlyStopping( monitor='val_loss', # 监控验证集上的损失 patience=3, # 如果验证集上的损失在3个轮次内没有改善,则停止训练 restore_best_weights=True # 恢复验证集损失最好的模型权重 ) # 训练模型 history = model.fit( x_train, y_train, # 训练数据 epochs=50, # 最大训练轮次 validation_data=(x_val, y_val),# 验证数据 callbacks=[early_stopping] # 早停回调函数 )
代码解释
- 导入必要的库:导入TensorFlow和Keras相关的模块。
- 加载数据集:加载MNIST手写数字数据集,并划分为训练集和验证集。
- 数据归一化处理:将数据归一化到0-1范围内。
- 定义模型:使用Keras的Sequential API定义一个简单的全连接神经网络。
- 编译模型:指定优化器、损失函数和评估指标。
- 定义早停回调函数:使用Keras的EarlyStopping回调函数,设定监控指标为验证集上的损失,耐心值为3,训练过程中恢复验证集上损失最小的模型权重。
- 训练模型:调用
model.fit
方法训练模型,同时传入早停回调函数。模型会在验证损失不再改善时提前停止训练。
这个例子演示了如何使用早停策略来防止模型过拟合,从而提高模型在验证集上的性能。
pytorch代码
以下是一个使用PyTorch实现早停策略的例子,同样使用MNIST手写数字数据集。
使用PyTorch实现早停策略
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split # 定义一个简单的全连接神经网络 class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.flatten = nn.Flatten() # 将输入展平为一维 self.fc1 = nn.Linear(28 * 28, 128) # 定义一个全连接层,输入大小为28*28,输出大小为128 self.relu = nn.ReLU() # 定义ReLU激活函数 self.fc2 = nn.Linear(128, 10) # 定义另一个全连接层,输入大小为128,输出大小为10(对应10个类别) self.softmax = nn.Softmax(dim=1) # 定义Softmax输出层,沿着维度1进行 def forward(self, x): x = self.flatten(x) # 将输入展平 x = self.fc1(x) # 输入到第一个全连接层 x = self.relu(x) # 通过ReLU激活函数 x = self.fc2(x) # 输入到第二个全连接层 x = self.softmax(x) # 通过Softmax激活函数 return x # 数据预处理:转换为张量并归一化到[-1, 1]范围内 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 加载MNIST数据集 train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) val_size = 10000 # 验证集大小 train_size = len(train_dataset) - val_size # 训练集大小 train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size]) # 划分训练集和验证集 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 训练集数据加载器 val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) # 验证集数据加载器 # 初始化模型、损失函数和优化器 model = SimpleNN() # 创建模型实例 criterion = nn.CrossEntropyLoss() # 定义交叉熵损失函数 optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器 # 定义早停策略 class EarlyStopping: def __init__(self, patience=3, delta=0): self.patience = patience # 设置耐心值,表示验证损失可以不改善的最大次数 self.delta = delta # 设置阈值,如果损失改善小于该值则认为没有改善 self.best_loss = None # 初始化最佳损失为None self.counter = 0 # 初始化计数器为0 self.early_stop = False # 初始化早停标志为False self.best_model_state = None # 初始化最佳模型状态为None def __call__(self, val_loss, model): if self.best_loss is None: # 如果最佳损失为None,说明是第一次调用 self.best_loss = val_loss # 将当前验证损失设为最佳损失 self.best_model_state = model.state_dict() # 保存模型的当前状态 elif val_loss > self.best_loss + self.delta: # 如果当前验证损失没有改善 self.counter += 1 # 计数器加1 if self.counter >= self.patience: # 如果计数器达到耐心值 self.early_stop = True # 设置早停标志为True model.load_state_dict(self.best_model_state) # 恢复模型到最佳状态 else: # 如果验证损失改善了 self.best_loss = val_loss # 更新最佳损失 self.best_model_state = model.state_dict() # 保存模型的当前状态 self.counter = 0 # 重置计数器 early_stopping = EarlyStopping(patience=3, delta=0.01) # 创建早停策略实例 # 训练模型 num_epochs = 50 # 最大训练轮次 for epoch in range(num_epochs): model.train() # 设置模型为训练模式 for batch in train_loader: images, labels = batch # 获取一批数据和标签 outputs = model(images) # 将数据输入模型,获得输出 loss = criterion(outputs, labels) # 计算损失 optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 optimizer.step() # 更新模型参数 # 验证模型 model.eval() # 设置模型为评估模式 val_loss = 0.0 # 初始化验证损失 with torch.no_grad(): # 禁用梯度计算 for batch in val_loader: images, labels = batch # 获取一批数据和标签 outputs = model(images) # 将数据输入模型,获得输出 loss = criterion(outputs, labels) # 计算损失 val_loss += loss.item() # 累加损失 val_loss /= len(val_loader) # 计算验证集上的平均损失 print(f'Epoch {epoch+1}, Validation Loss: {val_loss}') # 打印当前轮次的验证损失 # 检查早停条件 early_stopping(val_loss, model) # 调用早停策略 if early_stopping.early_stop: # 如果早停标志为True print("Early stopping") # 打印早停信息 break # 退出训练循环 # 模型训练完成
代码解释
- 定义模型:定义一个简单的全连接神经网络,包括展平层、全连接层、ReLU激活函数和Softmax输出层。
- 数据预处理:使用
transforms
对MNIST数据集进行标准化处理。 - 加载数据集:下载MNIST数据集,并将其划分为训练集和验证集。
- 初始化模型、损失函数和优化器:创建模型实例,定义交叉熵损失函数,并使用Adam优化器。
- 定义早停策略类:创建
EarlyStopping
类,包含早停所需的参数和逻辑。在验证损失不再改善时,保存模型的最佳状态,并在达到耐心值后停止训练。 - 训练模型:在每个训练轮次后,计算验证集上的损失,并使用早停策略检查是否需要停止训练。
这个PyTorch示例展示了如何实现早停策略,以防止模型过拟合并提高验证集上的性能。