Let us try torch.einsum() first (without using GPU). We will try np.einsum() later and see that the observations are similar.
Define three tensors and also check the pytorch version.
import torch
n, h, d = 101, 10, 39
K = torch.rand(n,h,d)
Q = torch.rand(n,h,d)
e = torch.rand(h,d)
print(torch.__version__)
2.9.1
Profile torch.einsum() without using pytorch optimization. The code is slow and consumes a lot of memory.
%pip install opt-einsum
def original_einsum():
out = torch.einsum('ihd,jhd,hd->ijh', K, Q, e)
return out
from torch.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
torch.backends.opt_einsum.enabled = False
out = original_einsum()
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
Requirement already satisfied: opt-einsum in /Users/jiechen/anaconda3/envs/torch/lib/python3.12/site-packages (3.4.0)
Note: you may need to restart the kernel to use updated packages.
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::mul 14.33% 3.102ms 14.33% 3.102ms 3.102ms 15.18 MB 15.18 MB 1
aten::bmm 32.36% 7.004ms 33.46% 7.243ms 7.243ms 398.48 KB 398.48 KB 1
aten::unsqueeze 7.15% 1.547ms 10.77% 2.331ms 582.792us 0 B 0 B 4
aten::as_strided 3.72% 804.667us 3.72% 804.667us 20.117us 0 B 0 B 40
aten::permute 2.78% 600.706us 2.81% 608.706us 101.451us 0 B 0 B 6
aten::reshape 3.54% 766.084us 16.84% 3.645ms 1.822ms 0 B 0 B 2
aten::_reshape_alias 1.65% 356.458us 1.65% 356.458us 356.458us 0 B 0 B 1
aten::view 11.68% 2.529ms 11.68% 2.529ms 842.972us 0 B 0 B 3
aten::resolve_conj 0.72% 154.790us 0.72% 154.790us 12.899us 0 B 0 B 12
aten::select 0.33% 71.580us 0.39% 83.997us 2.800us 0 B 0 B 30
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 21.644ms
Profile torch.einsum() using pytorch optimization. The code is much faster and memory consumption significantly drops.
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
torch.backends.opt_einsum.enabled = True
out = original_einsum()
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::bmm 86.88% 1.179ms 91.80% 1.246ms 1.246ms 398.48 KB 398.48 KB 1
aten::mul 3.15% 42.792us 3.15% 42.792us 42.792us 153.87 KB 153.87 KB 1
aten::unsqueeze 0.62% 8.420us 0.87% 11.752us 2.938us 0 B 0 B 4
aten::as_strided 1.90% 25.830us 1.90% 25.830us 0.646us 0 B 0 B 40
aten::permute 0.89% 12.082us 1.18% 16.082us 2.680us 0 B 0 B 6
aten::reshape 0.40% 5.418us 0.60% 8.209us 4.104us 0 B 0 B 2
aten::_reshape_alias 0.21% 2.791us 0.21% 2.791us 1.396us 0 B 0 B 2
aten::resolve_conj 0.39% 5.333us 0.39% 5.333us 0.242us 0 B 0 B 22
aten::select 3.17% 43.044us 4.53% 61.542us 2.051us 0 B 0 B 30
aten::view 0.34% 4.667us 0.34% 4.667us 2.333us 0 B 0 B 2
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 1.357ms
Profile the ChatGPT-suggested implementation.
ChatGPT prompt:
Optimize the following python code, where K, Q, e are tensors of shape (101, 10, 39), (101, 10, 39), (10, 39), respectively.
out = torch.einsum('ihd,jhd,hd->ijh', K, Q, e)
The optimized code is incorrect.
def chatgpt_optimized():
Q_weighted = Q * e # (101, 10, 39)
out = torch.matmul(
K, # (101, 10, 39)
Q_weighted.transpose(-1, -2) # (101, 39, 10)
) # (101, 10, 10)
out = out.permute(0, 2, 1) # (101, 10, 10)
return out
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
chatgpt_optimized()
try:
assert chatgpt_optimized().shape == (n, n, h)
except AssertionError:
print(f"Assertion failed: Output shape is incorrect")
# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
Assertion failed: Output shape is incorrect
Profile our hand-optimized code. The code is even faster than the pytorch optimized torch.einsum() while memory consumption is the same.
def hand_optimized():
Qe = Q * e
out = K.permute(1,0,2) @ Qe.permute(1,2,0)
out = out.permute(1,2,0)
return out
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
out = hand_optimized()
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::bmm 61.43% 246.388us 78.02% 312.923us 312.923us 398.48 KB 398.48 KB 1
aten::mul 11.57% 46.411us 11.57% 46.411us 46.411us 153.87 KB 153.87 KB 1
aten::permute 2.46% 9.873us 3.29% 13.206us 4.402us 0 B 0 B 3
aten::as_strided 5.59% 22.411us 5.59% 22.411us 0.640us 0 B 0 B 35
aten::matmul 2.22% 8.915us 85.14% 341.462us 341.462us 398.48 KB 0 B 1
aten::expand 1.17% 4.708us 1.88% 7.541us 3.770us 0 B 0 B 2
aten::reshape 1.46% 5.875us 2.25% 9.041us 4.520us 0 B 0 B 2
aten::_reshape_alias 0.79% 3.166us 0.79% 3.166us 1.583us 0 B 0 B 2
aten::resolve_conj 1.88% 7.542us 1.88% 7.542us 0.343us 0 B 0 B 22
aten::select 10.66% 42.748us 14.71% 58.993us 1.966us 0 B 0 B 30
------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 401.079us
Now, let us try np.einsum().
Define the numpy arrays and also check the numpy version.
%reset -f
%pip install memory_profiler
%reload_ext memory_profiler
import numpy as np
n, h, d = 101, 10, 39
K = np.random.rand(n,h,d)
Q = np.random.rand(n,h,d)
e = np.random.rand(h,d)
print(np.__version__)
Requirement already satisfied: memory_profiler in /Users/jiechen/anaconda3/envs/torch/lib/python3.12/site-packages (0.61.0) Requirement already satisfied: psutil in /Users/jiechen/anaconda3/envs/torch/lib/python3.12/site-packages (from memory_profiler) (7.2.0) Note: you may need to restart the kernel to use updated packages. 2.4.0
Profile np.einsum() without using numpy optimization. The code is faster than torch.einsum(). Still, we will improve it later.
def original_einsum():
out = np.einsum('ihd,jhd,hd->ijh', K, Q, e)
return out
%timeit original_einsum()
# %memit original_einsum() # There is something wrong, perhaps due to package version
3 ms ± 41.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Profile np.einsum() using numpy optimization. The code is much faster compared with the previous code.
def einsum_opt():
out = np.einsum('ihd,jhd,hd->ijh', K, Q, e, optimize=True)
%timeit einsum_opt()
# %memit einsum_opt() # There is something wrong, perhaps due to package version
99.8 μs ± 3.58 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Profile the ChatGPT-suggested implementation.
ChatGPT prompt:
Optimize the following python code, where K, Q, e are numpy arrays of shape (101, 10, 39), (101, 10, 39), (10, 39), respectively.
out = np.einsum('ihd,jhd,hd->ijh', K, Q, e)
The optimized code is incorrect (it is nearly identical to the pytorch version, suffering the same problem).
def chatgpt_optimized():
Q_weighted = Q * e # (101, 10, 39)
out = np.matmul(
K, # (101, 10, 39)
np.swapaxes(Q_weighted, -1, -2) # (101, 39, 10)
) # (101, 10, 10)
out = np.transpose(out, (0, 2, 1)) # (101, 10, 10)
return out
try:
assert chatgpt_optimized().shape == (n, n, h)
except AssertionError:
print(f"Assertion failed: Output shape is incorrect")
else:
%timeit chatgpt_optimized()
# %memit chatgpt_optimized() # There is something wrong, perhaps due to package version
Assertion failed: Output shape is incorrect
Profile our hand-optimized code. The code is even faster than the numpy optimized np.einsum().
def hand_optimized():
Qe = Q * e
out = np.permute_dims(K, (1,0,2)) @ np.permute_dims(Qe, (1,2,0))
out = np.permute_dims(out, (1,2,0))
return out
%timeit hand_optimized()
# %memit hand_optimized() # There is something wrong, perhaps due to package version
81.3 μs ± 1.22 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)