pytorch train学习率衰减

在PyTorch中,学习率衰减是一种优化技术,用于在训练过程中逐渐减小学习率。这有助于模型在训练初期快速收敛,同时在后期获得更精确的解。以下是在PyTorch中实现学习率衰减的几种方法:

使用torch.optim.lr_scheduler模块中的StepLR类:

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
# 创建一个简单的模型
model = torch.nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# 创建一个StepLR调度器,每隔10个epoch将学习率乘以0.1
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
# 训练循环
for epoch in range(100):
    # 训练模型...
    
    # 更新学习率
    scheduler.step()

使用torch.optim.lr_scheduler模块中的ReduceLROnPlateau类:

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
# 创建一个简单的模型
model = torch.nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# 创建一个ReduceLROnPlateau调度器,当验证集上的损失不再降低时,将学习率乘以0.1
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
# 训练循环
for epoch in range(100):
    # 训练模型...
    
    # 计算验证集上的损失
    val_loss = ...
    
    # 更新学习率
    scheduler.step(val_loss)

使用自定义学习率衰减函数:

import torch
import torch.optim as optim
# 创建一个简单的模型
model = torch.nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
def custom_lr_decay(epoch, lr):
    decay_rate = 0.1
    decay_step = 10
    if epoch % decay_step == 0:
        return lr * decay_rate
    else:
        return lr
# 训练循环
for epoch in range(100):
    # 训练模型...
    
    # 更新学习率
    lr = custom_lr_decay(epoch, optimizer.param_groups[0]['lr'])
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

这些方法都可以帮助你在PyTorch中实现学习率衰减。你可以根据自己的需求选择合适的方法。