Why einsum?

NumPy and PyTorch's einsum function lets you express any combination of

tensor contractions, transposes, and sums using a compact index notation

inspired by Einstein's summation convention.

Instead of:

# Batched matrix multiply — verbose
result = torch.bmm(A, B)

You write:

result = torch.einsum('bik,bkj->bij', A, B)

Reading the Notation

The string 'bik,bkj->bij' is read as:

PartMeaning
bikFirst tensor has axes batch, i, k
bkjSecond tensor has axes batch, k, j
->bijOutput has axes batch, i, j
k absent in outputSum over k (contraction)

Common Patterns

# Matrix multiply
torch.einsum('ij,jk->ik', A, B)

# Batch matrix multiply
torch.einsum('bij,bjk->bik', A, B)

# Dot product
torch.einsum('i,i->', a, b)

# Outer product
torch.einsum('i,j->ij', a, b)

# Trace
torch.einsum('ii->', A)

# Transpose
torch.einsum('ij->ji', A)

# Element-wise then sum (like torch.sum(A * B))
torch.einsum('ij,ij->', A, B)

Practical Example — Attention Scores

The scaled dot-product attention score matrix is a perfect einsum use case:

# Q: (batch, heads, seq, d_k)
# K: (batch, heads, seq, d_k)
scores = torch.einsum('bhid,bhjd->bhij', Q, K) / d_k ** 0.5

Clean, explicit, and often faster than equivalent torch.matmul with

reshaping because einsum can fuse operations under the hood.

Tips

  • Shared indices that appear in both inputs but not the output are contracted (summed over).
  • Indices only in one input are free — they pass through to the output.
  • Use opt_einsum for optimal contraction order on large tensors.