阅读量:2
在PyTorch中,view()
函数用于调整张量的形状。它的使用方式如下:
output = input.view(*shape)
这里的input
是输入张量,shape
是一个元组,用于指定调整后的形状。具体来说:
shape
中的每个元素可以是一个具体的维度大小,或者-1表示根据其他维度的大小自动计算。- 调整后的张量和原始张量共享内存空间,即它们指向相同的数据。
下面是一些示例:
import torch x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 将x的形状调整为(3, 2) output = x.view(3, 2) print(output) # 输出: # tensor([[1, 2], # [3, 4], # [5, 6]]) # 将x的形状调整为(6, -1),其中-1表示自动计算 output = x.view(6, -1) print(output) # 输出: # tensor([[1], # [2], # [3], # [4], # [5], # [6]]) # 将x的形状调整为(1, 6) output = x.view(1, 6) print(output) # 输出: # tensor([[1, 2, 3, 4, 5, 6]])
需要注意的是,调整后的形状必须和原始张量的元素总数保持一致,否则会抛出错误。