Attention

Attention is often presented as the most complex block in a transformer. It is called attention because it applies a weight to each prior token vector. To pay more attention to some prior tokens, we can apply a larger weight. Hence, attention is effectively a weighted mean of the input .

For this problem we must compute the token frequencies. Luckily for us, these are found by taking the (uniformly-weighted) mean of the input , where is the th row:

In the last section, we embedded the tokens bab as: Now in the attention step, we take the weighted mean across the rows to compute the token frequencies: .

For simplicity, we will assume the attention weights are uniform. In later chapters, we'll discuss the math behind computing the weights and show that uniform weights are possible.

The residual stream

We mentioned in the last section that the residual stream always has shape in both the input and output of attention.

Yet if attention takes the mean across all tokens, why doesn't this result in a -dimensional output?

It turns out that attention takes the mean of each increasing subset of tokens. So the first output row will contain the mean of only the first token's embedding (i.e. itself). The second row will contain the mean of the first two embeddings, and so on.

Let's see a concrete example:

Suppose we have the input embeddings of 'bab': Then, the output of attention is: The first row is the first row of . The second row is the mean of the first two rows of . And the final row is the mean of all three rows .

It is also possible to compute this efficiently using linear algebra:

We define a triangular -shape weight matrix . Row specifies how to weight the first token embeddings to compute the th row of the output matrix : Note that row weights the first tokens equally with probability . The final row of corresponds to , where each token embedding is weighted uniformly as .

To apply the weight matrix to the token emebddings , we use matrix multiplication:

In attention, each row of the weight matrix will always sum to 1. This is a consequence of how the matrix is computed. We'll use this fact below to simplify some of the math.

Python Implementation

n_toks = X.shape[0]
W = np.tri(n_toks) * (1. / np.arange(1, n_toks+1)).reshape(-1,1)

# (n_toks, K) = (n_toks, n_toks) @ (n_toks, K)
Y = W @ X

Technical Details

The attention block can do more than merely take an average of the inputs. However, this is a convenient first-approximation of its workings.

Below, we'll walk through under what conditions the attention block can be reduced to a direct weighted mean of its inputs.

  1. In attention, the inputs first undergo a linear projection (value projection) plus bias . We'll define this as .

  2. is then weighted by and subjected to a second linear projection (output projection) plus bias . We'll call this the attention block output .

In many attention implementations, there is no second (output) projection. In this text, however, we are preparing the reader for multi-headed attention. In this, the purpose of the output projection is to provide a final projection after concatenating together the outputs of numerous heads (although the math below still applies!). We'll look at this in more depth in later chapters.

The two linear projections above are back-to-back, so they can be combined into a single effective transformation. By using the associative and distributive properties of matrix multiplication, we can (step-by-step) rewrite these two transformations as one in terms of the weighted inputs :

The attention output is now in terms of the weighted inputs rather than the weighted transformed inputs !

Note we combined the two transforms above into a single linear transformation with bias (with projection and bias ).

To obtain the final step, we assert that . To explain this, note that each row of sums to 1. Since is a matrix, it is broadcasted into a matrix, resulting in each row of being multiplied by a constant. For example, the first output position is . Therefore, .

A few important observations:

  1. Attention is typically described as weighting the transformed inputs (). However, we showed it can just as accurately be described as directly weighting the original inputs ().

  2. Interestingly, we showed that the value biases are not needed! We equivalently can "fold" them into the output biases. Just let and , giving a single transform . This identity is useful both for ease of interpretability and for reducing multiplications.

  3. In this way, it becomes clear exactly how we can choose , , and to return solely uniformly-weighted inputs. Just set and such that (the identity), then zero the effective output bias ().

Python Implementation

K = X.shape[1]    # num channels
VO = np.identity(K)
bo = np.zeros(K)

# (n_toks, K) = (n_toks, n_toks) @ (n_toks, K) @ (K, K) + (K,)
Y = W @ X @ VO + bo