What is makemore

  • makemore takes one text file as input, where each line is assumed to be one training thing, and generates more things like it.
    • Inspiring topic for the tutorial: baby name generation!
  • makemore is a character-level language model (LM) so, when trained, it simply predicts the next character in the sequence.
    • Several LMs architectures can do this. I will implement the following:
      • Bigram (one char predicts the next, with a lookup table of counts)
      • Bag of words
      • MLP
      • RNN
      • GRU
      • Transformer (equivalent to GPT-2, on the character-level)
  • Extensions. What might come after this?
    • word-level (tokenisation)?
    • document-level?
    • images and image-text networks (DALLE, Stable Diffusion)

Construct Bigram object

# i - import data and inspect
words = open('data/names.txt', 'r').read().splitlines()
print('first 10 names:', words[:10])
print('total count:   ', len(words), 'names')
print('shortest name: ', min(len(w) for w in words), 'chars')
print('longest name:  ', max(len(w) for w in words), 'chars')
first 10 names: ['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']
total count:    32033 names
shortest name:  2 chars
longest name:   15 chars

Each line has many many training examples

Bigram

“Bigram” refers to a sequence of two adjacent elements (typically words, characters, or tokens) from a text string, representing a type of -gram where

This character-level bigram model is extremely simple. Only two characters are operated on at a time. It only uses a single character to predict the next.

Create frequency table of all bigrams (counts)

# i - store counts of all bigrams
b = {} # initialise dict of bigram counts
for w in words:
    chs = ['<S>'] + list(w) + ['<E>']  # prepend start '<S>', append end '<E>' special chars
    for ch1, ch2 in zip(chs, chs[1:]): # 2-char sliding window over each w
        bigram = (ch1, ch2)  # tuple
        b[bigram] = b.get(bigram, 0) + 1  # iterate counter (initialise 0 if first encounter)
# i - b.items() is type `dict_items`. It returns tuples (k: bigrams, v: counts) 
sorted(b.items(), key = lambda v: -v[1]) # enforce desc sort on counts (2nd item of tuple: value)
[(('n', '<E>'), 6763),
 (('a', '<E>'), 6640),
 (('a', 'n'), 5438),
 (('<S>', 'a'), 4410),
 (('e', '<E>'), 3983),
 (('a', 'r'), 3264),
 (('e', 'l'), 3248),
 (('r', 'i'), 3033),
 (('n', 'a'), 2977),
 (('<S>', 'k'), 2963),
 (('l', 'e'), 2921),
 (('e', 'n'), 2675),
 (('l', 'a'), 2623),
 (('m', 'a'), 2590),
 (('<S>', 'm'), 2538),
 (('a', 'l'), 2528),
 (('i', '<E>'), 2489),
 (('l', 'i'), 2480),
 (('i', 'a'), 2445),
 (('<S>', 'j'), 2422),
 (('o', 'n'), 2411),
 (('h', '<E>'), 2409),
 (('r', 'a'), 2356),
 (('a', 'h'), 2332),
 (('h', 'a'), 2244),
 (('y', 'a'), 2143),
 (('i', 'n'), 2126),
 (('<S>', 's'), 2055),
 (('a', 'y'), 2050),
 (('y', '<E>'), 2007),
 (('e', 'r'), 1958),
 (('n', 'n'), 1906),
 (('y', 'n'), 1826),
 (('k', 'a'), 1731),
 (('n', 'i'), 1725),
 (('r', 'e'), 1697),
 (('<S>', 'd'), 1690),
 (('i', 'e'), 1653),
 (('a', 'i'), 1650),
 (('<S>', 'r'), 1639),
 (('a', 'm'), 1634),
 (('l', 'y'), 1588),
 (('<S>', 'l'), 1572),
 (('<S>', 'c'), 1542),
 (('<S>', 'e'), 1531),
 (('j', 'a'), 1473),
 (('r', '<E>'), 1377),
 (('n', 'e'), 1359),
 (('l', 'l'), 1345),
 (('i', 'l'), 1345),
 (('i', 's'), 1316),
 (('l', '<E>'), 1314),
 (('<S>', 't'), 1308),
 (('<S>', 'b'), 1306),
 (('d', 'a'), 1303),
 (('s', 'h'), 1285),
 (('d', 'e'), 1283),
 (('e', 'e'), 1271),
 (('m', 'i'), 1256),
 (('s', 'a'), 1201),
 (('s', '<E>'), 1169),
 (('<S>', 'n'), 1146),
 (('a', 's'), 1118),
 (('y', 'l'), 1104),
 (('e', 'y'), 1070),
 (('o', 'r'), 1059),
 (('a', 'd'), 1042),
 (('t', 'a'), 1027),
 (('<S>', 'z'), 929),
 (('v', 'i'), 911),
 (('k', 'e'), 895),
 (('s', 'e'), 884),
 (('<S>', 'h'), 874),
 (('r', 'o'), 869),
 (('e', 's'), 861),
 (('z', 'a'), 860),
 (('o', '<E>'), 855),
 (('i', 'r'), 849),
 (('b', 'r'), 842),
 (('a', 'v'), 834),
 (('m', 'e'), 818),
 (('e', 'i'), 818),
 (('c', 'a'), 815),
 (('i', 'y'), 779),
 (('r', 'y'), 773),
 (('e', 'm'), 769),
 (('s', 't'), 765),
 (('h', 'i'), 729),
 (('t', 'e'), 716),
 (('n', 'd'), 704),
 (('l', 'o'), 692),
 (('a', 'e'), 692),
 (('a', 't'), 687),
 (('s', 'i'), 684),
 (('e', 'a'), 679),
 (('d', 'i'), 674),
 (('h', 'e'), 674),
 (('<S>', 'g'), 669),
 (('t', 'o'), 667),
 (('c', 'h'), 664),
 (('b', 'e'), 655),
 (('t', 'h'), 647),
 (('v', 'a'), 642),
 (('o', 'l'), 619),
 (('<S>', 'i'), 591),
 (('i', 'o'), 588),
 (('e', 't'), 580),
 (('v', 'e'), 568),
 (('a', 'k'), 568),
 (('a', 'a'), 556),
 (('c', 'e'), 551),
 (('a', 'b'), 541),
 (('i', 't'), 541),
 (('<S>', 'y'), 535),
 (('t', 'i'), 532),
 (('s', 'o'), 531),
 (('m', '<E>'), 516),
 (('d', '<E>'), 516),
 (('<S>', 'p'), 515),
 (('i', 'c'), 509),
 (('k', 'i'), 509),
 (('o', 's'), 504),
 (('n', 'o'), 496),
 (('t', '<E>'), 483),
 (('j', 'o'), 479),
 (('u', 's'), 474),
 (('a', 'c'), 470),
 (('n', 'y'), 465),
 (('e', 'v'), 463),
 (('s', 's'), 461),
 (('m', 'o'), 452),
 (('i', 'k'), 445),
 (('n', 't'), 443),
 (('i', 'd'), 440),
 (('j', 'e'), 440),
 (('a', 'z'), 435),
 (('i', 'g'), 428),
 (('i', 'm'), 427),
 (('r', 'r'), 425),
 (('d', 'r'), 424),
 (('<S>', 'f'), 417),
 (('u', 'r'), 414),
 (('r', 'l'), 413),
 (('y', 's'), 401),
 (('<S>', 'o'), 394),
 (('e', 'd'), 384),
 (('a', 'u'), 381),
 (('c', 'o'), 380),
 (('k', 'y'), 379),
 (('d', 'o'), 378),
 (('<S>', 'v'), 376),
 (('t', 't'), 374),
 (('z', 'e'), 373),
 (('z', 'i'), 364),
 (('k', '<E>'), 363),
 (('g', 'h'), 360),
 (('t', 'r'), 352),
 (('k', 'o'), 344),
 (('t', 'y'), 341),
 (('g', 'e'), 334),
 (('g', 'a'), 330),
 (('l', 'u'), 324),
 (('b', 'a'), 321),
 (('d', 'y'), 317),
 (('c', 'k'), 316),
 (('<S>', 'w'), 307),
 (('k', 'h'), 307),
 (('u', 'l'), 301),
 (('y', 'e'), 301),
 (('y', 'r'), 291),
 (('m', 'y'), 287),
 (('h', 'o'), 287),
 (('w', 'a'), 280),
 (('s', 'l'), 279),
 (('n', 's'), 278),
 (('i', 'z'), 277),
 (('u', 'n'), 275),
 (('o', 'u'), 275),
 (('n', 'g'), 273),
 (('y', 'd'), 272),
 (('c', 'i'), 271),
 (('y', 'o'), 271),
 (('i', 'v'), 269),
 (('e', 'o'), 269),
 (('o', 'm'), 261),
 (('r', 'u'), 252),
 (('f', 'a'), 242),
 (('b', 'i'), 217),
 (('s', 'y'), 215),
 (('n', 'c'), 213),
 (('h', 'y'), 213),
 (('p', 'a'), 209),
 (('r', 't'), 208),
 (('q', 'u'), 206),
 (('p', 'h'), 204),
 (('h', 'r'), 204),
 (('j', 'u'), 202),
 (('g', 'r'), 201),
 (('p', 'e'), 197),
 (('n', 'l'), 195),
 (('y', 'i'), 192),
 (('g', 'i'), 190),
 (('o', 'd'), 190),
 (('r', 's'), 190),
 (('r', 'd'), 187),
 (('h', 'l'), 185),
 (('s', 'u'), 185),
 (('a', 'x'), 182),
 (('e', 'z'), 181),
 (('e', 'k'), 178),
 (('o', 'v'), 176),
 (('a', 'j'), 175),
 (('o', 'h'), 171),
 (('u', 'e'), 169),
 (('m', 'm'), 168),
 (('a', 'g'), 168),
 (('h', 'u'), 166),
 (('x', '<E>'), 164),
 (('u', 'a'), 163),
 (('r', 'm'), 162),
 (('a', 'w'), 161),
 (('f', 'i'), 160),
 (('z', '<E>'), 160),
 (('u', '<E>'), 155),
 (('u', 'm'), 154),
 (('e', 'c'), 153),
 (('v', 'o'), 153),
 (('e', 'h'), 152),
 (('p', 'r'), 151),
 (('d', 'd'), 149),
 (('o', 'a'), 149),
 (('w', 'e'), 149),
 (('w', 'i'), 148),
 (('y', 'm'), 148),
 (('z', 'y'), 147),
 (('n', 'z'), 145),
 (('y', 'u'), 141),
 (('r', 'n'), 140),
 (('o', 'b'), 140),
 (('k', 'l'), 139),
 (('m', 'u'), 139),
 (('l', 'd'), 138),
 (('h', 'n'), 138),
 (('u', 'd'), 136),
 (('<S>', 'x'), 134),
 (('t', 'l'), 134),
 (('a', 'f'), 134),
 (('o', 'e'), 132),
 (('e', 'x'), 132),
 (('e', 'g'), 125),
 (('f', 'e'), 123),
 (('z', 'l'), 123),
 (('u', 'i'), 121),
 (('v', 'y'), 121),
 (('e', 'b'), 121),
 (('r', 'h'), 121),
 (('j', 'i'), 119),
 (('o', 't'), 118),
 (('d', 'h'), 118),
 (('h', 'm'), 117),
 (('c', 'l'), 116),
 (('o', 'o'), 115),
 (('y', 'c'), 115),
 (('o', 'w'), 114),
 (('o', 'c'), 114),
 (('f', 'r'), 114),
 (('b', '<E>'), 114),
 (('m', 'b'), 112),
 (('z', 'o'), 110),
 (('i', 'b'), 110),
 (('i', 'u'), 109),
 (('k', 'r'), 109),
 (('g', '<E>'), 108),
 (('y', 'v'), 106),
 (('t', 'z'), 105),
 (('b', 'o'), 105),
 (('c', 'y'), 104),
 (('y', 't'), 104),
 (('u', 'b'), 103),
 (('u', 'c'), 103),
 (('x', 'a'), 103),
 (('b', 'l'), 103),
 (('o', 'y'), 103),
 (('x', 'i'), 102),
 (('i', 'f'), 101),
 (('r', 'c'), 99),
 (('c', '<E>'), 97),
 (('m', 'r'), 97),
 (('n', 'u'), 96),
 (('o', 'p'), 95),
 (('i', 'h'), 95),
 (('k', 's'), 95),
 (('l', 's'), 94),
 (('u', 'k'), 93),
 (('<S>', 'q'), 92),
 (('d', 'u'), 92),
 (('s', 'm'), 90),
 (('r', 'k'), 90),
 (('i', 'x'), 89),
 (('v', '<E>'), 88),
 (('y', 'k'), 86),
 (('u', 'w'), 86),
 (('g', 'u'), 85),
 (('b', 'y'), 83),
 (('e', 'p'), 83),
 (('g', 'o'), 83),
 (('s', 'k'), 82),
 (('u', 't'), 82),
 (('a', 'p'), 82),
 (('e', 'f'), 82),
 (('i', 'i'), 82),
 (('r', 'v'), 80),
 (('f', '<E>'), 80),
 (('t', 'u'), 78),
 (('y', 'z'), 78),
 (('<S>', 'u'), 78),
 (('l', 't'), 77),
 (('r', 'g'), 76),
 (('c', 'r'), 76),
 (('i', 'j'), 76),
 (('w', 'y'), 73),
 (('z', 'u'), 73),
 (('l', 'v'), 72),
 (('h', 't'), 71),
 (('j', '<E>'), 71),
 (('x', 't'), 70),
 (('o', 'i'), 69),
 (('e', 'u'), 69),
 (('o', 'k'), 68),
 (('b', 'd'), 65),
 (('a', 'o'), 63),
 (('p', 'i'), 61),
 (('s', 'c'), 60),
 (('d', 'l'), 60),
 (('l', 'm'), 60),
 (('a', 'q'), 60),
 (('f', 'o'), 60),
 (('p', 'o'), 59),
 (('n', 'k'), 58),
 (('w', 'n'), 58),
 (('u', 'h'), 58),
 (('e', 'j'), 55),
 (('n', 'v'), 55),
 (('s', 'r'), 55),
 (('o', 'z'), 54),
 (('i', 'p'), 53),
 (('l', 'b'), 52),
 (('i', 'q'), 52),
 (('w', '<E>'), 51),
 (('m', 'c'), 51),
 (('s', 'p'), 51),
 (('e', 'w'), 50),
 (('k', 'u'), 50),
 (('v', 'r'), 48),
 (('u', 'g'), 47),
 (('o', 'x'), 45),
 (('u', 'z'), 45),
 (('z', 'z'), 45),
 (('j', 'h'), 45),
 (('b', 'u'), 45),
 (('o', 'g'), 44),
 (('n', 'r'), 44),
 (('f', 'f'), 44),
 (('n', 'j'), 44),
 (('z', 'h'), 43),
 (('c', 'c'), 42),
 (('r', 'b'), 41),
 (('x', 'o'), 41),
 (('b', 'h'), 41),
 (('p', 'p'), 39),
 (('x', 'l'), 39),
 (('h', 'v'), 39),
 (('b', 'b'), 38),
 (('m', 'p'), 38),
 (('x', 'x'), 38),
 (('u', 'v'), 37),
 (('x', 'e'), 36),
 (('w', 'o'), 36),
 (('c', 't'), 35),
 (('z', 'm'), 35),
 (('t', 's'), 35),
 (('m', 's'), 35),
 (('c', 'u'), 35),
 (('o', 'f'), 34),
 (('u', 'x'), 34),
 (('k', 'w'), 34),
 (('p', '<E>'), 33),
 (('g', 'l'), 32),
 (('z', 'r'), 32),
 (('d', 'n'), 31),
 (('g', 't'), 31),
 (('g', 'y'), 31),
 (('h', 's'), 31),
 (('x', 's'), 31),
 (('g', 's'), 30),
 (('x', 'y'), 30),
 (('y', 'g'), 30),
 (('d', 'm'), 30),
 (('d', 's'), 29),
 (('h', 'k'), 29),
 (('y', 'x'), 28),
 (('q', '<E>'), 28),
 (('g', 'n'), 27),
 (('y', 'b'), 27),
 (('g', 'w'), 26),
 (('n', 'h'), 26),
 (('k', 'n'), 26),
 (('g', 'g'), 25),
 (('d', 'g'), 25),
 (('l', 'c'), 25),
 (('r', 'j'), 25),
 (('w', 'u'), 25),
 (('l', 'k'), 24),
 (('m', 'd'), 24),
 (('s', 'w'), 24),
 (('s', 'n'), 24),
 (('h', 'd'), 24),
 (('w', 'h'), 23),
 (('y', 'j'), 23),
 (('y', 'y'), 23),
 (('r', 'z'), 23),
 (('d', 'w'), 23),
 (('w', 'r'), 22),
 (('t', 'n'), 22),
 (('l', 'f'), 22),
 (('y', 'h'), 22),
 (('r', 'w'), 21),
 (('s', 'b'), 21),
 (('m', 'n'), 20),
 (('f', 'l'), 20),
 (('w', 's'), 20),
 (('k', 'k'), 20),
 (('h', 'z'), 20),
 (('g', 'd'), 19),
 (('l', 'h'), 19),
 (('n', 'm'), 19),
 (('x', 'z'), 19),
 (('u', 'f'), 19),
 (('f', 't'), 18),
 (('l', 'r'), 18),
 (('p', 't'), 17),
 (('t', 'c'), 17),
 (('k', 't'), 17),
 (('d', 'v'), 17),
 (('u', 'p'), 16),
 (('p', 'l'), 16),
 (('l', 'w'), 16),
 (('p', 's'), 16),
 (('o', 'j'), 16),
 (('r', 'q'), 16),
 (('y', 'p'), 15),
 (('l', 'p'), 15),
 (('t', 'v'), 15),
 (('r', 'p'), 14),
 (('l', 'n'), 14),
 (('e', 'q'), 14),
 (('f', 'y'), 14),
 (('s', 'v'), 14),
 (('u', 'j'), 14),
 (('v', 'l'), 14),
 (('q', 'a'), 13),
 (('u', 'y'), 13),
 (('q', 'i'), 13),
 (('w', 'l'), 13),
 (('p', 'y'), 12),
 (('y', 'f'), 12),
 (('c', 'q'), 11),
 (('j', 'r'), 11),
 (('n', 'w'), 11),
 (('n', 'f'), 11),
 (('t', 'w'), 11),
 (('m', 'z'), 11),
 (('u', 'o'), 10),
 (('f', 'u'), 10),
 (('l', 'z'), 10),
 (('h', 'w'), 10),
 (('u', 'q'), 10),
 (('j', 'y'), 10),
 (('s', 'z'), 10),
 (('s', 'd'), 9),
 (('j', 'l'), 9),
 (('d', 'j'), 9),
 (('k', 'm'), 9),
 (('r', 'f'), 9),
 (('h', 'j'), 9),
 (('v', 'n'), 8),
 (('n', 'b'), 8),
 (('i', 'w'), 8),
 (('h', 'b'), 8),
 (('b', 's'), 8),
 (('w', 't'), 8),
 (('w', 'd'), 8),
 (('v', 'v'), 7),
 (('v', 'u'), 7),
 (('j', 's'), 7),
 (('m', 'j'), 7),
 (('f', 's'), 6),
 (('l', 'g'), 6),
 (('l', 'j'), 6),
 (('j', 'w'), 6),
 (('n', 'x'), 6),
 (('y', 'q'), 6),
 (('w', 'k'), 6),
 (('g', 'm'), 6),
 (('x', 'u'), 5),
 (('m', 'h'), 5),
 (('m', 'l'), 5),
 (('j', 'm'), 5),
 (('c', 's'), 5),
 (('j', 'v'), 5),
 (('n', 'p'), 5),
 (('d', 'f'), 5),
 (('x', 'd'), 5),
 (('z', 'b'), 4),
 (('f', 'n'), 4),
 (('x', 'c'), 4),
 (('m', 't'), 4),
 (('t', 'm'), 4),
 (('z', 'n'), 4),
 (('z', 't'), 4),
 (('p', 'u'), 4),
 (('c', 'z'), 4),
 (('b', 'n'), 4),
 (('z', 's'), 4),
 (('f', 'w'), 4),
 (('d', 't'), 4),
 (('j', 'd'), 4),
 (('j', 'c'), 4),
 (('y', 'w'), 4),
 (('v', 'k'), 3),
 (('x', 'w'), 3),
 (('t', 'j'), 3),
 (('c', 'j'), 3),
 (('q', 'w'), 3),
 (('g', 'b'), 3),
 (('o', 'q'), 3),
 (('r', 'x'), 3),
 (('d', 'c'), 3),
 (('g', 'j'), 3),
 (('x', 'f'), 3),
 (('z', 'w'), 3),
 (('d', 'k'), 3),
 (('u', 'u'), 3),
 (('m', 'v'), 3),
 (('c', 'x'), 3),
 (('l', 'q'), 3),
 (('p', 'b'), 2),
 (('t', 'g'), 2),
 (('q', 's'), 2),
 (('t', 'x'), 2),
 (('f', 'k'), 2),
 (('b', 't'), 2),
 (('j', 'n'), 2),
 (('k', 'c'), 2),
 (('z', 'k'), 2),
 (('s', 'j'), 2),
 (('s', 'f'), 2),
 (('z', 'j'), 2),
 (('n', 'q'), 2),
 (('f', 'z'), 2),
 (('h', 'g'), 2),
 (('w', 'w'), 2),
 (('k', 'j'), 2),
 (('j', 'k'), 2),
 (('w', 'm'), 2),
 (('z', 'c'), 2),
 (('z', 'v'), 2),
 (('w', 'f'), 2),
 (('q', 'm'), 2),
 (('k', 'z'), 2),
 (('j', 'j'), 2),
 (('z', 'p'), 2),
 (('j', 't'), 2),
 (('k', 'b'), 2),
 (('m', 'w'), 2),
 (('h', 'f'), 2),
 (('c', 'g'), 2),
 (('t', 'f'), 2),
 (('h', 'c'), 2),
 (('q', 'o'), 2),
 (('k', 'd'), 2),
 (('k', 'v'), 2),
 (('s', 'g'), 2),
 (('z', 'd'), 2),
 (('q', 'r'), 1),
 (('d', 'z'), 1),
 (('p', 'j'), 1),
 (('q', 'l'), 1),
 (('p', 'f'), 1),
 (('q', 'e'), 1),
 (('b', 'c'), 1),
 (('c', 'd'), 1),
 (('m', 'f'), 1),
 (('p', 'n'), 1),
 (('w', 'b'), 1),
 (('p', 'c'), 1),
 (('h', 'p'), 1),
 (('f', 'h'), 1),
 (('b', 'j'), 1),
 (('f', 'g'), 1),
 (('z', 'g'), 1),
 (('c', 'p'), 1),
 (('p', 'k'), 1),
 (('p', 'm'), 1),
 (('x', 'n'), 1),
 (('s', 'q'), 1),
 (('k', 'f'), 1),
 (('m', 'k'), 1),
 (('x', 'h'), 1),
 (('g', 'f'), 1),
 (('v', 'b'), 1),
 (('j', 'p'), 1),
 (('g', 'z'), 1),
 (('v', 'd'), 1),
 (('d', 'b'), 1),
 (('v', 'h'), 1),
 (('h', 'h'), 1),
 (('g', 'v'), 1),
 (('d', 'q'), 1),
 (('x', 'b'), 1),
 (('w', 'z'), 1),
 (('h', 'q'), 1),
 (('j', 'b'), 1),
 (('x', 'm'), 1),
 (('w', 'g'), 1),
 (('t', 'b'), 1),
 (('z', 'x'), 1)]

Convert bigram dict PyTorch tensor (2D array)

We need a 28 x 28 2D array. But this means some redundant data is stored:

  • Since <S> is always 1st char, it never follows anything.
    • Would lead to a blank col of bigrams like a<S>, b<S>, …, z<S>, <S><S>, <E><S>
  • Since <S> is always last char, it never preceeds anything.
    • Would lead to a blank row of bigrams like <E>a, <E>b, …, <E>z, <E><S>, <E><E>
  • The only potentially legal one is <S><E> (i.e. a line with no text)

Solution: Have a single special token: .. This reduces the heatmap from 28 x 28 27 x 27.

# i - initialise 27 x 27 tensor (2D array), and create lookup tables (maps)
import torch
N = torch.zeros((27, 27), dtype=torch.int32) # init counts array: 32-bit ints (26 chars + EOL)
 
chars = sorted(list(set(''.join(words))))  # sorted list: unique (26) chars in full dataset
stoi = {s:i+1 for i,s in enumerate(chars)} # map (dict type): `str`->`int`. 'a'=1, ..., 'z'=26
stoi['.'] = 0                              # map: '.'=0. Treat '<S>' and '<E>' as the same!
itos = {i:s for s,i in stoi.items()}       # map (reverse): invert `stoi` dict

Recreate the frequency map, but storing the bigram counts in the tensor object, N. This will help visualise it as a 2D array.

# i - recreate freq map, but storing bigram counts in PyTorch tensor: N
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    N[ix1, ix2] += 1

Visualise bigram counts as heatmap

In the visualisation below:

  • Row 1: Counts for first letters
  • Col 1: Counts for last letters
  • All other cells: Counts for all other bigrams
# i - visualise array (heatmap)
import matplotlib.pyplot as plt
%matplotlib inline
 
plt.figure(figsize=(16,16))
plt.imshow(N, cmap='Blues') # show entire array image
 
# iterate over each cell in array
for i in range(27):
    for j in range(27):
        chstr = itos[i] + itos[j] # create char strings (e.g. 'ac', 'gb', 'a.')
        plt.text(j, i, chstr, ha="center", va="bottom", color='gray')       # write char string
        plt.text(j, i, N[i, j].item(), ha="center", va="top", color='gray') # write count (item)
plt.axis('off');
plot

Sources