Einstein sum is convenient but slow

it may be worth of time to optimize the einsum() code

The einsum() function in pytorch and numpy is a convenient tool to multiply tensors and matrices. The programmer does not need to worry about transposing or permuting the given tensors to the right shape before multiplication—actions that are error-prone. Moreover, the numerous, forgettable pytorch/numpy syntaxes to perform transposition and permutation sometimes are confusing.

For example, when multiplying two matrices $A$ and $B$, the second dimension of $A$ must match the first dimension of $B$. If $A$ and $B$ have shapes $(m,k)$ and $(n,k)$, respectively, one must transpose $B$ first before calling matmult(). On the other hand, einsum() is convenient. All one needs to do is to specify the rule 'mk,nk->mn'. The einsum() function will internally make the shapes (contractions) right.

This convenience comes with a performance cost. The slow speed of einsum() is due to the different time and storage costs of contraction orderings, with the default one not necessarily optimal. However, finding the optimal ordering is NP-hard, which is different from finding the optimal ordering of matrix chain multiplication that admits an $O(n^3)$ cost based on dynamic programming and even lower costs, such as $O(n \log n)$, based on more complex algorithms. Moreover, the time cost consideration does not only come from complexity but also from memory layout and cache hits.

Here are a few actions to take:

  • If you really want to use einsum(), turn on the optimization. For pytorch, install the package opt_einsum. For numpy, turn on optimize=True in np.einsum().
  • Hand optimize your code on your own.
  • Optimize your code by using ChatGPT or other AI tools.

In what follows, I show an example, where hand-optimization is better than pytorch/numpy optimization, while ChatGPT produces incorrect code.

I encountered this problem when playing with SAN, a Graph Transformer that fuses edge features into the attention matrix. The three tensors to be multiplied are:

  • the key tensor $K$ of shape $(n, h, d)$, where $n$ is the number of nodes (a similar notion to the sequence length in sequence models), $h$ is the number of attention heads, and $d$ is the embedding dimension;
  • the query tensor $Q$, having the same shape as $K$; and
  • the edge embedding tensor $e$ of shape $(h, d)$.

Here, I make a simplification for the edge tensor, whose precise shape is actually $(c, h, d)$, where $c$ is the number of edge types, because each type has $h$ embedding vectors of length $d$. To demonstrate the point of the slowness of einsum(), I simplify the edge tensor to that of shape $(h, d)$ only.

For a typical attention computation, we multiply $K$ with the transpose of $Q$ for each head, contracting through the dimension denoted by $d$. Now, to incorporate the edge features, the contraction is additionally done with the edge embedding per head as well. The Einstein summation is thus 'ihd,jhd,hd->ijh', producing an output tensor of shape $(n, n, h)$.

The default approach to performing the Einstein sum will perform a broadcast to make all three tensors have the shape $(n, n, h, d)$ (that is, 'ijhd') and then contract the last dimension. This incurs a huge amount of intermediate memory consumption (roughly $n$ times as large as the inputs) and is certainly too slow.

The optimization I do, is to first multiply $Q$ with $e$ element-wise, by broadcasting $e$ to match the shape of $Q$ (which is $(n, h, d)$), and then multiply the result with $K$ as a batch matrix-matrix multiplication:

  1. Reshape the element-wise product of $Q$ and $e$ to $(h, d, n)$.
  2. Reshape $K$ to $(h, n, d)$.
  3. Multiply the above two tensors by treating $h$ as the batch dimension; that is, $(h, n, d) \times (h, d, n) \to (h, n, n)$.
  4. Reshape the product such that it has shape $(n, n, h)$.

This optimization produces the fastest code I have seen.

The experiments are done in the following Jupyter notebook. You can download it here.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • What is the intuition behind the parameter-shift rule?
  • Glasser's master theorem and an application