1. torch.einsum(‘bnd, bmd->bnm’, x, y)
torch.einsum(‘bnd, bmd->bnm’, x, y) 表示的是对张量 x 和 y 进行特定的求和和维度变换。
具体来说,这个操作的输入是两个形状为 [b, n, d] 和 [b, m, d] 的张量 x 和 y,输出是一个形状为 [b, n, m] 的张量 z。其计算过程可以理解为:对于每个 b,z[b, n, m] 等于 x[b, n, :] 和 y[b, m, :] 之间的点积。
为了用普通的 torch 操作符来替代 einsum,我们可以通过 torch.matmul 函数实现。这个函数可以用来执行批量矩阵乘法,并且能够很好地替代这个 einsum 操作。
具体实现如下:
import torch# 假设 x 和 y 的形状分别为 (b, n, d) 和 (b, m, d)
x = torch.randn(10, 20, 30) # 举例
y = torch.randn(10, 15, 30) # 举例# einsum: z = torch.einsum('bnd, bmd->bnm', x, y)
# 可以转换为以下操作:
z = torch.matmul(x, y.transpose(-1, -2)) # z 的形状为 (b, n, m)# 检查 z 的形状是否正确
print(z.shape)