pytorch 中两个张量的乘法分为两种
- 两个张量对应元素相乘(Element-wise),在 pytorch 中通过
torch.mul
函数(或 * 运算符)实现 - 两个张量矩阵相乘(Matrix Product),在 pytorch 中通过
torch.matmul
函数是按
torch.matmul(input, other, out=None)
函数对 input
和 other
两个张量进行矩阵相乘。 torch.matmul
函数根据传入参数的张量维度有很多重载函数。
这里将 input
命名为 a
, other
命名为 b
- 若 a 为 1D 张量,b 为 1D 张量,则返回两个张量的点积(此时的
torch.matmul
不支持out
参数)
点积在数学中,又称为数量积(dot product; sclar product),是指接收在实数 R 上的两个 1D 张量并返回一个实数值 0D 张量的二元运算
若 1D 张量 $a=[1,2]$,1D 张量 $b=[3,4]$,则:
$$
a\cdot b=1\times 3 + 2\times 4=11
$$
1 | import torch |
- 若 a 为 2D 张量,b 为 2D 张量,则返回两个张量的矩阵乘积
矩阵相乘最重要的方法是一般矩阵乘积。它只有在第一个 2D 张量(矩阵)的列数和第二个 2D 张量(矩阵)的行数相同时才有意义
若 2D 张量 $a=\left[ \begin{array}{c} 1 & 2 \ 3 & 4 \end{array} \right]$,2D 张量 $b=\left[ \begin{array}{c} 5 & 6 & 7 \ 8 & 9 & 10 \end{array} \right]$,则
$$
\begin{align}
a\times b&= \left[ \begin{array}{c} 1 & 2 \ 3 & 4 \end{array} \right] \times \left[ \begin{array}{c} 5 & 6 & 7 \ 8 & 9 & 10 \end{array} \right]\
&=\left[ \begin{array}{c} 21 & 24 & 27 \ 47 & 54 & 61 \end{array} \right]
\end{align}
$$
1 | import torch |
- 若 a 为 1D 张量,b 为 2D 张量,
torch.matmul
函数- 首先,在 1D 张量 a 的前面插入一个长度为1的新维度变成 2D 张量
- 然后,在满足第一个 2D 张量的列数和第二个 2D 张量的行数相同的条件下,两个 2D 张量矩阵乘积,否则会报错
- 最后,将矩阵乘积结果中长度为1的维度删除作为最后结果
简单来说,先将 1D 张量 a 扩展成 2D 张量,满足矩阵乘积的条件下,将两个 2D 张量进行矩阵乘积的运算
$$
\left[ \begin{array}{c} 1 & 2 \end{array} \right] \times \left[ \begin{array}{c} 5 & 6 & 7 \ 8 & 9 & 10 \end{array} \right]\
=[\begin{array}{c} 1\times 5 + 2\times 8 & 1\times 6+2\times 9 & 1\times 7+2\times 10
\end{array}]\
=\begin{array}{c}[21&24&27]\end{array}
$$
此时得到的是形状为 (1,3) 的 2D 张量,最后将前面插入长度为1的新维度删除即为最终torch.matmul(a, b)
函数返回的结果
1 | import torch |
- 若 a 为 2D 张量,b 为 1D 张量,
torch.matmul
函数- 首先,在 1D 张量 b 的后面插入一个长度为1的新维度变成 2D 张量
- 然后,在满足第一个 2D 张量的列数和第二个 2D 张量的行数相同的条件下,两个 2D 张量矩阵乘积,否则会报错
- 最后,将矩阵乘积结果中长度为1的维度删除作为最后结果
1 | import torch |
- 如果 a 和 b 至少有一个 ND 张量(N>2)。针对多维张量,我们只关注每个张量的后两个维度,将每个张量的后两个维度按照矩阵乘积进行运算,其余的维度都可以认为是批量维度
比如 4D 张量 a 的形状为 $(j\times 1\times n\times m)$,而张量 b 的形状为 $(k\times m\times p)$,将两个张量的后两个维度进行矩阵乘积运算,即 $(n\times m)\times (m\times p)=(n\times p)$,张量 a 的批量维度为 $(j\times 1)$,而张量 b 的批量维度为 $(k,)$ ,张量 a 的批量维度和张量 b 的批量维度可以进行广播成 $(j\times k)$,即最后返回的张量维度为 $(j\times k\times n\times m)$
1 | import torch |
此时 a 的形状为 (10, 3, 4) 的 3D 张量,而 b 为形状为 (4,) 的 1D 张量,首先计算 (3, 4) 与 (4,) 的矩阵乘法,按照前面介绍的可以先插入一个长度为1的新维度,然后再进行矩阵乘积,计算完成之后将长度为1的维度删除,最终得到的是形状为 (3,) 的 1D 张量,张量 a 的批量维度为 (10,),而张量 b 的批量维度为 (1,),进行广播最终的返回的张量维度为 (10, 3)