PyTorch 之 nn.Parameter
在 PyTorch 中,nn.Parameter
是一个类,用于将张量包装成可学习的参数。它是 torch.Tensor
的子类,但被设计成可以被优化器更新的参数。通过将张量包装成 nn.Parameter
,你可以告诉 PyTorch 这是一个模型参数,从而在训练时自动进行梯度计算和优化。
使用方法:
首先,你需要导入相应的模块:
import torch
import torch.nn as nn
然后,可以使用 nn.Parameter
类来创建可学习的参数。以下是一个简单的示例:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 创建一个可学习的参数,大小为 (3, 3)
self.weight = nn.Parameter(torch.rand(3, 3))
def forward(self, x):
# 在前向传播中使用参数
output = torch.matmul(x, self.weight)
return output
在上面的示例中,self.weight
被定义为一个 nn.Parameter
,它是一个 3x3 的矩阵。当你训练这个模型时,self.weight
将会被优化器更新。
为什么使用 nn.Parameter
:
-
自动梯度计算: 将张量包装成
nn.Parameter
后,PyTorch 将会自动追踪对该参数的操作,从而可以进行自动梯度计算。 -
与优化器的集成: 在模型的
parameters()
方法中,nn.Parameter
对象会被自动识别为模型的参数,可以方便地与优化器集成。 -
清晰的模型定义: 将可学习的参数显式地声明为
nn.Parameter
使得模型的定义更加清晰和可读。
示例使用:
# 创建模型
model = MyModel()
# 打印模型的参数
for param in model.parameters():
print(param)
# 假设有输入张量 x
x = torch.rand(3, 3)
# 计算模型输出
output = model(x)
# 打印输出
print(output)
在实际使用中,你可以通过 model.parameters()
获取模型的所有参数,并将其传递给优化器进行训练。