LLM Foundations (early book draft)

by Dan Wilhelm

In this book, we develop a low-level understanding of Large Language Models (LLMs). To understand large circuits, we apply learnings from mathematics and the analysis of small circuits.

This book has an accompanying YouTube channel and project GitHub repo (coming soon).

To get the most out of this book, we recommend the reader be fluent in Python and know the basics of NumPy, linear algebra, and machine learning. Therefore, our target audience includes researchers, CS students, and software engineers with an interest in LLMs.

Currently, our focus is on analysis rather than training.

This book is hosted on GitHub, and suggestions/revisions are welcome. On every page in the upper right, there are icons for visiting the repo and editing the current page.

Introduction

Our thesis is that understanding small circuits leads to techniques for understanding large circuits.

What is "understanding"?

Unfortunately, machine learning educators often introduce neural networks as "black box models." As an educator myself, I've found this often dissuades students from believing an understanding could exist. So, what do we mean by "understanding"?

  • First, neural networks are well-known to (often poorly) approximate some "true" function. As we'll see in chapter 3, a six-neuron neural network is required to approximate a modified square wave. Although the neural net describes a complex non-linear equation, the "true" structure is a simple square wave. By reconstructing this human-understandable "true" structure, we represent the problem simply and gain an "understanding."

  • Second, from engineering we know that complex systems work because they have guiding principles of design and operation. For example, in chapter 1 we will create a small circuit that solves cryptograms. There are only so many algorithms for solving cryptograms, namely by frequency analysis of the letters. Therefore, any circuit that does solve them must use one of these algorithms. By knowing which algorithm is used and the implementing mechanism, we have gained an "understanding" of the circuit.

  • Finally, we ascribe to the adage that we only demonstrate true understanding if we can build something from scratch. Toward this, we often will analyze pre-trained circuits to discover their principles of operation. Using these, we'll attempt to reconstruct the weights by hand. If we obtain a similar output, then we will claim to have "understood" the circuit.

Acknowledgements

The first chapter's Caesar cipher problem was originally posed by Callum McDougall as an interpretability challenge as part of his ARENA bootcamp. For more practice, you can explore his PyTorch-trained transformer model of the same problem.

Chapter 1: Solving Cryptograms by Designing a Transformer

In this chapter, we will encode natural language statistics in transformer weights. To gain a strong understanding, we'll handcraft a transformer circuit step-by-step from first principles. The result will solve cryptograms encoded using a Caesar cipher (a fixed-letter rotation).

Cryptogram Intro

A Caesar cipher is a code where each letter has been rotated forward by a fixed number. Given rotated text (ciphertext), the challenge is to determine the original rotation number (and thereby recover the original plaintext).

The plaintext a bay becomes ciphertext d edb with rotation 3.

Note that only the letters a-z will be rotated. For example, the space character in "a bay" is not rotated.

A transformer only accepts a fixed vocabulary of possible tokens. In this problem, we will allow 27 tokens -- the lowercase letters a-z and the space character.


Python Implementation

We encourage you to code along in your own notebook! For this chapter, visit GitHub to download three Project Gutenberg text files and our minimal plotting functions in llm_plots.py. Place these in the same directory as your notebook, then run the notebook server from this directory to ensure it is on your PATH.

Here are the Python packages we'll use:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression

from dan_plots import rowplot, implot, listplot

np.set_printoptions(suppress=True)  # suppress scientific notation

Now let's implement the Caesar rotation in Python! Our model requires indexes into the vocabulary as input. So instead of storing ciphertext as a string, we will represent it as an array of rotated numbers (rotnums). For clarity, all strings will be assumed plaintext and all lists of numbers ciphertext.

The ciphertext d edb will be stored as [3, 26, 4, 3, 1], where 'a'=0 and ' '=26.

For now, we will use classes primarily to help us organize and encapsulate constants. The class CaesarDataset will load text and rotate it, starting with:

class CaesarDataset:
    SEQ_LEN = 32     # chars per cryptogram
    N_ROTS = 26      # total letters/rotations
    
    # VOCAB: Must begin with a-z and include space.
    VOCAB = np.array(list('abcdefghijklmnopqrstuvwxyz '))
    VOCAB_SET = set(VOCAB)                         # O(1) membership test
    VOCAB_IDX = {l:i for i,l in enumerate(VOCAB)}  # O(1) char->index

    @staticmethod
    def plaintext_to_rotnums(plaintext, rot=0):
        '''Rotate letters forward, returning rotated numbers.
        Input characters must be in VOCAB.'''
        vocab_idxs = np.array([CaesarDataset.VOCAB_IDX[c] for c in plaintext])
        return np.where(vocab_idxs < CaesarDataset.N_ROTS, 
                        (vocab_idxs + rot) % CaesarDataset.N_ROTS,
                        vocab_idxs)  # do not rotate if not a-z

    @staticmethod
    def rotnums_to_plaintext(rotnums, rot=0):
        'Rotate numbers backward, returning plaintext.
        Input numbers must be less than the VOCAB size.'
        return ''.join(CaesarDataset.VOCAB[
            np.where(rotnums < CaesarDataset.N_ROTS,
                     (rotnums - rot) % CaesarDataset.N_ROTS,
                     rotnums)])  # do not rotate if not a-z
Example Run (expand to view)
rot = 5
plaintext = 'the quick brown fox jumps over the lazy dog'
rotnums = CaesarDataset.plaintext_to_rotnums(plaintext, rot)

print('plaintext:', plaintext)
print('ciphertext:', CaesarDataset.rotnums_to_plaintext(rotnums))
print('plaintext (hopefully!):',
        CaesarDataset.rotnums_to_plaintext(rotnums, rot))

This outputs:

plaintext: the quick brown fox jumps over the lazy dog
ciphertext: ymj vznhp gwtbs ktc ozrux tajw ymj qfed itl
plaintext (hopefully!): the quick brown fox jumps over the lazy dog

Letter Frequency Solution

In this section, we'll solve cryptograms by comparing the ciphertext letter frequencies to those expected with each rotation. We'll prove that in this particular problem, comparing the distributions using dot products gives the same result as the sum of squared errors (SSE).

Letter Frequency

Humans solve cryptograms by analyzing letter frequency.

Since e is the most common English letter, we can assume the most common ciphertext letter is a rotated e.

More generally, we can match the letter frequencies of our ciphertext to the letter frequencies we'd expect to see for each rotation. The "best match" will indicate the most likely rotation.

Suppose the most common English letters are etainos. In our ciphertext, the most common letters are fubjopt -- one letter higher. Hence, we can predict a rotation of 1.

From Wikipedia, here are the letter frequencies for "Texts":

# English language frequencies from "Texts" for the letters a-z
base_freq = [0.082, 0.015, 0.028, 0.043, 0.127, 0.022, 0.020, 0.061, 0.070, 
             0.0015, 0.0077, 0.04, 0.024, 0.067, 0.075, 0.019, 0.00095, 0.060, 
             0.063, 0.091, 0.028, 0.0098, 0.024, 0.0015, 0.020, 0.00074]

For a rotation of 2 where A=>C, B=>D, etc. these are cyclically rotated:

# English language frequencies from "Texts" for the letters a-z
rot_freqs = np.array([base_freq[-rot:] + base_freq[rot:] for rot in range(N_LETTERS)])

Formalizing the letter frequency

Let's make this more formal. Suppose we have a vocabulary containing tokens. In this problem, and . We will only rotate a-z, so we'll define .

We define the rotation letter frequency vector as the base English letter frequencies (above called base_freqs). We define as the letter frequencies when rotated by (above called rot_freqs).

For the th letter, we define . This defines a one-to-one map between and . Hence, the set of values in each vector are identical, making for any .

As we'll see below, because these magnitudes are equal they will cancel out when we compare the distance between vectors.

Distribution distances

To solve the cryptogram, we must find which of the rotation letter distributions is "closest" to that of the ciphertext. To do so, we will compute the "distance" between the ciphertext letter frequencies and each of the expected frequency distributions.

Although there are several methods for comparing frequency distributions, in this chapter we'll start with the simplest -- sum of squared errors (SSE).

Sum of Squared Errors (SSE)

We will compare each letter frequency individually, penalizing larger errors more. We'll do this by squaring these errors and summing the squared errors.

Based on this, here is our definition:

Expanded:

Squaring the errors gives several benefits: First, it makes all differences positive. Second, it magnifies large differences in our total error. Third, compared to absolute value, it is differentiable (for potential minimization) and easier to prove things about!

Comparing SSEs

Although we could compute all SSEs, there is a faster and easier way. It turns out we can simply take the dot product of with each :

Theorem 1. For a given ciphertext letter distribution and rotation distributions and , then if and only if .

Proof. We showed earlier that for any . Then:

Let's define a score function . By definition, the rotation with smallest SSE has a smaller SSE than any other rotation. From Theorem 1, the rotation with the largest score has the smallest SSE (and vice versa). For more discussion of score vs. distance vs. similarity, see Appendix B.

Note this does not hold in general -- only when all class vectors have the same magnitude.

Note that we can simultaneously compute the dot products using matrix algebra. Given a matrix where column is and a row vector of ciphertext frequencies , then gives a vector of all dot products.

Visualizing the dot product

Interestingly, the SSE can now be interpreted as related to the "angle" between vectors. This is because the dot product of unit vectors is the cosine of the angle between them -- see Appendix B for more details!

Imagine the x-axis depicting the letter frequency and the y-axis the frequency of . Then, the rotation 0 class vector is and the rotation 5 class vector .

Our ciphertext token frequencies will then also be a vector. We compare its angle to that of each of the two rotation classes, and the class vector with largest dot product (typically smallest angle) determines the predicted class.

shows the x,y axes representing the e,b letter frequencies. shows the rotation 0 and rotation 3 class vectors in relation.

Transformer Implementation

In this section, we'll take our letter frequency solution and implement it using the transformer architecture as follows:

---
title: One-layer Attention-only Transformer
---
stateDiagram-v2
    Embedding: Embedding
    Embedding: 1. One-hot encode each token as a vector of numbers.
    Attention: Attention Block
    Attention: 2. Take the mean to obtain ciphertext frequencies.
    Unembedding: Unembedding
    Unembedding: 3. Compare the ciphertext letter frequencies to each rotation's expected frequencies.

    [*] --> Embedding: "d edb" (5 tokens)
    Embedding --> Attention
    Attention --> Unembedding
    Unembedding --> [*]: largest score indicates rotation 3 ("a bay")

Embeddings

A transformer takes as input a list of tokens. In our case, the vocabulary comprise our tokens. That said, a token does not have to be a single character. It can represent chunks of characters or even more abstract concepts such as the start or end of text.

The first transformer operation -- the embedding -- takes each token and replaces it with a vector of numbers. These positions are often called channels since they convey information.

The residual stream

The output of the embedding stage is a matrix with each row representing a token, and each column representing a channel. This gives the resulting matrix a shape of . This matrix begins the transformer's residual stream, which is read from and written to as it passes through each stage of the transformer.

One-hot encoding

How should we represent each of the possible tokens? Recall our objective is to compute the ciphertext letter frequencies and that each letter is a token. So, a reasonable start would be to dedicate a single channel per token. If the channel contains , it indicates token ; if channel contains , it indicates it is not token .

This encoding scheme is called a one-hot encoding.

For example, suppose we one-hot encode three tokens a, b, and c using channels. We'll map a = [1, 0, 0], b = [0, 1, 0], and c = [0, 0, 1]. (Stacked, these comprise the embedding matrix, which here is the identity matrix.)

The embedding of the string abc provides our input to the model : Note this has a convenient property -- we can count the letters by summing across the rows! This gives the vector .

Typically embeddings are not one-hot encoded, since this requires . However, it also neglects an opportunity to store extra information per token. For example, we could dedicate a channel to indicate even vs. odd, or we could ensure that "similar" tokens are nearby in space.

At the end of this chapter, we'll show that nearly any embedding matrix can solve our cryptograms! This is what makes interpretability so difficult. To us, a one-hot encoding makes the algorithm easy to understand. However, a computer is likely to choose an arbitrary embedding which obscures the underlying algorithm.

Attention

Attention is often presented as the most complex block in a transformer. It is called attention because it applies a weight to each prior token vector. To pay more attention to some prior tokens, we can apply a larger weight. Hence, attention is effectively a weighted mean of the input .

For this problem we must compute the token frequencies. Luckily for us, these are found by taking the (uniformly-weighted) mean of the input , where is the th row:

In the last section, we embedded the tokens bab as: Now in the attention step, we take the weighted mean across the rows to compute the token frequencies: .

For simplicity, we will assume the attention weights are uniform. In later chapters, we'll discuss the math behind computing the weights and show that uniform weights are possible.

The residual stream

We mentioned in the last section that the residual stream always has shape in both the input and output of attention.

Yet if attention takes the mean across all tokens, why doesn't this result in a -dimensional output?

It turns out that attention takes the mean of each increasing subset of tokens. So the first output row will contain the mean of only the first token's embedding (i.e. itself). The second row will contain the mean of the first two embeddings, and so on.

Let's see a concrete example:

Suppose we have the input embeddings of 'bab': Then, the output of attention is: The first row is the first row of . The second row is the mean of the first two rows of . And the final row is the mean of all three rows .

It is also possible to compute this efficiently using linear algebra:

We define a triangular -shape weight matrix . Row specifies how to weight the first token embeddings to compute the th row of the output matrix : Note that row weights the first tokens equally with probability . The final row of corresponds to , where each token embedding is weighted uniformly as .

To apply the weight matrix to the token emebddings , we use matrix multiplication:

In attention, each row of the weight matrix will always sum to 1. This is a consequence of how the matrix is computed. We'll use this fact below to simplify some of the math.

Python Implementation

n_toks = X.shape[0]
W = np.tri(n_toks) * (1. / np.arange(1, n_toks+1)).reshape(-1,1)

# (n_toks, K) = (n_toks, n_toks) @ (n_toks, K)
Y = W @ X

Technical Details

The attention block can do more than merely take an average of the inputs. However, this is a convenient first-approximation of its workings.

Below, we'll walk through under what conditions the attention block can be reduced to a direct weighted mean of its inputs.

  1. In attention, the inputs first undergo a linear projection (value projection) plus bias . We'll define this as .

  2. is then weighted by and subjected to a second linear projection (output projection) plus bias . We'll call this the attention block output .

In many attention implementations, there is no second (output) projection. In this text, however, we are preparing the reader for multi-headed attention. In this, the purpose of the output projection is to provide a final projection after concatenating together the outputs of numerous heads (although the math below still applies!). We'll look at this in more depth in later chapters.

The two linear projections above are back-to-back, so they can be combined into a single effective transformation. By using the associative and distributive properties of matrix multiplication, we can (step-by-step) rewrite these two transformations as one in terms of the weighted inputs :

The attention output is now in terms of the weighted inputs rather than the weighted transformed inputs !

Note we combined the two transforms above into a single linear transformation with bias (with projection and bias ).

To obtain the final step, we assert that . To explain this, note that each row of sums to 1. Since is a matrix, it is broadcasted into a matrix, resulting in each row of being multiplied by a constant. For example, the first output position is . Therefore, .

A few important observations:

  1. Attention is typically described as weighting the transformed inputs (). However, we showed it can just as accurately be described as directly weighting the original inputs ().

  2. Interestingly, we showed that the value biases are not needed! We equivalently can "fold" them into the output biases. Just let and , giving a single transform . This identity is useful both for ease of interpretability and for reducing multiplications.

  3. In this way, it becomes clear exactly how we can choose , , and to return solely uniformly-weighted inputs. Just set and such that (the identity), then zero the effective output bias ().

Python Implementation

K = X.shape[1]    # num channels
VO = np.identity(K)
bo = np.zeros(K)

# (n_toks, K) = (n_toks, n_toks) @ (n_toks, K) @ (K, K) + (K,)
Y = W @ X @ VO + bo

Unembeddings

Finally, we must predict the rotation of the entire ciphertext. To do this, we'll use the technique from the first section.

  1. First, we'll take the dot product of the final residual stream row () with each rotation's expected frequency distribution.
  2. Then, the largest dot product (called a logit) will indicate the most likely rotation class.

Let's continue the example from the first section:

Suppose for rotation 0 we expect frequencies for the letters e and b to be , respectively. As we discussed in the first section, for rotation 5 we expect the e and b frequencies to be .

In our ciphertext, the respective frequencies are . This makes the dot product for rotation 0: and for rotation 5: .

Hence, any given rotation comparatively is largest when the letter frequencies follow its expected distribution.

As before, transformers use linear algebra to compute these as follows:

Similar to the embedding matrix, we combine each class's expected frequencies into an unembedding matrix where each column represents a rotation class: To compute the final logits, we use matrix multiplication:

Typically there will also be unembed biases (one per rotation class), making the final logits:

Note there are two conspicuous ways to make a large dot product:

  1. Align the signs/magnitudes so that large positives match with large positives and large negatives match with large negatives.
  2. Use a few particularly large outliers to influence the result, similarly to a gate.

In LLMs, the first method seems much more common. However, there are some examples of trained models using outliers as gates.

Python Implementation

# (n_toks, n_classes) = (n_toks, K) @ (K, n_classes) + (n_classes,)
logits = Y @ unembeds + bu
pred_class = np.argmax(self.logits[-1])  # final row incl all tokens

Improving the Accuracy

Arbitrary Embeddings

Chapter 2: Revealing the "True" Structure from Trained Feedforward Nets

In an LLM, each layer is comprised of two blocks. These are the attention block and the feedforward block, also known as a multi-layer perceptron (MLP).

In this chapter, we'll find a minimal solution to a circuit proposed in Steven Wolfram's "What Is ChatGPT Doing ... and Why Does It Work?".

Training algorithms often have trouble with minimal-sized circuits, likely since the solution space is small. Indeed, Wolfram found that approximating a wave function using training required a hidden layer, but we'll see it's possible using a single layer.

Along the way, we'll investigate how a one-layer network can approximate any function arbitrarily well. In doing so, we'll see that neural networks essentially approximate a "true function" that likely has a simpler representation than a neural net.

Understanding this, it is enticing to believe that large LLMs noisily approximate some "true" structure of language which has yet to be discovered.

A Single-Layered Network

The Math

First, let's examine the math behind a single-layered neural network.

Suppose we are given inputs (scalars) and neurons. Then, the output of neuron is a linear combination of all -dimensional inputs with -dimensional weights , plus a scalar bias . An activation function is then applied to the resulting sum:

Alternatively, we can compute all outputs simultaneously with a single matrix multiplication:

As we saw in Chapter 1, multiple linear transforms applied back-to-back can be collapsed into a single linear transform. Therefore, non-terminal layers use a non-linear . Depending on the function, this allows us to warp the space (e.g. the logistic function squishes all outputs into the range ), apply gating (e.g. the ReLU function "turns off" an output if the weighted input sum is negative), and more.

The feedforward block of a transformer

The feedforward block consists of two neuron layers, where the first layer has some neurons. It takes as input the residual stream, which is an x matrix. (That is, tokens/sequence positions, each represented by a -dimensional vector.) The first layer consists of neurons, each with a non-linear activation function, followed by a second layer of neurons with an identity activation function.

Mathematical Preliminaries

This book is intended to be entirely self-contained. So, we are providing a minimal set of definitions needed to understand the text.

  • We indicate a definition by . This often defines a new notation or operator that cannot be derived from other statements.

Given statements and :

  • (" implies ") indicates that logically follows from (but not necessarily in the opposite direction).
  • (" if and only if ") indicates that follows from and follows from (in both directions).

Sets

A set V is a collection of unique elements. We denote membership by (" is an element of " or " is in "). We denote the size of the set (its cardinality) by . The empty set is denoted .

For example, we may define a vocabulary which has cardinality . Then, while .

Set operations. Given sets and , we define:

  • The union .
  • The intersection .
  • The set difference .

The above uses set-builder notation to define each set in terms of and . We read the union definition as "the set of all such that is in or is in ".

Special sets. We denote as the set of all real numbers. For example: . We define as the set of all integers and as the set of all positive integers. From these, we define the natural numbers .

Tuples and the Cartesian product

Tuples are often used in conjunction with sets and vectors. An -tuple is an ordered collection of elements. The -tuple is a singleton, and the -tuple an ordered pair. For example, the point is an ordered pair.

Cartesian product. We can redundantly define the set of all reals as:

Let's generalize this to the set of all points (ordered pairs) in 2D space by using the Cartesian product of :

For any set , we can generally define the -fold Cartesian product of A to produce the set of all -tuples with elements in :

So the ordered pair and .

Vectors

In this book, we will only work with vectors defined on the reals. So, we refer to a real number as a scalar. Vectors will be written in lowercase boldface, e.g. , while elements of vectors (scalars) will be lowercase non-boldface, e.g. .

Definition. A -dimensional vector is a -tuple with scalar elements:

For any two -dimensional vectors and and scalar , we define the following:

Definition. The mathematical operator is defined elementwise such that

Definition. Scalar-vector multiplication is defined as

Definition. The mathematical operator is defined elementwise as

Comparing Vectors

In this book, we'll use the terms distance, score, and similarity. For any -dimensional vectors and :
  • A distance measure such as Euclidean distance requires smaller values to indicate "more similar". We assume for any distance measure, if and only if , , and .

  • A similarity measure such as cosine similarity requires larger values to indicate "more similar", but no upper bound is required. Often, similarity measures can be converted to distance measures, and vice versa. For example, since cosine similarity is bounded by -1 and 1, we can define cosine distance as .

  • A score is an arbitrary function where a larger score indicates a "better" match. All similarity measures are scores, and all negated distance metrics are scores. This terminology was invented to avoid misunderstandings that may arise when comparing, say, Euclidean distance with cosine similarity.

The Euclidean distance

In machine learning, we frequently must compare how far one vector is from another. There are many ways of doing this. Among the simplest (and most common) is the Euclidean distance. This is the "straight line" distance we're familiar with in the real world.

Its definition is based on the Pythagorean Theorem. Given a right-angled triangle with side lengths and hypotenuse length , then . We can use this to compute the distance ("hypotenuse") between two two-dimensional points.

Suppose we want to know the "straight-line" distance between two points and . We can connect the points with a right triangle with sides and . Then, by the Pythagorean Theorem, the "straight-line" distance between them (the hypotenuse) is .

Applying the Pythagorean Theorem a second time, we can derive the formula for three-dimensional vectors .

Noticing a pattern, we can define the general -dimensional Euclidean distance as follows:

Definition. Given a -dimensional vector , the magnitude of is

Definition. Given -dimensional vectors and , the Euclidean distance is

We will see that some distance measures (including Euclidean distance) are referred to as metrics:

Definition. A distance metric satisfies the following properties. For any -dimensional vectors , , and :

  • ;
  • Positivity. ;
  • Symmetry. ;
  • Triangle inequality.

The triangle inequality ensures that the shortest distance between two points is a straight line.

Dot product

Definition. Given two -dimensional vectors and , the dot product is their elementwise product: .

In machine learning, the dot product often is used as a similarity measure. However, as we'll see below it is somewhat imperfect. It is affected both by the angle between the vectors and their magnitudes.

Definition. Given a -dimensional vector , the unit vector of is . This vector is in the direction of but has magnitude , on the unit circle.

In many textbooks, the dot product is defined directly in terms of the angle between two vectors. Where are vectors, are their magnitudes, and is the angle between the vectors:

Hence, if the vectors are unit vectors, the dot product by itself is the cosine of the angle between them.

Cosine similarity

The cosine similarity is the cosine of the angle between two vectors. It is useful when direction is more important than magnitude. For example, someone who always rates movies is similar to someone who always rates them (perhaps the origin represents all movies rated ) -- their ratings convey no information about the movies!

The cosine similarity is bounded between [-1, 1], with a larger score indicating more similar.

  • s = 1. Same direction, since: , e.g. .
  • s = -1. Opposite direction, since: , e.g. .
  • s = 0. Orthogonal (e.g. a 90-degree angle), since: , e.g. .

This measure can be efficiently computed, since it is based on the dot product.

Hence, the dot product by itself is effectively an "un-normalized" cosine similarity. This means it can be affected by magnitude!

This is especially apparent in modern LLMs if the embeddings are unembedded by themselves. Using cosine similarity, each token will match best with itself. However, merely using the dot product will cause many tokens to match with other tokens with larger magnitude.

Exercises

  1. By applying the Pythagorean Theorem twice, derive the three-dimensional Euclidean distance formula. Prove that for any three-dimensional vectors and ,