My reconstruction of the results from this paper. An 8-layer GPT model is trained to predict valid moves in the board game othello - i.e. it only sees the sequence of moves, never the board. Probes are then trained to investigate whether there is an internal model of the board state. The causality of this internal model is then investigated by performing interventions on the intermediate activations of the model, and seeing if predictions for valid moves are updated accordingly.
Code to simulate games is found within generate_data.ipynb
and the game logic in othello.py
. Moves are selected randomly from the set of valid moves (no attempt at learning good moves, only valid moves).
Games are stored as numpy arrays and capped at length
Training loop is in train.ipynb
. Uses default mingpt training loop edited as in othello_world to train on epochs (sampling without replacement) rather than the default sampling with replacement. Training set used is 1048576 games - took ~ 2 hours to acheive training loss of 2.02 / argmax accuracy of 99.17% training on a 3090 (vast.ai instance). Below is an example of a typical board state with orange squares indicating valid moves that can be taken by black, while on the right is a heatmap of logits output by the model.
Probes are trained in board_probes.ipynb
. They are trained on 100000 games to predict the board state based on activations from each layer. Rather than using the non-linear probes from the original project, linear probes are used using this insight from Neel Nanda to flip the board state after each turn to reflect 'my piece' vs 'your piece' rather than 'black piece' vs 'white piece'. Below is an example board state on the left and a heatmap showing the most likely category predicted by the probe for each board state (blue = empty, pink = my piece, yellow = opponents piece) with white to play.
An example intervention. The white piece in cell (2,1) has its component in the 'opponent (white) direction' removed and a component in the 'black (my) direction' inserted - in this the modification has just been made on the activations coming from the 5th layer, and only only the final token. The original logits output by the model are shown on the left, the logits resulting from this intervention on the right. We see that the model has correctly identified that cell (1,0) is no longer a valid move.