浏览 62
扫码
迁移学习是一种常见的机器学习技术,它利用在一个任务上训练好的模型,在另一个相关任务上进行微调。微调是指在新任务上对预训练模型进行一定程度的调整,以适应新任务的特点。在这个教程中,我们将学习如何使用PyTorch进行迁移学习和微调。
- 安装PyTorch和相关库 首先,确保已经安装了PyTorch和torchvision库,可以使用以下命令进行安装:
pip install torch torchvision
- 加载预训练模型 在PyTorch中,可以使用torchvision模块加载预训练的模型,例如ResNet、VGG等。在这个例子中,我们将加载一个在ImageNet数据集上预训练的ResNet模型:
import torch
import torchvision.models as models
model = models.resnet18(pretrained=True)
- 替换最后一层全连接层 预训练的模型通常包含一个全连接层,其输出维度与原始任务的类别数相匹配。为了适应新任务的类别数,需要将最后一层全连接层替换为新的全连接层。代码如下:
num_classes = 10 # 新任务的类别数
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
- 微调模型 接下来,我们定义损失函数和优化器,并对模型进行微调:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# 微调模型
model.train()
for epoch in range(num_epochs):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
- 评估模型性能 在微调完成后,可以对模型在验证集上进行评估:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
print('Validation accuracy: {:.2f}%'.format(100 * accuracy))
通过以上步骤,我们完成了使用PyTorch进行迁移学习和微调的教程。希朇对你有所帮助!