Summary: A technique to represent categorical integer indices as binary vectors so that neural networks cannot misinterpret them as ordinal numbers.
Why integers can’t be fed directly to a neural network
Raw integer indices for categorical data cannot be fed into a neural network. Integers imply false ordinal relationships — the network would interpret 'c' = 3 as literally three times 'a' = 1, which is meaningless for categorical data. Downstream multiplicative and non-linear operations would exacerbate the NN’s misunderstanding.
One-hot vectors make every category equidistant from every other: each vector lies on a different axis of (where is the number of classes), so no pair of categories is “closer” or “further” than any other.
Encoding
For a vocabulary of size , the one-hot vector for index is the standard basis vector :
A batch of examples becomes a matrix , one row per example. The dtype must be float (not int) so gradients can flow during backpropagation.
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=C).float() # (N, C)GPT-3 example: encoding "The cat sat" ( )
Token indices (BPE):
"The"→ 464,"cat"→ 3797,"sat"→ 3332. Each becomes a length-50,257 vector.For example,
"The":Stacked as a batch of :
Each row contains one
1and 50,256 zeros. The full matrix holds 150,771 floats — 99.998% of which are zero. This is why embeddings replace one-hot encoding in practice: an embedding lookup is the same row-select operation (see row-select insight), but without materialising the sparse matrix.
Neat insight: one-hot × weight matrix = row select
Multiplying a one-hot vector by a weight matrix is algebraically equivalent to selecting a single row (the ‘th row) from :
GPT-3 example: , selecting row (
"The")
GPT-3 batch: tokens (
"The cat sat"), each row is a row-selectEach output row is a direct copy of the corresponding row of — three simultaneous lookups, no arithmetic.
The dot product zeroes out every row except the one the 1 aligns with. No actual arithmetic is needed; it’s a lookup.
Implication for bigram language models
This means xenc @ W (where xenc is a one-hot batch) produces one row of per training example — effectively treating each row of as the learned log-count for that input category. When gradient descent converges, W.exp() recovers the same count-based probabilities that an explicit frequency table would give — the two approaches are identical in their final result; they differ only in how they arrive there (counting vs. gradient descent).
Sources
- Relevant Jupyter notebooks: