为什么在 Einsum 之前进行标量乘法更快?

huangapple go评论54阅读模式
英文:

Why Is Scalar Multiply Before Einsum Faster?

问题

在TensorFlow Keras的Multi-Head Attention实现中,与首先评估分子不同,他们首先评估Q/√dₖ,并添加了以下注释:

> 注意:在einsum的较小端应用标量乘法可以提高XLA性能,但可能会在Transformer注意头中引入轻微的数值差异。

这样做为何更快?在einsum之后进行除法不会同样快吗?

英文:

In the TensorFlow Keras implementation of Multi-Head Attention, instead of evaluating the numerator first like in

为什么在 Einsum 之前进行标量乘法更快?

they evaluate Q/√dₖ first and put comment
> Note: Applying scalar multiply at the smaller end of einsum improves
> XLA performance, but may introduce slight numeric differences in
> the Transformer attention head.

How is it faster this way? Wouldn't the division after einsum be equally as fast?

答案1

得分: 1

以下是翻译好的部分:

"这个评论的建议是key中的元素数量少于以下方程中的queryattention_scores中的元素数量。

attention_scores = tf.einsum(self._dot_product_equation, key, query)

给定维度:

query形状为`(B, T, N, key_dim)`的投影查询张量
key形状为`(B, S, N, key_dim)`的投影关键字张量

假设_dot_product_equation只是执行批次矩阵乘法,如果Q是T x N,而K是S x N,则乘积Q @ K.TT x S,如果S > N,预计左侧的乘法数量将较小。

但无论如何,除非S > T * N(或XLA存在错误),否则这不应该是主要部分。"

英文:

What the comment suggest is that the the number of elements in key is less than the number of elements in query or attention_scores in the following equation.

attention_scores = tf.einsum(self._dot_product_equation, key, query)

Given the dimensions

            query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
            key: Projected key `Tensor` of shape `(B, S, N, key_dim)`.

Assuming that _dot_product_equation is simply doing the batched matrix multiplication, if Q is T x N, and Q is S x N, the product Q @ K.T is T x S, if S > N the number of multiplications is expected to be smaller on the left.

But either way that should not be the dominant part except if S > T * N (or XLA has a bug).

huangapple
  • 本文由 发表于 2023年5月21日 07:05:43
  • 转载请务必保留本文链接:https://go.coder-hub.com/76297659.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定