Logo

dev-resources.site

for different kinds of informations.

The Unreasonable Usefulness of numpy's einsum

Published at
11/3/2024
Categories
python
numpy
datascience
ai
Author
kylepena
Categories
4 categories in total
python
open
numpy
open
datascience
open
ai
open
Author
8 person written this
kylepena
open
The Unreasonable Usefulness of numpy's einsum

Introduction

I'd like to introduce you to the most useful method in Python, np.einsum.

With np.einsum (and its counterparts in Tensorflow and JAX), you can write complicated matrix and tensor operations in an extremely clear and succinct way. I've also found that its clarity and succinctness relieves a lot of the mental overload that comes with working with tensors.

And it's actually fairly simple to learn and use. Here's how it works:

In np.einsum, you have a subscripts string argument and you have one or more operands:

numpy.einsum(subscripts : string, *operands : List[np.ndarray])
Enter fullscreen mode Exit fullscreen mode

The subscripts argument is a "mini-language" that tells numpy how to manipulate and combine the axes of the operands. It's a little difficult to read at first, but it's not bad when you get the hang of it.

Single Operands

For a first example, let's use np.einsum to swap the axes of (a.k.a. take the transpose) a matrix A:

M = np.einsum('ij->ji', A)
Enter fullscreen mode Exit fullscreen mode

The letters i and j are bound to the first and second axes of A. Numpy binds letters to axes in the order they appear, but numpy doesn't care what letters you use if you are explicit. We could have used a and b, for example, and it works the same way:

M = np.einsum('ab->ba', A)
Enter fullscreen mode Exit fullscreen mode

However, you must supply as many letters as there are axes in the operand. There are two axes in A, so you must supply two distinct letters. The next example won't work because the subscripts formula only has one letter to bind, i:

# broken
M = np.einsum('i->i', A)
Enter fullscreen mode Exit fullscreen mode

On the other hand, if the operand does indeed have only one axis (i.o.w., it is a vector), then the single-letter subscript formula works just fine, although it isn't very useful because it leaves the vector a as-is:

m = np.einsum('i->i', a)
Enter fullscreen mode Exit fullscreen mode

Summing Over Axes

But what about this operation? There's no i on the right-hand-side. Is this valid?

c = np.einsum('i->', a)
Enter fullscreen mode Exit fullscreen mode

Surprisingly, yes!

Here is the first key to understanding the essence of np.einsum: If an axis is omitted from the right-hand-side, then the axis is summed over.

sum over i

Code:

c = 0
I = len(a)
for i in range(I):
   c += a[i]
Enter fullscreen mode Exit fullscreen mode

The summing-over behavior isn't limited to a single axis. For example, you can sum over two axes at once by using this subscript formula: c = np.einsum('ij->', A):

sum i and j

Here is the corresponding Python code:

c = 0
I,J = A.shape
for i in range(I):
   for j in range(J):
      c += A[i,j]
Enter fullscreen mode Exit fullscreen mode

But it doesn't stop there - we can get creative and sum some axes and leave others alone. For example: np.einsum('ij->i', A) sums the rows of matrix A, leaving a vector of row sums of length j:

Sum over j

Code:

I,J = A.shape
r = np.zeros(I)
for i in range(I):
   for j in range(J):
      r[i] += A[i,j]
Enter fullscreen mode Exit fullscreen mode

Likewise, np.einsum('ij->j', A) sums columns in A.

double sum

Code:

I,J = A.shape
r = np.zeros(J)
for i in range(I):
   for j in range(J):
      r[j] += A[i,j]
Enter fullscreen mode Exit fullscreen mode

Two Operands

There's a limit to what we can do with a single operand. Things get a lot more interesting (and useful) with two operands.

Let's suppose you have two vectors a = [a_1, a_2, ... ] and b = [a_1, a_2, ...].

If len(a) === len(b), we can compute the inner product (also called the dot product) like this:

a = np.asarray([4,5,6])
b = np.asarray([1,2,3])
c = np.einsum('i,i->', a, b)`
>> c := 32.0
Enter fullscreen mode Exit fullscreen mode

Two things are happening here simultaneously:

  1. Because i is bound to both a and b, a and b are "lined up" and then multiplied together: a[i] * b[i].
  2. Because the index i is excluded from the right-hand-side, axis i is summed over in order to eliminate it.

If you put (1) and (2) together, you get the classic inner product.

Inner Product

Code:

c = 0
I = len(a)
for i in range(I):
   c += a[i] * b[i]
Enter fullscreen mode Exit fullscreen mode

Now, let's suppose that we didn't omit i from the subscript formula, we would multiply all a[i] and b[i] and not sum over i:

a = np.asarray([4,5,6])
b = np.asarray([1,2,3])
c = np.einsum(`i,i->i`, a, b)
>> c := np.asarray([4,10,18])
Enter fullscreen mode Exit fullscreen mode

pointwise

Code:

I = len(a)
c = np.zeros(I)
for i in range(I):
   c[i] = a[i] * b[i]
Enter fullscreen mode Exit fullscreen mode

This is also called element-wise multiplication (or the Hadamard Product for matrices), and is typically done via the numpy method np.multiply.

Finally, let's suppose we included all the axes in the output - both i and j. This is called the outer product.

a = np.asarray([4,5,6])
b = np.asarray([1,2,3])
C = np.einsum(`i,j->ij`, a, b)
>> C := np.asarray([[4,8,12],[5,10,15],[6,12,18]])
Enter fullscreen mode Exit fullscreen mode

In this subscript formula, the axes of a and b are bound to separate letters, and thus are treated as separate "loop variables". Therefore C has entries a[i] * b[j] for all i and j, arranged into a matrix.

outer product

Code:

I = len(a)
J = len(b)
C = np.zeros(I,J)
for i in range(I):
   for j in range(J):
      C[i,j] = a[i] * b[j]
Enter fullscreen mode Exit fullscreen mode

Three Operands

Taking the outer product a step further, here's a three-operand version:

M = np.einsum('i,j,k->ijk', a, b, c) 
Enter fullscreen mode Exit fullscreen mode

three axis outer product

The equivalent Python code for our three-operand outer product is:

I = len(a)
J = len(b)
K = len(c)
for i in range(I):
   for j in range(J):
      for j in range(K):
         M[i,j,k] = a[i] * b[j] * c[k]
Enter fullscreen mode Exit fullscreen mode

Going even further, there's nothing stopping us from omitting axes to sum over them in addition to transposing the result by writing ki instead of ik on the right-hand-side of ->:

M = np.einsum('i,j,k->ki', a, b, c)
Enter fullscreen mode Exit fullscreen mode

The equivalent Python code would read:

I = len(a)
J = len(b)
K = len(c)
M = np.zeros(K,I)
for i in range(I):
   for j in range(J):
      for k in range(K):
         M[k,i] += a[i] * b[j] * c[k]
Enter fullscreen mode Exit fullscreen mode

Now I hope you can begin to see how you can specify complicated tensor operations rather easily. What's more, I can readily read off the above operation straight from the subscripts: "The outer product of three vectors, with the middle axes summed over, and the final result transposed". Pretty neat, but is this just academic? I don't think so.

A Practical Example

For a practical example, let's implement the equation at the heart of LLMs, from the classic paper "Attention is All You Need".

Eq. 1 describes the Attention Mechanism:

Attention

We'll focus our attention on the term QKTQK^T , because softmax isn't computible by np.einsum and the scaling factor 1dk\frac{1}{\sqrt{d_k}} is trivial to apply.

The QKTQK^T term represents the dot products of m queries with n keys. Q is a collection of m d-dimensional row vectors stacked into a matrix, so Q has the shape md. Likewise, K is a collection of n d-dimensional row vectors stacked into a matrix, so K has the shape md.

The product between a single Q and K would be written as:

np.einsum('md,nd->mn', Q, K)

Note that because of the way we wrote our subscripts equation, we avoided having to transpose K prior to matrix multiplication!

Q kt

So, that seems pretty straightforward - in fact, it's just a traditional matrix multiplication. However, we're not done yet. Attention Is All You Need uses multi-head attention, which means we really have k such matrix multiplies happening simultaneously over an indexed collection of Q matrices and K matrices.

To make things a bit clearer, we might rewrite the product as QiKiTQ_iK_i^T .

That means we have an additional axis i for both Q and K.

And what's more, if we are in a training setting, we are probably executing a batch of such multi-headed attention operations.

So presumably would want to perform the operation over a batch of examples along a batch axis b. Thus, the complete product would be something like:

batch_multihead_QKt = np.einsum('bimd,bind->bimn', Q, K, optimize = True)
Enter fullscreen mode Exit fullscreen mode

I'm going to skip the diagram here because we're dealing with 4-axis tensors. But you might be able to picture "stacking" the earlier diagram to get our multi-head axis i, and then "stacking" those "stacks" to get our batch axis b.

It's difficult for me to see how we would implement such an operation with any combination of the other numpy methods. Yet, with a little bit of inspection, it's clear what's happening: Over a batch, over a collection of matrices Q and K, perform the matrix multiplication Qt(K).

Now, isn't that wonderful?

Small Update

It was pointed out to me that the Attention example boils down to slicewise matmul, and that is supported by matmul out-of-the-box. Fair point! Originally, this post was using a different example that doesn't boil down to slicewise matmul, but it was far less topical than LLMs.

Also, it was pointed out to me that the performance of np.einsum greatly suffers unless you specifically set optimize=True in the parameter list. I investigated this topic in detail in my follow-up blog posts here.

Shameless Plug

After doing the founder mode grind for a year, I'm looking for work. I've got over 15 years experience in a wide variety of technical fields and programming languages and also experience managing teams. Math and statistics are focus areas. DM me on Twitter or on LinkedIn and let's talk!

numpy Article's
30 articles in total
Favicon
Basics of Python in 1 minute
Favicon
Previewing a .npy file
Favicon
Python NumPy Tutorial for Beginners: Learn Array Creation, Indexing, and More
Favicon
A Visual Guide to Affine Transformations: Translation, Scaling, Rotation, and Shear
Favicon
NumPy for Machine Learning & Deep Learning
Favicon
Investigating the performance of np.einsum
Favicon
The Unreasonable Usefulness of numpy's einsum
Favicon
ML Zoomcamp Week 1
Favicon
Python Data Wrangling and Data Quality
Favicon
Build Your Own AI Language Model with Python and NumPy
Favicon
Streamline Your NumPy File Conversions with npyConverter
Favicon
Streamline Plots: NumPy to Jupyter, No Loops
Favicon
PYTHON 101: INTRODUCTION TO PYTHON FOR DATA ANALYTICS
Favicon
How to Install NumPy in PyCharm on a Mac
Favicon
NumPy for the Curious Beginner
Favicon
NumPy: The Superhero Library Python Deserves (But Maybe Didn't Know It Needed)
Favicon
Transforming Simplicity: Adapting Linear Regression to Capture Complex Non-Linear Phenomena with NumPy
Favicon
A Beginner's Guide to Python Libraries
Favicon
fix: A module that was compiled using NumPy 1.x cannot be run in NumPy 2.0.0 as it may crash
Favicon
NumPy Asarray Function: A Comprehensive Guide
Favicon
Device conversion with to() and from_numpy() and numpy() in PyTorch
Favicon
Master Linear Regression with NumPy: Step-by-Step Guide to Building and Optimizing Your First Model!
Favicon
A Comprehensive Guide to NumPy with Python šŸšŸŽ²
Favicon
Creating Line Plots with Object-Oriented API and Subplot Function in Python
Favicon
Numpy Isnumeric Function: Mastering Numeric String Validation
Favicon
Element-Wise Numerical Operations in NumPy: A Practical Guide with Examples
Favicon
NumPy for Beginners: Why You Should Rely on Numpy Arrays More
Favicon
NumPy's Argmax? How it Finds Max Elements from Arrays
Favicon
5 Exciting NumPy Challenges to Boost Your Programming Skills! šŸš€
Favicon
Array Manipulation: A Deep Dive into Insertions and Deletions

Featured ones: