Recap: emma example

The following code cell imports the data, sets up the bigram NN (inputs xs and labels ys), defines weights W, and performs the forward pass

# import data, init bigram nn (training inputs + labels), define W, forward pass
from IPython.display import HTML
from utils.matmul_viz import show_matmul
 
words = open('data/names.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
 
import torch
import torch.nn.functional as F
 
# create the training set of bigrams (x,y)
xs, ys = [], []
iter = 1
 
print("first word, 'emma', contains 5 training examples:")
print('eg #', '  input (x)', '-> target (y)')
for w in words[:1]:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    print(f'{iter}          {ch1}     ->      {ch2}')
    iter += 1
    xs.append(ix1)
    ys.append(ix2)
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)
print('\nnn inputs  (xs)    :', xs)
print('targets/labels (ys):', ys)
 
# randomly initialize 27 neurons' weights. each neuron receives 27 inputs
g = torch.Generator().manual_seed(2147483647) # changing this seed may provide better `probs` in forward pass
W = torch.randn((27, 27), generator=g)
 
# forward pass (last 2 lines `counts` and `probs` are applying softmax)
xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
 
print('\nprobs.shape:', probs.shape) # same shape as logits, counts, and probs
first word, 'emma', contains 5 training examples:
eg #   input (x) -> target (y)
1          .     ->      e
2          e     ->      m
3          m     ->      m
4          m     ->      a
5          a     ->      .
 
nn inputs  (xs)    : tensor([ 0,  5, 13, 13,  1])
targets/labels (ys): tensor([ 5, 13, 13,  1,  0])
 
probs.shape: torch.Size([5, 27])

Diagnostic of loss function value

Next, inspect the nll for each training example, as well as the average nll across the 5 examples. The latter value is the network’s loss. Expand below diagnostic output

# inspection loop
nlls = torch.zeros(5)
for i in range(5):
  # i-th bigram:
  x = xs[i].item() # input character index
  y = ys[i].item() # label character index
  print('--------')
  print(f'bigram ex. {i+1}              : "{itos[x]}{itos[y]}" (indexes {x},{y})')
  print('nn input                  :', x)
  print('nn output (probs)         :', probs[i])
  print('label (actual next char)  :', y)
  p = probs[i, y]
  print('nn\'s prob for CORRECT char:', p.item())
  logp = torch.log(p)
  print('log likelihood            :', logp.item())
  nll = -logp
  print('negative log likelihood   :', nll.item())
  nlls[i] = nll
 
print('=========')
print('average negative log likelihood, i.e. loss =', nlls.mean().item())
--------
bigram ex. 1              : ".e" (indexes 0,5)
nn input                  : 0
nn output (probs)         : tensor([0.0607, 0.0100, 0.0123, 0.0042, 0.0168, 0.0123, 0.0027, 0.0232, 0.0137,
        0.0313, 0.0079, 0.0278, 0.0091, 0.0082, 0.0500, 0.2378, 0.0603, 0.0025,
        0.0249, 0.0055, 0.0339, 0.0109, 0.0029, 0.0198, 0.0118, 0.1537, 0.1459])
label (actual next char)  : 5
nn's prob for CORRECT char: 0.012286250479519367
log likelihood            : -4.3992743492126465
negative log likelihood   : 4.3992743492126465
--------
bigram ex. 2              : "em" (indexes 5,13)
nn input                  : 5
nn output (probs)         : tensor([0.0290, 0.0796, 0.0248, 0.0521, 0.1989, 0.0289, 0.0094, 0.0335, 0.0097,
        0.0301, 0.0702, 0.0228, 0.0115, 0.0181, 0.0108, 0.0315, 0.0291, 0.0045,
        0.0916, 0.0215, 0.0486, 0.0300, 0.0501, 0.0027, 0.0118, 0.0022, 0.0472])
label (actual next char)  : 13
nn's prob for CORRECT char: 0.018050704151391983
log likelihood            : -4.014570713043213
negative log likelihood   : 4.014570713043213
--------
bigram ex. 3              : "mm" (indexes 13,13)
nn input                  : 13
nn output (probs)         : tensor([0.0312, 0.0737, 0.0484, 0.0333, 0.0674, 0.0200, 0.0263, 0.0249, 0.1226,
        0.0164, 0.0075, 0.0789, 0.0131, 0.0267, 0.0147, 0.0112, 0.0585, 0.0121,
        0.0650, 0.0058, 0.0208, 0.0078, 0.0133, 0.0203, 0.1204, 0.0469, 0.0126])
label (actual next char)  : 13
nn's prob for CORRECT char: 0.026691533625125885
log likelihood            : -3.623408794403076
negative log likelihood   : 3.623408794403076
--------
bigram ex. 4              : "ma" (indexes 13,1)
nn input                  : 13
nn output (probs)         : tensor([0.0312, 0.0737, 0.0484, 0.0333, 0.0674, 0.0200, 0.0263, 0.0249, 0.1226,
        0.0164, 0.0075, 0.0789, 0.0131, 0.0267, 0.0147, 0.0112, 0.0585, 0.0121,
        0.0650, 0.0058, 0.0208, 0.0078, 0.0133, 0.0203, 0.1204, 0.0469, 0.0126])
label (actual next char)  : 1
nn's prob for CORRECT char: 0.07367684692144394
log likelihood            : -2.6080667972564697
negative log likelihood   : 2.6080667972564697
--------
bigram ex. 5              : "a." (indexes 1,0)
nn input                  : 1
nn output (probs)         : tensor([0.0150, 0.0086, 0.0396, 0.0100, 0.0606, 0.0308, 0.1084, 0.0131, 0.0125,
        0.0048, 0.1024, 0.0086, 0.0988, 0.0112, 0.0232, 0.0207, 0.0408, 0.0078,
        0.0899, 0.0531, 0.0463, 0.0309, 0.0051, 0.0329, 0.0654, 0.0503, 0.0091])
label (actual next char)  : 0
nn's prob for CORRECT char: 0.01497753243893385
log likelihood            : -4.2012038230896
negative log likelihood   : 4.2012038230896
=========
average negative log likelihood, i.e. loss = 3.7693049907684326

What we want gradient descent to do

For the first name row emma, these are the 5 example bigrams to train on. For these training bigrams, we know the correct next character.

#bigraminput xs[i]label ys[i]P(correct next char)notation
1.e. (idx 0)e (idx 5)probs[0, 5]
2eme (idx 5)m (idx 13)probs[1, 13]
3mmm (idx 13)m (idx 13)probs[2, 13]
4mam (idx 13)a (idx 1)probs[3, 1]
5a.a (idx 1). (idx 0)probs[4, 0]

The NN initially assigns random probabilities to all bigrams, including for the correct next character:

Goal of gradient descent: nudge W so that each of these correct probabilities moves closer to 1

# for `emma` (5 examples) nn initially has random (bad) probabilities for correct next char
print('xs:', xs) # nn inputs
print('ys:', ys) # labels (desired outputs)
 
print('\nprobs.shape:', probs.shape) 
print(
    '\nex. 1 ".e": inspect P("e" | "."), i.e. `probs[0, 5]`:', probs[0, 5], 
    '\nex. 2 "em": inspect P("m" | "e"), i.e. `probs[1, 13]`:', probs[1, 13],
    '\nex. 3 "mm": inspect P("m" | "m"), i.e. `probs[2, 13]`:', probs[2, 13],
    '\nex. 4 "ma": inspect P("a" | "m"), i.e. `probs[3, 1]`:', probs[3, 1], 
    '\nex. 5 "a.": inspect P("." | "a"), i.e. `probs[4, 0]`:', probs[4, 0], 
)
xs: tensor([ 0,  5, 13, 13,  1])
ys: tensor([ 5, 13, 13,  1,  0])
 
probs.shape: torch.Size([5, 27])
 
ex. 1 ".e": inspect P("e" | "."), i.e. `probs[0, 5]`: tensor(0.0123) 
ex. 2 "em": inspect P("m" | "e"), i.e. `probs[1, 13]`: tensor(0.0181) 
ex. 3 "mm": inspect P("m" | "m"), i.e. `probs[2, 13]`: tensor(0.0267) 
ex. 4 "ma": inspect P("a" | "m"), i.e. `probs[3, 1]`: tensor(0.0737) 
ex. 5 "a.": inspect P("." | "a"), i.e. `probs[4, 0]`: tensor(0.0150)

Loss function

This NN has 1 linear layer and 1 non-linearity (softmax), feeding into the loss function: average negative log likelihood (nll). Since every operation in this pipeline is differentiable, backpropagation can compute the local gradient of the loss at each computation-graph node.

Gradient descent then improves the loss:

  • Initialise all gradients to zero
  • Nudge each weight in W in the opposite direction to its gradient
# i - current loss (avg nll), vectorised computation
print('loss = average nll =', -probs[torch.arange(5), ys].log().mean().item())
loss = average nll = 3.7693049907684326

Compare vectorised computation of loss (expand output) to diagnostic loop output above. Both identical ~3.769

Forward pass

# rand init 27 neurons' weights (with grad param). each neuron receives 27 inputs
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True) # ensure gradients

Also see backpropagation in PyTorch note: 09_pytorch_gradient_descent

# i - forward pass (softmax, and loss calculation)
xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
loss = -probs[torch.arange(5), ys].log().mean() # average negative log-likelihood
 
print(loss.item()) # loss = 3.77
3.7693049907684326
  • PyTorch tracks the forward pass and creates a computation-graph with all (differentiable) mathematical operations.
  • .backward() call fills in the gradients grad of all intermediate operation nodes (i.e. local chain rule), all the way back to W leaf nodes (NN parameters).
    • No loop is needed on parameters (like we did in 09_pytorch_gradient_descent)
    • PyTorch treats the W tensor as a single parameter. All elements update in parallel

Manual backpropagation and gradient descent

# i - backward pass
W.grad = None # zero all grads (more efficient, PyTorch interprets as 0)
loss.backward()
print('W.shape      :', W.shape, '\nW.grad.shape :',W.grad.shape)
 
# manual gradient descent
W.data += -0.1 * W.grad
W.shape      : torch.Size([27, 27]) 
W.grad.shape : torch.Size([27, 27])
  • Backpropagation loss.backward() computes W.grad. Each element tells us the influence of that weight W on the loss function (average negative log-likelihood)
  • The gradient descent step nudged each weight in W in the opposite direction to its gradient

Perform a few forward pass iterations below to confirm the loss is decreasing (3.77 3.75 3.73). The network is assigning higher and higher probabilities to the correct next characters. NN performance is improving

# forward pass confirms loss decreases to 3.75
xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
loss = -probs[torch.arange(5), ys].log().mean() # average negative log-likelihood
 
print(loss.item()) # 3.75
3.7492129802703857
# another backward pass -> another forward pass -> loss decreases further to 3.73
W.grad = None
loss.backward()         # backward pass
W.data += -0.1 * W.grad # manual gradient descent
 
# forward pass
xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
loss = -probs[torch.arange(5), ys].log().mean() # average negative log-likelihood
 
print(loss.item()) # 3.73
3.7291626930236816

Putting it all together

# i - initialise the dataset (over all names, not just `emma`)
xs, ys = [], []
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number of examples: ', num)
 
# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)
number of examples:  228146

Perform 100 iterations of gradient descent

Note the loss from the NN approach converges to ~2.49. Similar to the the explicit (naive counting) Bigram approach in 03_loss_function_and_smoothing.

Naive counting works because Bigram problem is so simple (calculate counts N maintain as probabilities P).

Gradient descent is significantly more flexible and scalable across more complex problem classes (applying more complex NNs):

  • Including more than just 1 prior character
  • Building up to attention and the transformer model
# i - gradient descent 100 iterations (output shows loss minimisation!)
for k in range(100):
    
    # forward pass
    xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
    print(f'iter {k}: loss = {loss.item():.4f}')
    
    # backward pass
    W.grad = None # set to zero the gradient
    loss.backward()
    
    # update
    W.data += -50 * W.grad
iter 0: loss = 3.7686
iter 1: loss = 3.3788
iter 2: loss = 3.1611
iter 3: loss = 3.0272
iter 4: loss = 2.9345
iter 5: loss = 2.8672
iter 6: loss = 2.8167
iter 7: loss = 2.7771
iter 8: loss = 2.7453
iter 9: loss = 2.7188
iter 10: loss = 2.6965
iter 11: loss = 2.6774
iter 12: loss = 2.6608
iter 13: loss = 2.6464
iter 14: loss = 2.6337
iter 15: loss = 2.6225
iter 16: loss = 2.6125
iter 17: loss = 2.6037
iter 18: loss = 2.5958
iter 19: loss = 2.5887
iter 20: loss = 2.5823
iter 21: loss = 2.5764
iter 22: loss = 2.5711
iter 23: loss = 2.5663
iter 24: loss = 2.5618
iter 25: loss = 2.5577
iter 26: loss = 2.5539
iter 27: loss = 2.5504
iter 28: loss = 2.5472
iter 29: loss = 2.5442
iter 30: loss = 2.5414
iter 31: loss = 2.5387
iter 32: loss = 2.5363
iter 33: loss = 2.5340
iter 34: loss = 2.5318
iter 35: loss = 2.5298
iter 36: loss = 2.5279
iter 37: loss = 2.5261
iter 38: loss = 2.5244
iter 39: loss = 2.5228
iter 40: loss = 2.5213
iter 41: loss = 2.5198
iter 42: loss = 2.5185
iter 43: loss = 2.5172
iter 44: loss = 2.5160
iter 45: loss = 2.5148
iter 46: loss = 2.5137
iter 47: loss = 2.5127
iter 48: loss = 2.5117
iter 49: loss = 2.5108
iter 50: loss = 2.5099
iter 51: loss = 2.5090
iter 52: loss = 2.5082
iter 53: loss = 2.5074
iter 54: loss = 2.5066
iter 55: loss = 2.5059
iter 56: loss = 2.5052
iter 57: loss = 2.5045
iter 58: loss = 2.5039
iter 59: loss = 2.5033
iter 60: loss = 2.5027
iter 61: loss = 2.5021
iter 62: loss = 2.5016
iter 63: loss = 2.5011
iter 64: loss = 2.5006
iter 65: loss = 2.5001
iter 66: loss = 2.4996
iter 67: loss = 2.4992
iter 68: loss = 2.4987
iter 69: loss = 2.4983
iter 70: loss = 2.4979
iter 71: loss = 2.4975
iter 72: loss = 2.4971
iter 73: loss = 2.4967
iter 74: loss = 2.4964
iter 75: loss = 2.4960
iter 76: loss = 2.4957
iter 77: loss = 2.4954
iter 78: loss = 2.4950
iter 79: loss = 2.4947
iter 80: loss = 2.4944
iter 81: loss = 2.4941
iter 82: loss = 2.4939
iter 83: loss = 2.4936
iter 84: loss = 2.4933
iter 85: loss = 2.4931
iter 86: loss = 2.4928
iter 87: loss = 2.4926
iter 88: loss = 2.4923
iter 89: loss = 2.4921
iter 90: loss = 2.4919
iter 91: loss = 2.4917
iter 92: loss = 2.4915
iter 93: loss = 2.4913
iter 94: loss = 2.4911
iter 95: loss = 2.4909
iter 96: loss = 2.4907
iter 97: loss = 2.4905
iter 98: loss = 2.4903
iter 99: loss = 2.4901

Regularisation of loss (smoothing)

Note the + 0.01*(W**2).mean() term in the loss above. This is L2 regularisation (weight decay):

  • W**2 — squares every element of W (27×27 = 729 values)
  • .mean() — averages all 729 squared values into a single scalar
  • 0.01 * ... — scales it down so it doesn’t overpower the actual loss

It gets added to the NLL loss, so minimising the total loss now has two competing objectives:

  • Minimise NLL assign high probability to correct next characters
  • Minimise W**2 keep all weights close to zero

Why do we want weights close to zero?

A network with no regularisation is free to push weights to large values to fit the training data perfectly. Large weights produce very peaked, overconfident probability distributions. Penalising large weights W**2 forces the network toward smaller, more diffuse weights, which corresponds to a more uniform probability distribution — essentially pulling the model toward the “I’m not sure” prior (more equal probabilities across all next characters).

  • The 0.01 is a hyperparameter — called the regularisation strength. Too high and it overwhelms the NLL loss, making all predictions uniform. Too low and it has no effect.
  • In the bigram context specifically: when W is all zeros, xenc @ W is all zeros, softmax of all zeros is uniform 1/27 for every character — which is the maximally uncertain prediction. L2 regularisation is nudging the model toward that baseline.

Generate sample

Note the output name predictions from the NN are identical to the explicit Bigram model.

# i - finally, sample from the 'neural net' model
g = torch.Generator().manual_seed(2147483647)
 
for i in range(5):
    
    out = []
    ix = 0
    while True:
        
        # ----------
        # BEFORE:
        # p = P[ix]
        # ----------
        # NOW:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W # predict log-counts
        counts = logits.exp() # counts, equivalent to N
        p = counts / counts.sum(1, keepdims=True) # probabilities for next character
        # ----------
        
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0:
            break
    print(''.join(out))
cexze.
momasurailezityha.
konimittain.
llayn.
ka.

Sources