在深度学习中,我们经常会遇到需要对张量进行形状变换的情况。PyTorch 提供了多种方法来改变张量的形状,包括 reshape
, view
, transpose
和permute
。本文总结了其它博客的精华,详细介绍这些方法的原理和应用场景。
目录
一、张量的存储方式
在理解这些方法之前,我们需要先了解张量在内存中的存储方式。张量的数据存储分为两部分:头信息区(Tensor)和存储区(Storage)。头信息区包含张量的形状(size)、步长(stride)等信息;而存储区则存放实际的数据。
如果我们对A进行截取、转置或修改等操作后赋值给B,则B的数据共享A的存储区,存储区的数据数量没变,变化的只是B的头信息区对数据的索引方式。如果听说过浅拷贝和深拷贝的话,很容易明白这种方式其实类似浅拷贝。
步长(Stride):步长是指从一个元素跳转到下一个元素所需的偏移量,通常以字节数表示。
例如,在二维张量中,第一个维度的步长表示从一行跳转到下一行的偏移量,第二个维度的步长表示在同一行内从一个元素跳转到下一个元素的偏移量。
假设我们有一个形状为 (2, 3) 的张量 a
,包含元素 [1, 2, 3, 4, 5, 6]
。
import torch # 创建一个 (2, 3) 的张量 a = torch.tensor([[1, 2, 3], [4, 5, 6]]) print("原始张量 a:") print(a)
输出结果类似于:
原始张量 a: tensor([[1, 2, 3], [4, 5, 6]])
张量 a 的形状:(2, 3),
张量 a 的步长:a.stride()
返回 (3, 1)
- 第一个维度的步长(行):
3
,意味着从第一行到第二行需要跳过 3 个元素。这是因为每个元素占据一定的内存空间(例如 4 字节),因此从第一行跳转到第二行需要跳过 3 * 4 字节。 - 第二个维度的步长(列):
1
,意味着在同一行内从一个元素到下一个元素只需要跳过 1 个元素。这意味着从一个元素到下一个元素只需要跳过 1 * 4 字节。
之所以提及步长是因为底层的实现是利用了这种方式,适当了解一下,后续的讲解不会涉及太多步长的知识。
二、 reshape
reshape
方法会尝试在不复制数据的情况下改变张量的形状。如果新的形状与原始形状的元素总数相同,那么 reshape
会改变张量的步长,使得新的形状可以正确地访问数据。
reshape()在新形状满足一定条件时会共享相同一份数据,否则会复制一份新的数据。(reshape操作可以作用在连续和非连续数据上,如果作用在非连续数据上也会带来内存拷贝,变成内存连续的数据。连续的知识点在view中会讲到)
下面默认讲解reshape满足条件下的共享情况
reshape是如何实现改变尺度操作的呢,参考了一篇帖子,总结了下面这些知识点。
import torch # 使用torch.arange生成值从0开始,步长为2的张量,直到22(不包括22) a = torch.arange(0, 24, 2) # 输出张量a print(a) # 输出: tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22])
这个a非常简单, 就是一个数组, 但是其内部其实是这样组织的
a行代表内存中数组的存储状态, 上面的i
行代表辅助的索引轴,借助这个轴, 就可以访问到数组的任何一个元素, 比如a[3], 就是索引到元素6。
import torch # 假设a是之前创建的一维张量 a = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22]) # 使用reshape方法重塑a为2x6的二维张量 b = torch.reshape(a, (2, 6)) # 输出b print(b) # 输出 tensor([[ 0, 2, 4, 6, 8, 10], [12, 14, 16, 18, 20, 22]])
此时貌似是a里面的元素改变了, 变成了2维, 其实改变的只是轴。 此时pytorch会建立两个轴,假设两轴为 i、j
,i
的取值为 0 - 1,j
的取值为 0 - 5,示意图如下:
此时使用b[1][2]就可以获取到元素16(这时候的两个轴取值为1,2)
无论shape如何变化, 变化的是视图而已, 底下的缓冲区数据始终未变。
现在进阶一下看看三个轴的例子:
c = a.reshape(2,3,2) print(c) # 输出 array([[[ 0, 2], [ 4, 6], [ 8, 10]], [[12, 14], [16, 18], [20, 22]]])
数组 c 形状为(2, 3, 2)有三个轴,取值分别为 0 - 1,0 - 2,0 - 1,示意图如下所示:
下一个轴的循环范围是以上一轴的同一元素为边界,图中我框出来了,应该能看明白。
此时c[0,2,1]可以得到元素10,跟代码里输出的维度进行比较,确实是10.
reshape 后的张量,仅仅是原来张量的视图 view,并没有发生复制元素的行为,这样才能保证 reshape 操作更为高效。 也就是说即使reshape了, 两者指向的是底层数据还是一样, 如果改变其中一个, 另一个也会跟着改变。
import torch a = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22]) # 使用reshape方法改变a的形状为2x6的二维张量b b = a.reshape(2, 3, 2) # 修改b中的一个元素 b[1][2] = 99 # 输出a print("修改后的a:", a) 修改后的a: tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 99, 18, 20, 22])
三、view
view类似于reshape,将tensor转换为指定的shape,原始的data不改变。返回的tensor与原始的tensor共享存储区。返回的tensor的size和stride必须与原始的tensor兼容。每个新的tensor的维度必须是原始维度的子空间,或满足连续条件。view操作只作用在连续内存上,仅仅按照行重新排列下标,不改变数据的内存分布。
否则需要先使用contiguous()方法将原始tensor转换为满足连续条件的tensor,然后就可以使用view方法进行shape变换了。(这就是我们在深度学习代码中常见到在view之前使用contiguous()的原因。contiguous()方法开辟了一个新的存储区,并改变了原始存储区数据的存放顺序,类似于深拷贝。)
或直接使用reshape方法进行维度变换,但在不连续情况下,这种方法变换后的tensor就不是与原始tensor共享内存了,而是被重新开辟了一个空间。
这部分的详细内容可以参考下面的链接,讲得很好。
PyTorch:view() 与 reshape() 区别详解_pytorch view reshape-CSDN博客
四、transpose
transpose
方法用于交换张量的两个维度。与 permute
不同,transpose
只能交换两个维度。transpose
方法接收两个整数参数,表示要交换的维度。
举个例子
import torch # 创建形状为 (2, 3, 2) 的三维张量 x = torch.arange(12).view(2, 3, 2) print("原始张量 x:") print(x) # 使用 transpose 方法交换第0维和第1维 x_transposed = x.transpose(0, 1) print("\n交换第0维和第1维后的张量 x_transposed:") print(x_transposed) # 输出 原始张量 x: tensor([[[ 0, 1], [ 2, 3], [ 4, 5]], [[ 6, 7], [ 8, 9], [10, 11]]]) 交换第0维和第1维后的张量 x_transposed: tensor([[[ 0, 1], [ 6, 7]], [[ 2, 3], [ 8, 9]], [[ 4, 5], [10, 11]]])
使用transpose ()进行变换,其实就是交换了坐标轴。如:x.transpose(0,1)
,其实就是将x的第一维与第二维索引数字交换。最后的shape为:(3,2,2)。
原理如下
原先的数据的索引和数据对应情况为:
x[0][0][0] = 0 x[1][0][0] = 6 x[0][0][1] = 1 x[1][0][1] = 7 x[0][1][0] = 2 x[1][1][0] = 8 x[0][1][1] = 3 x[1][1][1] = 9 x[0][2][0] = 4 x[1][2][0] = 10 x[0][2][1] = 5 x[1][2][1] = 11
交换数据的索引,对应的值还是不变,即交换了坐标轴,如[0][2][1] —>[2][0][1]
x[0][0][0] = 0 x[0][1][0] = 6 x[0][0][1] = 1 x[0][1][1] = 7 x[1][0][0] = 2 x[1][1][0] = 8 x[1][0][1] = 3 x[1][1][1] = 9 x[2][0][0] = 4 x[2][1][0] = 10 x[2][0][1] = 5 x[2][1][1] = 11
此时变换坐标后和上面打印的张量一致。
例子中,我们使用的shape是(2, 3, 2),可以理解成:2通道的图片,每张图层是3 * 2大小
,正常渲染是先把第一个通道的图片把3 * 2个像素点绘制,再绘制第二个通道的3 * 2像素。
在使用transpose(0, 1)后,新的数据是shape是(3,2,2),可以理解成每张图层是3 * 2大小,2通道的图片
,原先的是先绘制一个通道数据,如今变换后的数据是每次将一个坐标的不同通道的像素进行一次性绘制。
如图:
五、 permute
permute
方法用于交换张量的维度。与 transpose
类似,但可以处理任意数量的维度。
permute() 函数一次可以进行多个维度的交换或者可以成为维度重新排列,参数是 0, 1, 2, 3, … ,随着待转换张量的阶数上升参数越来越多,本质上可以理解为多个 transpose() 操作的叠加,因此理解 permute() 函数的关键在于理解 transpose() 函数。
例如,对于一个三维张量,使用permute(1, 2, 0)
意味着:
- 原始张量的第二维(索引1)将成为新张量的第一维。
- 原始张量的第三维(索引2)将成为新张量的第一维。
- 原始张量的第一维(索引0)将成为新张量的第三维。
这相当于先执行一个将第一维和第二维互换,然后再将结果张量的第一维和第三维互换,最终达到重新排列维度的目的。