Unembeddings

Finally, we must predict the rotation of the entire ciphertext. To do this, we'll use the technique from the first section.

  1. First, we'll take the dot product of the final residual stream row () with each rotation's expected frequency distribution.
  2. Then, the largest dot product (called a logit) will indicate the most likely rotation class.

Let's continue the example from the first section:

Suppose for rotation 0 we expect frequencies for the letters e and b to be , respectively. As we discussed in the first section, for rotation 5 we expect the e and b frequencies to be .

In our ciphertext, the respective frequencies are . This makes the dot product for rotation 0: and for rotation 5: .

Hence, any given rotation comparatively is largest when the letter frequencies follow its expected distribution.

As before, transformers use linear algebra to compute these as follows:

Similar to the embedding matrix, we combine each class's expected frequencies into an unembedding matrix where each column represents a rotation class: To compute the final logits, we use matrix multiplication:

Typically there will also be unembed biases (one per rotation class), making the final logits:

Note there are two conspicuous ways to make a large dot product:

  1. Align the signs/magnitudes so that large positives match with large positives and large negatives match with large negatives.
  2. Use a few particularly large outliers to influence the result, similarly to a gate.

In LLMs, the first method seems much more common. However, there are some examples of trained models using outliers as gates.

Python Implementation

# (n_toks, n_classes) = (n_toks, K) @ (K, n_classes) + (n_classes,)
logits = Y @ unembeds + bu
pred_class = np.argmax(self.logits[-1])  # final row incl all tokens