浏览 190
扫码
混淆矩阵是机器学习和统计学中用于评估分类模型性能的重要工具。在PyTorch中,我们可以使用混淆矩阵来可视化模型在测试集上的预测结果,以便更直观地了解模型的性能。
下面是一个示例代码,用于生成混淆矩阵并将其可视化:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
# 加载测试集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 定义模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 定义卷积神经网络结构
def forward(self, x):
# 前向传播
return x
# 加载模型
model = CNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()
# 使用混淆矩阵评估模型性能
y_true = []
y_pred = []
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
y_true.extend(labels.numpy())
y_pred.extend(predicted.numpy())
# 生成混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 可视化混淆矩阵
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
for i in range(len(classes)):
for j in range(len(classes)):
plt.text(j, i, str(cm[i, j]), ha='center', va='center')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
在上面的代码中,我们首先加载了测试集数据,然后定义了一个卷积神经网络模型。接着加载了预先训练好的模型,并对测试集进行预测,将真实标签和预测标签存储起来。最后,使用sklearn库中的confusion_matrix函数生成混淆矩阵,并使用matplotlib库将混淆矩阵可视化。
通过混淆矩阵,我们可以直观地看到模型在每个类别上的预测准确率和错误率,从而更全面地评估模型的性能。希望这个教程能帮助你理解如何在PyTorch中使用混淆矩阵评估模型性能。