Let us try torch.einsum() first (without using GPU). We will try np.einsum() later and see that the observations are similar.

Define three tensors and also check the pytorch version.

In [1]:
import torch
n, h, d = 101, 10, 39
K = torch.rand(n,h,d)
Q = torch.rand(n,h,d)
e = torch.rand(h,d)

print(torch.__version__)
2.9.1

Profile torch.einsum() without using pytorch optimization. The code is slow and consumes a lot of memory.

In [2]:
%pip install opt-einsum

def original_einsum():
    out = torch.einsum('ihd,jhd,hd->ijh', K, Q, e)
    return out

from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    torch.backends.opt_einsum.enabled = False
    out = original_einsum()
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
Requirement already satisfied: opt-einsum in /Users/jiechen/anaconda3/envs/torch/lib/python3.12/site-packages (3.4.0)
Note: you may need to restart the kernel to use updated packages.
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::mul        14.33%       3.102ms        14.33%       3.102ms       3.102ms      15.18 MB      15.18 MB             1  
               aten::bmm        32.36%       7.004ms        33.46%       7.243ms       7.243ms     398.48 KB     398.48 KB             1  
         aten::unsqueeze         7.15%       1.547ms        10.77%       2.331ms     582.792us           0 B           0 B             4  
        aten::as_strided         3.72%     804.667us         3.72%     804.667us      20.117us           0 B           0 B            40  
           aten::permute         2.78%     600.706us         2.81%     608.706us     101.451us           0 B           0 B             6  
           aten::reshape         3.54%     766.084us        16.84%       3.645ms       1.822ms           0 B           0 B             2  
    aten::_reshape_alias         1.65%     356.458us         1.65%     356.458us     356.458us           0 B           0 B             1  
              aten::view        11.68%       2.529ms        11.68%       2.529ms     842.972us           0 B           0 B             3  
      aten::resolve_conj         0.72%     154.790us         0.72%     154.790us      12.899us           0 B           0 B            12  
            aten::select         0.33%      71.580us         0.39%      83.997us       2.800us           0 B           0 B            30  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.644ms

Profile torch.einsum() using pytorch optimization. The code is much faster and memory consumption significantly drops.

In [3]:
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    torch.backends.opt_einsum.enabled = True
    out = original_einsum()
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::bmm        86.88%       1.179ms        91.80%       1.246ms       1.246ms     398.48 KB     398.48 KB             1  
               aten::mul         3.15%      42.792us         3.15%      42.792us      42.792us     153.87 KB     153.87 KB             1  
         aten::unsqueeze         0.62%       8.420us         0.87%      11.752us       2.938us           0 B           0 B             4  
        aten::as_strided         1.90%      25.830us         1.90%      25.830us       0.646us           0 B           0 B            40  
           aten::permute         0.89%      12.082us         1.18%      16.082us       2.680us           0 B           0 B             6  
           aten::reshape         0.40%       5.418us         0.60%       8.209us       4.104us           0 B           0 B             2  
    aten::_reshape_alias         0.21%       2.791us         0.21%       2.791us       1.396us           0 B           0 B             2  
      aten::resolve_conj         0.39%       5.333us         0.39%       5.333us       0.242us           0 B           0 B            22  
            aten::select         3.17%      43.044us         4.53%      61.542us       2.051us           0 B           0 B            30  
              aten::view         0.34%       4.667us         0.34%       4.667us       2.333us           0 B           0 B             2  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.357ms

Profile the ChatGPT-suggested implementation.

ChatGPT prompt:

Optimize the following python code, where K, Q, e are tensors of shape (101, 10, 39), (101, 10, 39), (10, 39), respectively.

out = torch.einsum('ihd,jhd,hd->ijh', K, Q, e)

The optimized code is incorrect.

In [4]:
def chatgpt_optimized():
    Q_weighted = Q * e                # (101, 10, 39)
    out = torch.matmul(
        K,                            # (101, 10, 39)
        Q_weighted.transpose(-1, -2)  # (101, 39, 10)
    )                                 # (101, 10, 10)
    out = out.permute(0, 2, 1)        # (101, 10, 10)
    return out

with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    chatgpt_optimized()
    try:
        assert chatgpt_optimized().shape == (n, n, h)
    except AssertionError:
        print(f"Assertion failed: Output shape is incorrect")
# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
Assertion failed: Output shape is incorrect

Profile our hand-optimized code. The code is even faster than the pytorch optimized torch.einsum() while memory consumption is the same.

In [5]:
def hand_optimized():
    Qe = Q * e
    out = K.permute(1,0,2) @ Qe.permute(1,2,0)
    out = out.permute(1,2,0)
    return out

with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    out = hand_optimized()
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
               aten::bmm        61.43%     246.388us        78.02%     312.923us     312.923us     398.48 KB     398.48 KB             1  
               aten::mul        11.57%      46.411us        11.57%      46.411us      46.411us     153.87 KB     153.87 KB             1  
           aten::permute         2.46%       9.873us         3.29%      13.206us       4.402us           0 B           0 B             3  
        aten::as_strided         5.59%      22.411us         5.59%      22.411us       0.640us           0 B           0 B            35  
            aten::matmul         2.22%       8.915us        85.14%     341.462us     341.462us     398.48 KB           0 B             1  
            aten::expand         1.17%       4.708us         1.88%       7.541us       3.770us           0 B           0 B             2  
           aten::reshape         1.46%       5.875us         2.25%       9.041us       4.520us           0 B           0 B             2  
    aten::_reshape_alias         0.79%       3.166us         0.79%       3.166us       1.583us           0 B           0 B             2  
      aten::resolve_conj         1.88%       7.542us         1.88%       7.542us       0.343us           0 B           0 B            22  
            aten::select        10.66%      42.748us        14.71%      58.993us       1.966us           0 B           0 B            30  
------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 401.079us

Now, let us try np.einsum().

Define the numpy arrays and also check the numpy version.

In [6]:
%reset -f
%pip install memory_profiler
%reload_ext memory_profiler

import numpy as np
n, h, d = 101, 10, 39
K = np.random.rand(n,h,d)
Q = np.random.rand(n,h,d)
e = np.random.rand(h,d)

print(np.__version__)
Requirement already satisfied: memory_profiler in /Users/jiechen/anaconda3/envs/torch/lib/python3.12/site-packages (0.61.0)
Requirement already satisfied: psutil in /Users/jiechen/anaconda3/envs/torch/lib/python3.12/site-packages (from memory_profiler) (7.2.0)
Note: you may need to restart the kernel to use updated packages.
2.4.0

Profile np.einsum() without using numpy optimization. The code is faster than torch.einsum(). Still, we will improve it later.

In [7]:
def original_einsum():
    out = np.einsum('ihd,jhd,hd->ijh', K, Q, e)
    return out

%timeit original_einsum()
# %memit original_einsum() # There is something wrong, perhaps due to package version
3 ms ± 41.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Profile np.einsum() using numpy optimization. The code is much faster compared with the previous code.

In [8]:
def einsum_opt():
    out = np.einsum('ihd,jhd,hd->ijh', K, Q, e, optimize=True)

%timeit einsum_opt()
# %memit einsum_opt() # There is something wrong, perhaps due to package version
99.8 μs ± 3.58 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Profile the ChatGPT-suggested implementation.

ChatGPT prompt:

Optimize the following python code, where K, Q, e are numpy arrays of shape (101, 10, 39), (101, 10, 39), (10, 39), respectively.

out = np.einsum('ihd,jhd,hd->ijh', K, Q, e)

The optimized code is incorrect (it is nearly identical to the pytorch version, suffering the same problem).

In [9]:
def chatgpt_optimized():
    Q_weighted = Q * e                   # (101, 10, 39)
    out = np.matmul(
        K,                               # (101, 10, 39)
        np.swapaxes(Q_weighted, -1, -2)  # (101, 39, 10)
    )                                    # (101, 10, 10)
    out = np.transpose(out, (0, 2, 1))   # (101, 10, 10)
    return out

try:
    assert chatgpt_optimized().shape == (n, n, h)
except AssertionError:
    print(f"Assertion failed: Output shape is incorrect")
else:
    %timeit chatgpt_optimized()
    # %memit chatgpt_optimized() # There is something wrong, perhaps due to package version
Assertion failed: Output shape is incorrect

Profile our hand-optimized code. The code is even faster than the numpy optimized np.einsum().

In [10]:
def hand_optimized():
    Qe = Q * e
    out = np.permute_dims(K, (1,0,2)) @ np.permute_dims(Qe, (1,2,0))
    out = np.permute_dims(out, (1,2,0))
    return out

%timeit hand_optimized()
# %memit hand_optimized() # There is something wrong, perhaps due to package version
81.3 μs ± 1.22 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)