Tensorflow中高维矩阵的乘法运算tf.matmul(tf.linalg.matmul)详悉

avatar
作者
猴君
阅读量:0

1.问题由来

 在tensorflow框架下,经常会用到矩阵的乘法运算,特别是高(多)维的矩阵运算,在这些矩阵运算时,经常使用到其中的tf.matmul或tf.linalg.matmul等函数。但高维矩阵在内部怎么运算的?其内部的参数是怎么实现的在tensorflow给出的介绍仍然存在表达不明的问题,所以在此作进一步的阐释。

声明:本博客里的数组乘法运算是指矩阵乘法运算,不是对应元素相乘。所述高维代表矩阵的维度\geq3维。

2.高维矩阵的乘法运算规则

2.1 运算条件

两矩阵的维数相同:len(a.shape)=len(b.shape) 
n-2个维度都一致:a.shape[0]=b.shape[0],...,a.shape[-3]=b.shape[-3]
最后两个维度满足矩阵乘法运算:a.shape[-1]=b.shape[-2]
具体地,假设a.shape=(n_{1},n_{2},...,n_{L}) ,b.shape=(m_{1},m_{2},...,m_{L}),则tf.matmul(a,b) 能运算的条件如下图(箭头表示相等):

 2.2 使用tf.matmul(tf.linalg.matmul)时存在的问题

 按照上文的规则使用tf.matmul(tf.linalg.matmul)时,又会存在各种问题。以tf.linalg.matmul为例,其关键参数设置如下,\mathbf{a}\mathbf{b}表征2个高维矩阵,transpose_a和transpose_b可以理解为分别对\mathbf{a}\mathbf{b}这2个矩阵的转置操作。我们假设\mathbf{a}\mathbf{b}都是4维矩阵,并设维度分别为[a,b,c,d][e,f,g,h]。tensorflow中,第1维一般是batchsize。那么,tf.linalg.matmul(a,b,transpose_b=True)是不是对矩阵\mathbf{b}的真正转置呢?即tf.linalg.matmul(a,b,transpose_b=True)是维度维[a,b,c,d]的矩阵\mathbf{a}与维度为[h,g,f,e]的矩阵\mathbf{b}直接的矩阵运算呢?

tf.linalg.matmul(     a,     b,     transpose_a=False,     transpose_b=False,     adjoint_a=False,     adjoint_b=False,     a_is_sparse=False,     b_is_sparse=False,     output_type=None,     grad_a=False,     grad_b=False,     name=None )

 如果我们直接看tensorflow给出的解释如下

 

 直观的理解确实如前文所述,其实不然。

这里的transpose_a / transpose_b=True并不是执行传统数学意义上的转置操作,而是仅对高维矩阵上的最后两个维度的转置,其它维度仍保持不变。这是通过调用tf.linalg.matrix_transpose实现的。具体如下

tf.linalg.matrix_transpose(     a, name='matrix_transpose', conjugate=False )

tensorflow文档中对其的描述如下,即转置矩阵\mathbf{a}的最后2个维度。

Transposes last two dimensions of tensor a.

至于后续的运算可以在相关文档中查阅得到。

参考文档

tensorflow中高维数组乘法运算_高位矩阵乘法 tensorflow-CSDN博客

TensorFlow中矩阵乘操作tf.matmul(或tf.linalg.matmul)和矩阵元素乘tf.multiply(或tf.math.multiply)用法对比-CSDN博客

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!