PyTorch-matmul函数

pytorch 中两个张量的乘法分为两种

  1. 两个张量对应元素相乘(Element-wise),在 pytorch 中通过 torch.mul 函数(或 * 运算符)实现
  2. 两个张量矩阵相乘(Matrix Product),在 pytorch 中通过 torch.matmul 函数是按

torch.matmul(input, other, out=None) 函数对 inputother 两个张量进行矩阵相乘。 torch.matmul 函数根据传入参数的张量维度有很多重载函数

这里将 input 命名为 aother 命名为 b

  • 若 a 为 1D 张量,b 为 1D 张量,则返回两个张量的点积(此时的 torch.matmul 不支持 out 参数)

点积在数学中,又称为数量积(dot product; sclar product),是指接收在实数 R 上的两个 1D 张量并返回一个实数值 0D 张量的二元运算

若 1D 张量 $a=[1,2]$,1D 张量 $b=[3,4]$,则:
$$
a\cdot b=1\times 3 + 2\times 4=11
$$

1
2
3
4
5
6
7
8
9
import torch

# a,b 都是 1D 张量
a = torch.tensor([1.,2.])
b = torch.tensor([3.,4.])

res = torch.matmul(a, b)
print(res)
# tensor(11.)
  • 若 a 为 2D 张量,b 为 2D 张量,则返回两个张量的矩阵乘积

矩阵相乘最重要的方法是一般矩阵乘积。它只有在第一个 2D 张量(矩阵)的列数和第二个 2D 张量(矩阵)的行数相同时才有意义

若 2D 张量 $a=\left[ \begin{array}{c} 1 & 2 \ 3 & 4 \end{array} \right]$,2D 张量 $b=\left[ \begin{array}{c} 5 & 6 & 7 \ 8 & 9 & 10 \end{array} \right]$,则
$$
\begin{align}
a\times b&= \left[ \begin{array}{c} 1 & 2 \ 3 & 4 \end{array} \right] \times \left[ \begin{array}{c} 5 & 6 & 7 \ 8 & 9 & 10 \end{array} \right]\
&=\left[ \begin{array}{c} 21 & 24 & 27 \ 47 & 54 & 61 \end{array} \right]
\end{align}
$$

1
2
3
4
5
6
7
8
9
10
import torch

# a,b 都是 2D 张量
a = torch.tensor([[1., 2.], [3., 4.]])
b = torch.tensor([[5., 6., 7.], [8., 9., 10.]])

res = torch.matmul(a, b)
print(res)
# tensor([[21., 24., 27.],
# [47., 54., 61.]])
  • 若 a 为 1D 张量,b 为 2D 张量, torch.matmul 函数
    • 首先,在 1D 张量 a 的前面插入一个长度为1的新维度变成 2D 张量
    • 然后,在满足第一个 2D 张量的列数和第二个 2D 张量的行数相同的条件下,两个 2D 张量矩阵乘积,否则会报错
    • 最后,将矩阵乘积结果中长度为1的维度删除作为最后结果

简单来说,先将 1D 张量 a 扩展成 2D 张量,满足矩阵乘积的条件下,将两个 2D 张量进行矩阵乘积的运算
$$
\left[ \begin{array}{c} 1 & 2 \end{array} \right] \times \left[ \begin{array}{c} 5 & 6 & 7 \ 8 & 9 & 10 \end{array} \right]\
=[\begin{array}{c} 1\times 5 + 2\times 8 & 1\times 6+2\times 9 & 1\times 7+2\times 10
\end{array}]\
=\begin{array}{c}[21&24&27]\end{array}
$$
此时得到的是形状为 (1,3) 的 2D 张量,最后将前面插入长度为1的新维度删除即为最终 torch.matmul(a, b) 函数返回的结果

1
2
3
4
5
6
7
8
9
10
11
12
import torch

# a,b 都是 2D 张量
a = torch.tensor([1., 2.])
b = torch.tensor([[5., 6., 7.], [8., 9., 10.]])

res = torch.matmul(a, b)
print(res.shape)
# tensor.Size([3])

print(res)
# tensor([21., 24., 27.]
  • 若 a 为 2D 张量,b 为 1D 张量, torch.matmul 函数
    • 首先,在 1D 张量 b 的后面插入一个长度为1的新维度变成 2D 张量
    • 然后,在满足第一个 2D 张量的列数和第二个 2D 张量的行数相同的条件下,两个 2D 张量矩阵乘积,否则会报错
    • 最后,将矩阵乘积结果中长度为1的维度删除作为最后结果
1
2
3
4
5
6
7
8
9
10
11
12
import torch

# a 是 2D, b 是 1D
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
b = torch.tensor([7., 8., 9.])

res = torch.matmul(a, b)
print(res.shape)
# tensor.Size([2])

print(res)
# tensor([50., 122.]
  • 如果 a 和 b 至少有一个 ND 张量(N>2)。针对多维张量,我们只关注每个张量的后两个维度,将每个张量的后两个维度按照矩阵乘积进行运算,其余的维度都可以认为是批量维度

比如 4D 张量 a 的形状为 $(j\times 1\times n\times m)$,而张量 b 的形状为 $(k\times m\times p)$,将两个张量的后两个维度进行矩阵乘积运算,即 $(n\times m)\times (m\times p)=(n\times p)$,张量 a 的批量维度为 $(j\times 1)$,而张量 b 的批量维度为 $(k,)$ ,张量 a 的批量维度和张量 b 的批量维度可以进行广播成 $(j\times k)$,即最后返回的张量维度为 $(j\times k\times n\times m)$

1
2
3
4
5
6
7
8
import torch

a = torch.randn(10, 3, 4)
b = torch.randn(4)

res = torch.matmul(a, b)
print(res.shape)
# tensor.Size([10, 3])

此时 a 的形状为 (10, 3, 4) 的 3D 张量,而 b 为形状为 (4,) 的 1D 张量,首先计算 (3, 4) 与 (4,) 的矩阵乘法,按照前面介绍的可以先插入一个长度为1的新维度,然后再进行矩阵乘积,计算完成之后将长度为1的维度删除,最终得到的是形状为 (3,) 的 1D 张量,张量 a 的批量维度为 (10,),而张量 b 的批量维度为 (1,),进行广播最终的返回的张量维度为 (10, 3)