Unembeddings
Finally, we must predict the rotation of the entire ciphertext. To do this, we'll use the technique from the first section.
- First, we'll take the dot product of the final residual stream row () with each rotation's expected frequency distribution.
- Then, the largest dot product (called a logit) will indicate the most likely rotation class.
Let's continue the example from the first section:
Suppose for rotation 0 we expect frequencies for the letters
e
andb
to be , respectively. As we discussed in the first section, for rotation 5 we expect thee
andb
frequencies to be .In our ciphertext, the respective frequencies are . This makes the dot product for rotation 0: and for rotation 5: .
Hence, any given rotation comparatively is largest when the letter frequencies follow its expected distribution.
As before, transformers use linear algebra to compute these as follows:
Similar to the embedding matrix, we combine each class's expected frequencies into an unembedding matrix where each column represents a rotation class: To compute the final logits, we use matrix multiplication:
Typically there will also be unembed biases (one per rotation class), making the final logits:
Note there are two conspicuous ways to make a large dot product:
- Align the signs/magnitudes so that large positives match with large positives and large negatives match with large negatives.
- Use a few particularly large outliers to influence the result, similarly to a gate.
In LLMs, the first method seems much more common. However, there are some examples of trained models using outliers as gates.
Python Implementation
# (n_toks, n_classes) = (n_toks, K) @ (K, n_classes) + (n_classes,)
logits = Y @ unembeds + bu
pred_class = np.argmax(self.logits[-1]) # final row incl all tokens