NumPYPyTorch 中都有都遵循相同的法则。

逐元素广播

当对两个形状不同的张量进行运算时,广播会尝试“自动扩展”那个维度较小的张量,使其形状与另一个张量匹配,从而能够进行逐元素的运算。这个“扩展”是虚拟的,并不会真的在内存中复制数据,因此非常高效。

当对两个张量 A 和 B 进行运算时,系统会从最右边的维度开始,逐个比较它们的维度大小。

  1. 对齐维度: 如果两个张量的维度数量不同,系统会在维度较少的那个张量的左侧补上大小为 1 的维度,直到它们的维度数量相同。
    • 例如:A 的形状是 (3, 4)B 的形状是 (4,)
    • 系统会将 B 的形状看作 (1, 4)
  2. 兼容性检查: 在对齐维度后,从右到左逐个检查每一对维度的大小。对于每一对维度,必须满足以下三个条件之一,否则就会报错:
    • 两个维度的大小相等
    • 其中一个维度的大小是 1
    • (其中一个维度不存在,已经被规则 1 补为 1 了)。

矩阵乘法广播

传统的矩阵乘法就是只有两个维度参与(分别是 [m, k], [k, n])。但是在实际生产中,往往还会同时进行多个批次,也就是 [b, m, k], [b, k, n] ,同时进行 b 个矩阵乘法。

正是基于这种事实,矩阵乘法广播的规则只应用于除了最后两个维度之外的“批次维度” (Batch Dimensions)。最后两个维度被保留用于执行实际的、标准的矩阵乘法。