浏览 99
扫码
在PyTorch中定义模型结构通常需要创建一个继承自nn.Module
的类。这个类需要实现两个方法:__init__
方法用来定义模型的结构,forward
方法用来定义模型的前向传播过程。下面是一个简单的例子,展示如何定义一个基本的神经网络模型:
import torch
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128) # 输入层到隐藏层的全连接层
self.relu = nn.ReLU() # 激活函数
self.fc2 = nn.Linear(128, 10) # 隐藏层到输出层的全连接层
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 创建模型实例
model = SimpleNN()
在这个例子中,我们定义了一个包含一个隐藏层的全连接神经网络模型。模型的输入维度是784,输出维度是10。在__init__
方法中,我们定义了两个全连接层和一个ReLU激活函数。在forward
方法中,我们定义了模型的前向传播过程。
除了上面这种简单的方式,还可以使用nn.Sequential
来定义模型结构,这种方式更加简洁。下面是一个使用nn.Sequential
的例子:
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784, 128), # 输入层到隐藏层的全连接层
nn.ReLU(), # 激活函数
nn.Linear(128, 10) # 隐藏层到输出层的全连接层
)
使用nn.Sequential
方式定义模型结构时,不需要定义forward
方法,PyTorch会自动根据模型的结构生成前向传播过程。
无论是使用自定义的类还是nn.Sequential
,都可以很方便地定义PyTorch模型的结构。建议根据具体的任务需求选择合适的方式。