您的位置:首页 > 房产 > 建筑 > Einsum(Einstein summation convention)

Einsum(Einstein summation convention)

2025/1/5 17:02:44 来源:https://blog.csdn.net/weixin_48524215/article/details/142032537  浏览:    关键词:Einsum(Einstein summation convention)

Einsum(Einstein summation convention)

笔记来源:
Permute和Reshape嫌麻烦?einsum来帮忙!

The Einstein summation convention is a notational shorthand used in tensor calculus, particularly in the fields of physics and mathematics, to simplify the representation of sums over indices in tensor equations. This convention is widely used in general relativity and other areas involving tensors.

爱因斯坦求和约定(Einstein summation convention)的一种函数,它通过指定的索引规则执行张量操作。爱因斯坦求和表示法使得高维张量运算更加直观灵活,比如可以用于更复杂的张量运算,而不仅限于简单的矩阵乘法

numpy.einsum(subscripts,*operands,out=None,dtype=None,order='K',casting='safe',optimize=False)
torch.einsum(equation,*operands)
tf.einsum(equation,*inputs,**kwargs)

1.1 什么是einsum?运算规则是什么?

在求和公式中,某些下标在等式两边都有出现(例如下标 i i i)而有些下标(被求和的维度)只出现在一侧(例如下标 j j j


所以即便是省略求和符号也不会产生歧义,即我们仍然知道哪个维度被求和了

这种运算与变量是什么并没有关系,因此上式可以进一步简化

1.在不同输入之间重复出现的索引表示沿着这一维度进行乘法(例如 k k k
2.只出现在输入中的索引表示在这一维度上求和(例如输出有 i , j i,j i,j,也就是说 k k k只出现在输入中)

C = torch.einsum('ik,kj->ij',A,B)
# 箭头和箭头右侧的可以省略
C = torch.einsum('ik,kj',A,B)
#等价于 
C = torch.matmul(A, B)

3.输出中维度的顺序可以是任意的(例如 i j ij ij j i ji ji
这里 C j i C_{ji} Cji就是 C i j C_{ij} Cij的转置

C = torch.einsum('ik,kj->ji',A,B)
# 省略号可以用于broadcasting,也就是忽略不关心的维度,只对最后两个维度进行计算
C = torch.einsum('...ik,...kj->ji',A,B)

1.2 Einsum怎么用?

向量外积



C = torch.einsum('i,j->ij',a,b)

提取对角元素


a = torch.einsum('kk->k',A)

1.3 Einum在多头注意力的应用

原版本

qkv:torch.Tensor self.qkv(x) #B,patches,3*dim
qkv = qkv.reshape(B, patches, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2,0,3,1,4)
q:torch.Tensor = qkv[0]
k:torch.Tensor = qkv[1]
v:torch.Tensor = qkv[2]
k_t = k.transpose(-2,-1)
attn = torch.softmax(q k_t self.scale,dim=-1)
attn = self.attn_drop(attn)
wa = attn @ v
wa = wa.transpose(1,2)
wa = wa.flatten(2)

使用Einum简化

q,k,v = map(lambda t:rearrange(t,'b n (h d)->b h n d',h=self.num_heads),qkv)
attn = torch.einsum('bijc,bikc -bijk',q,k)*self.scale
attn = attn.softmax(dim=-1)
x = torch.einsum('bijk,bikc -bijc',attn,v)
x = rearrange(x,'b i jc->b j (i c)')

(1)使用EINOPS库中的rearrange操作

q,k,v = map(lambda t:rearrange(t,'b n (h d)->b h n d',h=self.num_heads),qkv)


(2)q乘k转置除以缩放比例

attn = torch.einsum('bijc,bikc -bijk',q,k)*self.scale


(3)softmax得到attention数值

attn = attn.softmax(dim=-1)


(4)attention值对v加权

x = torch.einsum('bijk,bikc -bijc',attn,v)


(5)将x的维度还原为输入的形式

x = rearrange(x,'b i jc->b j (i c)')

[ B , h e a d , N , C / / h e a d ] − > [ B , N , C ] [B,head,N,C//head]->[B,N,C] [B,head,N,C//head]>[B,N,C]

1.4 Einsum优缺点

优点:

  1. 一次调用、一个函数完成多个操作
  2. 有时比多个Permute和Transpose操作组合的可读性高
  3. 可以避免生成中间变量

缺点:
求和表达式复杂时耗费内存,导致性能问题

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com