发布时间:2024-09-23 14:01
如果问pytorch中最强大的一个数学函数是什么?
我会说是torch.einsum:爱因斯坦求和函数。
它几乎是一个"万能函数":能实现超过一万种功能的函数。
不仅如此,和其它pytorch中的函数一样,torch.einsum是支持求导和反向传播的,并且计算效率非常高。
einsum 提供了一套既简洁又优雅的规则,可实现包括但不限于:内积,外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练掌握 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。
尤其是在一些包括batch维度的高阶张量的相关计算中,若使用普通的矩阵乘法、求和、转置等算子来实现很容易出现维度匹配等问题,但换成einsum则会特别简单。
套用一句深度学习paper标题当中非常时髦的话术,torch.einsum is all you needed !
公众号后台回复关键词:einsum,获取本文源代码链接。
顾名思义,einsum这个函数的思想起源于家喻户晓的小爱同学:爱因斯坦~。
很久很久以前,小爱同学在捣鼓广义相对论。广义相对论表述各种物理量用的都是张量。
比如描述时空有一个四维时空度规张量,描述电磁场有一个电磁张量,描述运动的有能量动量张量。
在理论物理学家中,小爱同学的数学基础不算特别好,在捣鼓这些张量的时候,他遇到了一个比较头疼的问题:公式太长太复杂了。
有没有什么办法让这些张量运算公式稍微显得对人类友好一些呢,能不能减少一些那种扭曲的求和符号呢?
小爱发现,求和导致维度收缩,因此求和符号操作的指标总是只出现在公式的一边。
例如在我们熟悉的矩阵乘法中
k这个下标被求和了,求和导致了这个维度的消失,所以它只出现在右边而不出现在左边。
这种只出现在张量公式的一边的下标被称之为哑指标,反之为自由指标。
小爱同学脑瓜子滴溜一转,反正这种只出现在一边的哑指标一定是被求和求掉的,干脆把对应的求和符号省略得了。
这就是爱因斯坦求和约定:
只出现在公式一边的指标叫做哑指标,针对哑指标的求和符号可以省略。
公式立刻清爽了很多。
这个公式表达的含义如下:
C这个张量的第i行第j列由这个张量的第i行第k列和这个张量的第k行第j列相乘,这样得到的是一个三维张量, 其元素为,然后对在维度k上求和得到。
公式展现形式中除了省去了求和符号,还省去了乘法符号(代数通识)。
借鉴爱因斯坦求和约定表达张量运算的清爽整洁,numpy、tensorflow和 torch等库中都引入了 einsum这个函数。
上述矩阵乘法可以被einsum这个函数表述成
C = torch.einsum("ik,kj->ij",A,B)
这个函数的规则原理非常简洁,3句话说明白。
1,用元素计算公式来表达张量运算。
2,只出现在元素计算公式箭头左边的指标叫做哑指标。
3,省略元素计算公式中对哑指标的求和符号。