PyTorch 之 nn.Parameter

在 PyTorch 中,nn.Parameter 是一个类,用于将张量包装成可学习的参数。它是 torch.Tensor 的子类,但被设计成可以被优化器更新的参数。通过将张量包装成 nn.Parameter,你可以告诉 PyTorch 这是一个模型参数,从而在训练时自动进行梯度计算和优化。

使用方法:

首先,你需要导入相应的模块:

import torch
import torch.nn as nn

然后,可以使用 nn.Parameter 类来创建可学习的参数。以下是一个简单的示例:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        
        # 创建一个可学习的参数,大小为 (3, 3)
        self.weight = nn.Parameter(torch.rand(3, 3))
        
    def forward(self, x):
        # 在前向传播中使用参数
        output = torch.matmul(x, self.weight)
        return output

在上面的示例中,self.weight 被定义为一个 nn.Parameter,它是一个 3x3 的矩阵。当你训练这个模型时,self.weight 将会被优化器更新。

为什么使用 nn.Parameter

  1. 自动梯度计算: 将张量包装成 nn.Parameter 后,PyTorch 将会自动追踪对该参数的操作,从而可以进行自动梯度计算。

  2. 与优化器的集成: 在模型的 parameters() 方法中,nn.Parameter 对象会被自动识别为模型的参数,可以方便地与优化器集成。

  3. 清晰的模型定义: 将可学习的参数显式地声明为 nn.Parameter 使得模型的定义更加清晰和可读。

示例使用:

# 创建模型
model = MyModel()

# 打印模型的参数
for param in model.parameters():
    print(param)

# 假设有输入张量 x
x = torch.rand(3, 3)

# 计算模型输出
output = model(x)

# 打印输出
print(output)

在实际使用中,你可以通过 model.parameters() 获取模型的所有参数,并将其传递给优化器进行训练。