阅读量:0
nn.Linear
是 PyTorch 中的一个类,用来定义一个线性变换(线性层)的操作。
具体来说,nn.Linear
用于定义一个线性映射,将输入张量的每个元素与权重矩阵相乘,并加上偏置向量。其功能可以总结如下:
线性变换:将输入张量与权重矩阵相乘,得到输出张量。输入张量的形状为
(batch_size, input_size)
,权重矩阵的形状为(output_size, input_size)
。输出张量的形状为(batch_size, output_size)
。加偏置:将输出张量加上一个偏置向量,该偏置向量的形状为
(output_size,)
。偏置向量会被广播到每个样本的输出上。自动创建参数:
nn.Linear
创建线性层时会自动创建权重矩阵和偏置向量,并将它们保存在模型的参数列表中。自动梯度计算:通过 PyTorch 的自动求导机制,
nn.Linear
可以自动计算权重矩阵和偏置向量的梯度,并进行优化。
nn.Linear
通常在神经网络模型中被用作全连接层(全连接神经网络),用来将输入特征映射到输出特征。