Chapter 1: Solving Cryptograms by Designing a Transformer
In this chapter, we will encode natural language statistics in transformer weights. To gain a strong understanding, we'll handcraft a transformer circuit step-by-step from first principles. The result will solve cryptograms encoded using a Caesar cipher (a fixed-letter rotation).
Cryptogram Intro
A Caesar cipher is a code where each letter has been rotated forward by a fixed number. Given rotated text (ciphertext), the challenge is to determine the original rotation number (and thereby recover the original plaintext).
The plaintext
a bay
becomes ciphertextd edb
with rotation 3.
Note that only the letters a-z will be rotated. For example, the space character in "a bay" is not rotated.
A transformer only accepts a fixed vocabulary of possible tokens. In this problem, we will allow 27 tokens -- the lowercase letters a-z
and the space character.
Python Implementation
We encourage you to code along in your own notebook! For this chapter, visit GitHub to download three Project Gutenberg text files and our minimal plotting functions in llm_plots.py
. Place these in the same directory as your notebook, then run the notebook server from this directory to ensure it is on your PATH.
Here are the Python packages we'll use:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from dan_plots import rowplot, implot, listplot
np.set_printoptions(suppress=True) # suppress scientific notation
Now let's implement the Caesar rotation in Python! Our model requires indexes into the vocabulary as input. So instead of storing ciphertext as a string, we will represent it as an array of rotated numbers (rotnums). For clarity, all strings will be assumed plaintext and all lists of numbers ciphertext.
The ciphertext
d edb
will be stored as[3, 26, 4, 3, 1]
, where'a'=0
and' '=26
.
For now, we will use classes primarily to help us organize and encapsulate constants. The class CaesarDataset
will load text and rotate it, starting with:
class CaesarDataset:
SEQ_LEN = 32 # chars per cryptogram
N_ROTS = 26 # total letters/rotations
# VOCAB: Must begin with a-z and include space.
VOCAB = np.array(list('abcdefghijklmnopqrstuvwxyz '))
VOCAB_SET = set(VOCAB) # O(1) membership test
VOCAB_IDX = {l:i for i,l in enumerate(VOCAB)} # O(1) char->index
@staticmethod
def plaintext_to_rotnums(plaintext, rot=0):
'''Rotate letters forward, returning rotated numbers.
Input characters must be in VOCAB.'''
vocab_idxs = np.array([CaesarDataset.VOCAB_IDX[c] for c in plaintext])
return np.where(vocab_idxs < CaesarDataset.N_ROTS,
(vocab_idxs + rot) % CaesarDataset.N_ROTS,
vocab_idxs) # do not rotate if not a-z
@staticmethod
def rotnums_to_plaintext(rotnums, rot=0):
'Rotate numbers backward, returning plaintext.
Input numbers must be less than the VOCAB size.'
return ''.join(CaesarDataset.VOCAB[
np.where(rotnums < CaesarDataset.N_ROTS,
(rotnums - rot) % CaesarDataset.N_ROTS,
rotnums)]) # do not rotate if not a-z
Example Run (expand to view)
rot = 5
plaintext = 'the quick brown fox jumps over the lazy dog'
rotnums = CaesarDataset.plaintext_to_rotnums(plaintext, rot)
print('plaintext:', plaintext)
print('ciphertext:', CaesarDataset.rotnums_to_plaintext(rotnums))
print('plaintext (hopefully!):',
CaesarDataset.rotnums_to_plaintext(rotnums, rot))
This outputs:
plaintext: the quick brown fox jumps over the lazy dog
ciphertext: ymj vznhp gwtbs ktc ozrux tajw ymj qfed itl
plaintext (hopefully!): the quick brown fox jumps over the lazy dog