- Prev: 01_build_mlp
- Next: 03_overfitting_train_val_test
# import, only use first 5 names (32 training examples), specify network parameters C, W1, b1, W2, b2
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline
# import data
words = open('data/names.txt', 'r').read().splitlines()
# build the vocabulary of characters, and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
# build mini-data set X and Y of first 5 names (32x 3-char training examples)
block_size = 3 # context length: how many characters do we take to predict the next one?
X, Y = [], [] # X: NN input training examples, Y: labels for each input in X
for w in words[:5]:
print(w)
context = [0] * block_size
for ch in w + '.':
ix = stoi[ch]
X.append(context)
Y.append(ix)
print(''.join(itos[i] for i in context), '->', itos[ix])
context = context[1:] + [ix] # crop and append
X = torch.tensor(X)
Y = torch.tensor(Y)
# print('\nX.shape:', X.shape, 'X.dtype:', X.dtype)
# print('Y.shape:', Y.shape, 'Y.dtype:', Y.dtype)
print('\nX:', X.shape, '-> Y:', Y.shape)
g = torch.Generator().manual_seed(2147483647) # for reproducibility
# define parameters (3,481 in total)
C = torch.randn((27, 2), generator=g) # embedding matrix (lookup table for input tokens)
W1 = torch.randn((6, 100), generator=g) # hidden layer's incoming weights: 6 inputs to layer, 100 hidden neurons in layer
b1 = torch.randn(100, generator=g) # 100 biases live "in" hidden layer's neurons
W2 = torch.randn((100, 27), generator=g) # output layer's incoming weights: 100 inputs to layer, 27 output neurons in layer
b2 = torch.randn(27, generator=g) # 27 biases live "in" output layer's neurons
parameters = [C, W1, b1, W2, b2] # list of all parameters (makes easier to count)
print('num. of parameters:', sum(p.nelement() for p in parameters)) # total parameter count in network: 3,481emma
... -> e
..e -> m
.em -> m
emm -> a
mma -> .
olivia
... -> o
..o -> l
.ol -> i
oli -> v
liv -> i
ivi -> a
via -> .
ava
... -> a
..a -> v
.av -> a
ava -> .
isabella
... -> i
..i -> s
.is -> a
isa -> b
sab -> e
abe -> l
bel -> l
ell -> a
lla -> .
sophia
... -> s
..s -> o
.so -> p
sop -> h
oph -> i
phi -> a
hia -> .
X: torch.Size([32, 3]) -> Y: torch.Size([32])
num. of parameters: 3481Overfitting one batch
We now use a small sample (batch). Just 5 names (32 training examples) instead of all 32,033 names (228,146 training examples).
- Since we have 3,481 parameters, we are overfitting that single batch of data.
- Too many parameters for too few data points. Expect very low loss
# ensure all 3,481 parameters have gradient (to enable optimisation)
for p in parameters:
p.requires_grad = True# i - run 1000 training iterations
for _ in range(1000):
# forward pass
emb = C[X] # (32, 3, 2) -> (32, 6) on next line via emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y) # simpler!
print('loss:', loss.item())
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent (update parameters)
for p in parameters:
p.data += -0.1 * p.grad
print(loss.item())loss: 17.76971435546875
loss: 13.656402587890625
loss: 11.298770904541016
loss: 9.4524564743042
loss: 7.984263896942139
loss: 6.891321659088135
loss: 6.1000142097473145
loss: 5.452035903930664
loss: 4.898152828216553
loss: 4.414663791656494
loss: 3.985849142074585
loss: 3.602830410003662
loss: 3.2621419429779053
loss: 2.961381196975708
loss: 2.6982970237731934
loss: 2.469712734222412
loss: 2.271660804748535
loss: 2.1012840270996094
loss: 1.957176923751831
loss: 1.8374861478805542
loss: 1.7380965948104858
loss: 1.653511881828308
loss: 1.5790901184082031
loss: 1.511767029762268
loss: 1.4496052265167236
loss: 1.3913124799728394
loss: 1.3359930515289307
loss: 1.2830536365509033
loss: 1.232191801071167
loss: 1.1833821535110474
loss: 1.1367992162704468
loss: 1.0926648378372192
loss: 1.0510929822921753
loss: 1.0120275020599365
loss: 0.9752705693244934
loss: 0.9405568242073059
loss: 0.9076130986213684
loss: 0.8761922121047974
loss: 0.8460891246795654
loss: 0.8171360492706299
loss: 0.78919917345047
loss: 0.7621749043464661
loss: 0.7359816431999207
loss: 0.7105581760406494
loss: 0.6858612298965454
loss: 0.6618653535842896
loss: 0.6385658383369446
loss: 0.6159819960594177
loss: 0.594166100025177
loss: 0.5732106566429138
loss: 0.5532564520835876
loss: 0.5344885587692261
loss: 0.5171172022819519
loss: 0.501331627368927
loss: 0.48724299669265747
loss: 0.4748406410217285
loss: 0.4639979302883148
loss: 0.45451465249061584
loss: 0.4461711645126343
loss: 0.43876639008522034
loss: 0.4321332573890686
loss: 0.4261389970779419
loss: 0.4206800162792206
loss: 0.41567546129226685
loss: 0.4110616147518158
loss: 0.40678736567497253
loss: 0.402810662984848
loss: 0.3990974426269531
loss: 0.39561811089515686
loss: 0.39234787225723267
loss: 0.38926541805267334
loss: 0.386352002620697
loss: 0.38359174132347107
loss: 0.3809700608253479
loss: 0.37847432494163513
loss: 0.37609297037124634
loss: 0.3738164007663727
loss: 0.37163498997688293
loss: 0.36954089999198914
loss: 0.3675267696380615
loss: 0.3655855059623718
loss: 0.36371132731437683
loss: 0.3618984520435333
loss: 0.3601416051387787
loss: 0.35843607783317566
loss: 0.35677799582481384
loss: 0.35516273975372314
loss: 0.35358697175979614
loss: 0.35204702615737915
loss: 0.3505397439002991
loss: 0.3490622937679291
loss: 0.3476121723651886
loss: 0.3461866080760956
loss: 0.3447835445404053
loss: 0.34340089559555054
loss: 0.34203675389289856
loss: 0.3406899571418762
loss: 0.3393586277961731
loss: 0.3380417823791504
loss: 0.33673879504203796
loss: 0.335448682308197
loss: 0.3341710865497589
loss: 0.33290597796440125
loss: 0.331652969121933
loss: 0.33041250705718994
loss: 0.32918474078178406
loss: 0.3279706835746765
loss: 0.32677045464515686
loss: 0.325585275888443
loss: 0.3244158625602722
loss: 0.3232629597187042
loss: 0.3221278488636017
loss: 0.3210112154483795
loss: 0.3199138939380646
loss: 0.31883662939071655
loss: 0.31778040528297424
loss: 0.3167456388473511
loss: 0.31573286652565
loss: 0.31474244594573975
loss: 0.31377434730529785
loss: 0.3128291964530945
loss: 0.31190672516822815
loss: 0.3110068738460541
loss: 0.31012943387031555
loss: 0.309274286031723
loss: 0.3084409236907959
loss: 0.3076290488243103
loss: 0.3068382143974304
loss: 0.30606809258461
loss: 0.30531778931617737
loss: 0.30458715558052063
loss: 0.30387550592422485
loss: 0.3031822741031647
loss: 0.3025068938732147
loss: 0.3018488883972168
loss: 0.30120766162872314
loss: 0.300582617521286
loss: 0.2999733090400696
loss: 0.29937922954559326
loss: 0.2987997233867645
loss: 0.29823440313339233
loss: 0.2976827621459961
loss: 0.2971443831920624
loss: 0.2966187000274658
loss: 0.2961055636405945
loss: 0.2956041693687439
loss: 0.29511430859565735
loss: 0.29463568329811096
loss: 0.29416781663894653
loss: 0.29371023178100586
loss: 0.29326289892196655
loss: 0.29282522201538086
loss: 0.2923969030380249
loss: 0.29197776317596436
loss: 0.29156753420829773
loss: 0.2911657989025116
loss: 0.2907722592353821
loss: 0.2903868556022644
loss: 0.29000917077064514
loss: 0.28963905572891235
loss: 0.28927624225616455
loss: 0.2889205515384674
loss: 0.28857165575027466
loss: 0.28822949528694153
loss: 0.28789380192756653
loss: 0.2875644564628601
loss: 0.2872411906719208
loss: 0.286923885345459
loss: 0.286612331867218
loss: 0.28630638122558594
loss: 0.2860059440135956
loss: 0.28571078181266785
loss: 0.2854207158088684
loss: 0.28513580560684204
loss: 0.2848557233810425
loss: 0.28458040952682495
loss: 0.2843098044395447
loss: 0.28404372930526733
loss: 0.2837819457054138
loss: 0.2835245132446289
loss: 0.28327125310897827
loss: 0.2830221354961395
loss: 0.28277695178985596
loss: 0.2825356721878052
loss: 0.28229820728302
loss: 0.28206437826156616
loss: 0.2818342447280884
loss: 0.28160756826400757
loss: 0.2813843786716461
loss: 0.2811645269393921
loss: 0.2809479534626007
loss: 0.2807345986366272
loss: 0.2805244028568268
loss: 0.2803173065185547
loss: 0.2801131010055542
loss: 0.27991190552711487
loss: 0.27971357107162476
loss: 0.2795180380344391
loss: 0.2793252170085907
loss: 0.279135137796402
loss: 0.2789476811885834
loss: 0.2787627577781677
loss: 0.27858036756515503
loss: 0.2784004509449005
loss: 0.278222918510437
loss: 0.2780478000640869
loss: 0.27787500619888306
loss: 0.27770447731018066
loss: 0.27753615379333496
loss: 0.27737000584602356
loss: 0.27720603346824646
loss: 0.2770440876483917
loss: 0.2768843173980713
loss: 0.27672645449638367
loss: 0.2765706181526184
loss: 0.27641671895980835
loss: 0.2762647271156311
loss: 0.2761145234107971
loss: 0.27596622705459595
loss: 0.27581968903541565
loss: 0.275674968957901
loss: 0.27553191781044006
loss: 0.2753906548023224
loss: 0.27525097131729126
loss: 0.27511298656463623
loss: 0.27497658133506775
loss: 0.2748417854309082
loss: 0.2747085392475128
loss: 0.2745767831802368
loss: 0.274446576833725
loss: 0.27431780099868774
loss: 0.2741904854774475
loss: 0.2740645706653595
loss: 0.2739401161670685
loss: 0.2738170623779297
loss: 0.2736952602863312
loss: 0.2735748291015625
loss: 0.27345573902130127
loss: 0.2733379006385803
loss: 0.27322134375572205
loss: 0.27310603857040405
loss: 0.27299192547798157
loss: 0.27287906408309937
loss: 0.2727673351764679
loss: 0.27265679836273193
loss: 0.2725474238395691
loss: 0.2724391520023346
loss: 0.27233198285102844
loss: 0.2722259759902954
loss: 0.27212098240852356
loss: 0.27201706171035767
loss: 0.2719142735004425
loss: 0.27181243896484375
loss: 0.27171164751052856
loss: 0.2716118097305298
loss: 0.2715129554271698
loss: 0.2714150846004486
loss: 0.2713181674480438
loss: 0.27122220396995544
loss: 0.27112719416618347
loss: 0.27103301882743835
loss: 0.27093979716300964
loss: 0.2708474397659302
loss: 0.27075594663619995
loss: 0.27066537737846375
loss: 0.270575612783432
loss: 0.2704866826534271
loss: 0.2703986167907715
loss: 0.27031129598617554
loss: 0.27022483944892883
loss: 0.2701391279697418
loss: 0.27005425095558167
loss: 0.2699701189994812
loss: 0.26988673210144043
loss: 0.26980409026145935
loss: 0.26972219347953796
loss: 0.26964104175567627
loss: 0.26956063508987427
loss: 0.2694809138774872
loss: 0.26940181851387024
loss: 0.26932352781295776
loss: 0.26924583315849304
loss: 0.2691688537597656
loss: 0.26909250020980835
loss: 0.269016832113266
loss: 0.2689417898654938
loss: 0.26886746287345886
loss: 0.26879364252090454
loss: 0.2687205374240875
loss: 0.26864802837371826
loss: 0.26857611536979675
loss: 0.2685047686100006
loss: 0.268434077501297
loss: 0.26836395263671875
loss: 0.26829439401626587
loss: 0.26822537183761597
loss: 0.2681569457054138
loss: 0.2680891156196594
loss: 0.26802173256874084
loss: 0.2679549753665924
loss: 0.26788872480392456
loss: 0.2678229808807373
loss: 0.267757773399353
loss: 0.2676931321620941
loss: 0.26762890815734863
loss: 0.2675652801990509
loss: 0.267502099275589
loss: 0.26743942499160767
loss: 0.26737722754478455
loss: 0.2673155665397644
loss: 0.2672542929649353
loss: 0.2671935558319092
loss: 0.2671332359313965
loss: 0.267073392868042
loss: 0.2670139968395233
loss: 0.26695510745048523
loss: 0.2668965756893158
loss: 0.26683855056762695
loss: 0.26678088307380676
loss: 0.2667236924171448
loss: 0.2666669487953186
loss: 0.26661059260368347
loss: 0.26655468344688416
loss: 0.2664991617202759
loss: 0.26644399762153625
loss: 0.26638931035995483
loss: 0.26633498072624207
loss: 0.26628103852272034
loss: 0.26622748374938965
loss: 0.26617431640625
loss: 0.2661215662956238
loss: 0.2660691738128662
loss: 0.2660171389579773
loss: 0.26596546173095703
loss: 0.2659141719341278
loss: 0.26586318016052246
loss: 0.2658126652240753
loss: 0.2657623887062073
loss: 0.2657124996185303
loss: 0.26566293835639954
loss: 0.2656137943267822
loss: 0.2655649185180664
loss: 0.26551637053489685
loss: 0.26546818017959595
loss: 0.2654203176498413
loss: 0.26537278294563293
loss: 0.26532554626464844
loss: 0.265278697013855
loss: 0.26523202657699585
loss: 0.26518577337265015
loss: 0.2651398181915283
loss: 0.26509416103363037
loss: 0.2650488317012787
loss: 0.2650037407875061
loss: 0.264958918094635
loss: 0.26491451263427734
loss: 0.2648703157901764
loss: 0.26482638716697693
loss: 0.2647828161716461
loss: 0.264739453792572
loss: 0.2646964490413666
loss: 0.26465368270874023
loss: 0.264611154794693
loss: 0.26456892490386963
loss: 0.264526903629303
loss: 0.264485239982605
loss: 0.26444384455680847
loss: 0.2644026577472687
loss: 0.26436176896095276
loss: 0.26432105898857117
loss: 0.26428067684173584
loss: 0.2642405331134796
loss: 0.2642006278038025
loss: 0.26416099071502686
loss: 0.26412156224250793
loss: 0.2640824317932129
loss: 0.26404350996017456
loss: 0.26400479674339294
loss: 0.2639663517475128
loss: 0.2639281451702118
loss: 0.2638901472091675
loss: 0.26385238766670227
loss: 0.26381486654281616
loss: 0.26377755403518677
loss: 0.2637404799461365
loss: 0.2637036442756653
loss: 0.2636669874191284
loss: 0.26363053917884827
loss: 0.2635943591594696
loss: 0.26355835795402527
loss: 0.26352259516716003
loss: 0.26348698139190674
loss: 0.26345163583755493
loss: 0.26341649889945984
loss: 0.26338157057762146
loss: 0.2633468210697174
loss: 0.2633122205734253
loss: 0.26327791810035706
loss: 0.26324379444122314
loss: 0.26320981979370117
loss: 0.2631761133670807
loss: 0.26314252614974976
loss: 0.2631092071533203
loss: 0.2630760073661804
loss: 0.2630431056022644
loss: 0.26301029324531555
loss: 0.2629777789115906
loss: 0.26294535398483276
loss: 0.26291313767433167
loss: 0.26288118958473206
loss: 0.2628493905067444
loss: 0.26281774044036865
loss: 0.2627863883972168
loss: 0.26275521516799927
loss: 0.26272422075271606
loss: 0.2626934349536896
loss: 0.2626628577709198
loss: 0.2626325786113739
loss: 0.26260244846343994
loss: 0.26257261633872986
loss: 0.2625430226325989
loss: 0.2625136971473694
loss: 0.26248466968536377
loss: 0.2624559700489044
loss: 0.26242756843566895
loss: 0.2623995840549469
loss: 0.2623720169067383
loss: 0.2623448669910431
loss: 0.26231828331947327
loss: 0.26229214668273926
loss: 0.2622668743133545
loss: 0.2622421085834503
loss: 0.2622184455394745
loss: 0.2621954381465912
loss: 0.26217398047447205
loss: 0.26215335726737976
loss: 0.262134850025177
loss: 0.26211729645729065
loss: 0.2621029019355774
loss: 0.26208969950675964
loss: 0.2620808184146881
loss: 0.2620733678340912
loss: 0.2620723247528076
loss: 0.26207247376441956
loss: 0.26208218932151794
loss: 0.2620924711227417
loss: 0.2621168792247772
loss: 0.2621403932571411
loss: 0.262184739112854
loss: 0.2622247338294983
loss: 0.26229533553123474
loss: 0.26235491037368774
loss: 0.2624589800834656
loss: 0.2625398635864258
loss: 0.26268458366394043
loss: 0.2627852261066437
loss: 0.26297569274902344
loss: 0.2630893290042877
loss: 0.26332616806030273
loss: 0.2634389102458954
loss: 0.26371562480926514
loss: 0.26380711793899536
loss: 0.2641099691390991
loss: 0.26415717601776123
loss: 0.26446768641471863
loss: 0.2644520699977875
loss: 0.2647523283958435
loss: 0.2646661102771759
loss: 0.26494380831718445
loss: 0.2647913694381714
loss: 0.26504191756248474
loss: 0.26483649015426636
loss: 0.26506081223487854
loss: 0.2648184299468994
loss: 0.26502084732055664
loss: 0.2647562623023987
loss: 0.26494112610816956
loss: 0.26466551423072815
loss: 0.264836847782135
loss: 0.26455792784690857
loss: 0.26471850275993347
loss: 0.2644410729408264
loss: 0.2645929455757141
loss: 0.26432004570961
loss: 0.26446449756622314
loss: 0.26419755816459656
loss: 0.26433539390563965
loss: 0.26407545804977417
loss: 0.26420748233795166
loss: 0.26395484805107117
loss: 0.26408153772354126
loss: 0.26383647322654724
loss: 0.2639581859111786
loss: 0.26372039318084717
loss: 0.2638375163078308
loss: 0.2636069655418396
loss: 0.26371967792510986
loss: 0.2634962499141693
loss: 0.26360490918159485
loss: 0.2633882761001587
loss: 0.2634931206703186
loss: 0.26328301429748535
loss: 0.26338422298431396
loss: 0.2631804347038269
loss: 0.26327821612358093
loss: 0.2630803883075714
loss: 0.26317474246025085
loss: 0.26298266649246216
loss: 0.2630738914012909
loss: 0.26288744807243347
loss: 0.2629757821559906
loss: 0.2627946734428406
loss: 0.26288002729415894
loss: 0.2627039849758148
loss: 0.2627866864204407
loss: 0.2626155614852905
loss: 0.26269569993019104
loss: 0.2625291049480438
loss: 0.26260673999786377
loss: 0.26244473457336426
loss: 0.2625201344490051
loss: 0.26236236095428467
loss: 0.2624354958534241
loss: 0.26228177547454834
loss: 0.26235273480415344
loss: 0.26220303773880005
loss: 0.2622719407081604
loss: 0.26212596893310547
loss: 0.26219287514686584
loss: 0.2620506286621094
loss: 0.2621157169342041
loss: 0.261976957321167
loss: 0.2620403468608856
loss: 0.26190492510795593
loss: 0.26196661591529846
loss: 0.26183444261550903
loss: 0.26189446449279785
loss: 0.26176533102989197
loss: 0.2618238031864166
loss: 0.26169759035110474
loss: 0.2617545425891876
loss: 0.2616312503814697
loss: 0.26168686151504517
loss: 0.26156628131866455
loss: 0.26162052154541016
loss: 0.26150259375572205
loss: 0.2615554928779602
loss: 0.26144012808799744
loss: 0.2614918351173401
loss: 0.2613789141178131
loss: 0.26142942905426025
loss: 0.26131871342658997
loss: 0.26136815547943115
loss: 0.2612598240375519
loss: 0.2613081634044647
loss: 0.26120200753211975
loss: 0.261249303817749
loss: 0.2611452639102936
loss: 0.2611915171146393
loss: 0.2610894739627838
loss: 0.2611348330974579
loss: 0.26103475689888
loss: 0.2610791325569153
loss: 0.26098087430000305
loss: 0.2610243856906891
loss: 0.2609279155731201
loss: 0.2609705328941345
loss: 0.2608759105205536
loss: 0.2609177529811859
loss: 0.2608248293399811
loss: 0.2608658969402313
loss: 0.26077449321746826
loss: 0.2608148455619812
loss: 0.2607250511646271
loss: 0.26076462864875793
loss: 0.2606763541698456
loss: 0.2607152462005615
loss: 0.26062846183776855
loss: 0.2606666386127472
loss: 0.26058128476142883
loss: 0.2606188654899597
loss: 0.2605348825454712
loss: 0.26057180762290955
loss: 0.26048922538757324
loss: 0.2605256140232086
loss: 0.2604442238807678
loss: 0.26048001646995544
loss: 0.26039984822273254
loss: 0.2604350745677948
loss: 0.260356068611145
loss: 0.2603907287120819
loss: 0.26031294465065
loss: 0.2603470981121063
loss: 0.26027053594589233
loss: 0.26030415296554565
loss: 0.26022863388061523
loss: 0.2602618634700775
loss: 0.2601873278617859
loss: 0.26022011041641235
loss: 0.2601466476917267
loss: 0.26017889380455017
loss: 0.26010650396347046
loss: 0.2601383924484253
loss: 0.2600668668746948
loss: 0.26009830832481384
loss: 0.260027676820755
loss: 0.260058730840683
loss: 0.2599891126155853
loss: 0.26001983880996704
loss: 0.25995099544525146
loss: 0.25998130440711975
loss: 0.2599134147167206
loss: 0.25994327664375305
loss: 0.25987616181373596
loss: 0.25990578532218933
loss: 0.2598395347595215
loss: 0.25986871123313904
loss: 0.2598032057285309
loss: 0.25983214378356934
loss: 0.2597673833370209
loss: 0.25979602336883545
loss: 0.2597319781780243
loss: 0.2597602307796478
loss: 0.259696900844574
loss: 0.259724885225296
loss: 0.2596622705459595
loss: 0.25968998670578003
loss: 0.2596280574798584
loss: 0.2596554160118103
loss: 0.2595941424369812
loss: 0.2596212923526764
loss: 0.2595606744289398
loss: 0.2595875561237335
loss: 0.25952762365341187
loss: 0.2595542073249817
loss: 0.2594949007034302
loss: 0.2595212459564209
loss: 0.25946247577667236
loss: 0.2594885528087616
loss: 0.2594304382801056
loss: 0.2594563066959381
loss: 0.2593987286090851
loss: 0.2594243884086609
loss: 0.2593673765659332
loss: 0.2593928277492523
loss: 0.25933629274368286
loss: 0.25936153531074524
loss: 0.25930559635162354
loss: 0.2593306005001068
loss: 0.2592751383781433
loss: 0.25929996371269226
loss: 0.25924500823020935
loss: 0.2592696249485016
loss: 0.25921520590782166
loss: 0.25923964381217957
loss: 0.25918564200401306
loss: 0.2592098116874695
loss: 0.2591562867164612
loss: 0.2591802179813385
loss: 0.25912725925445557
loss: 0.25915104150772095
loss: 0.25909847021102905
loss: 0.2591221332550049
loss: 0.2590699791908264
loss: 0.25909343361854553
loss: 0.2590416967868805
loss: 0.2590649724006653
loss: 0.25901374220848083
loss: 0.2590368390083313
loss: 0.2589859366416931
loss: 0.25900882482528687
loss: 0.2589583992958069
loss: 0.2589811682701111
loss: 0.25893113017082214
loss: 0.2589537501335144
loss: 0.2589040994644165
loss: 0.2589265704154968
loss: 0.2588772773742676
loss: 0.25889959931373596
loss: 0.2588507831096649
loss: 0.2588728666305542
loss: 0.2588244080543518
loss: 0.2588464021682739
loss: 0.2587982416152954
loss: 0.25882014632225037
loss: 0.2587722837924957
loss: 0.25879397988319397
loss: 0.25874656438827515
loss: 0.2587681710720062
loss: 0.2587210536003113
loss: 0.2587425112724304
loss: 0.2586956322193146
loss: 0.2587169110774994
loss: 0.2586704194545746
loss: 0.2586914896965027
loss: 0.2586453855037689
loss: 0.25866639614105225
loss: 0.25862064957618713
loss: 0.25864145159721375
loss: 0.2585960030555725
loss: 0.25861677527427673
loss: 0.25857165455818176
loss: 0.25859224796295166
loss: 0.2585473954677582
loss: 0.25856781005859375
loss: 0.25852328538894653
loss: 0.2585436701774597
loss: 0.2584993839263916
loss: 0.25851958990097046
loss: 0.258475661277771
loss: 0.2584957778453827
loss: 0.2584521174430847
loss: 0.2584720551967621
loss: 0.2584286630153656
loss: 0.25844845175743103
loss: 0.25840532779693604
loss: 0.2584250271320343
loss: 0.25838226079940796
loss: 0.2584018111228943
loss: 0.25835925340652466
loss: 0.25837865471839905
loss: 0.2583363950252533
loss: 0.2583557367324829
loss: 0.2583138346672058
loss: 0.2583330273628235
loss: 0.2582913339138031
loss: 0.2583104372024536
loss: 0.25826898217201233
loss: 0.2582879066467285
loss: 0.2582467496395111
loss: 0.25826555490493774
loss: 0.2582246661186218
loss: 0.2582434415817261
loss: 0.25820282101631165
loss: 0.25822147727012634
loss: 0.2581811249256134
loss: 0.25819963216781616
loss: 0.25815948843955994
loss: 0.25817790627479553
loss: 0.2581380009651184
loss: 0.2581562399864197
loss: 0.25811663269996643
loss: 0.25813478231430054
loss: 0.258095383644104
loss: 0.25811341404914856
loss: 0.25807422399520874
loss: 0.25809210538864136
loss: 0.2580532431602478
loss: 0.25807100534439087
loss: 0.25803232192993164
loss: 0.25805002450942993
loss: 0.2580115497112274
loss: 0.25802910327911377
loss: 0.25799089670181274
loss: 0.2580083906650543
loss: 0.2579704523086548
loss: 0.2579877972602844
loss: 0.2579500377178192
loss: 0.25796735286712646
loss: 0.25792986154556274
loss: 0.25794705748558044
loss: 0.25790974497795105
loss: 0.2579267919063568
loss: 0.2578897178173065
loss: 0.25790664553642273
loss: 0.2578698694705963
loss: 0.25788670778274536
loss: 0.2578500807285309
loss: 0.25786685943603516
loss: 0.2578304708003998
loss: 0.2578470706939697
loss: 0.25781092047691345
loss: 0.25782743096351624
loss: 0.25779151916503906
loss: 0.2578079104423523
loss: 0.25777220726013184
loss: 0.2577885091304779
loss: 0.2577529549598694
loss: 0.2577691078186035
loss: 0.25773385167121887
loss: 0.2577499747276306
loss: 0.2577148973941803
loss: 0.2577309012413025
loss: 0.2576960325241089
loss: 0.25771191716194153
loss: 0.25767722725868225
loss: 0.2576930522918701
loss: 0.25765860080718994
loss: 0.25767427682876587
loss: 0.2576400339603424
loss: 0.25765562057495117
loss: 0.2576215863227844
loss: 0.257637083530426
loss: 0.257603257894516
loss: 0.2576186954975128
loss: 0.2575850486755371
loss: 0.2576003670692444
loss: 0.2575669288635254
loss: 0.2575821280479431
loss: 0.25754886865615845
loss: 0.2575639486312866
loss: 0.2575308680534363
loss: 0.2575458288192749
loss: 0.2575129568576813
loss: 0.25752782821655273
loss: 0.2574951648712158
loss: 0.2575099766254425
loss: 0.2574774920940399
loss: 0.25749218463897705
loss: 0.2574598491191864
loss: 0.25747448205947876
loss: 0.2574423551559448
loss: 0.2574569284915924
loss: 0.2574249804019928
loss: 0.25743940472602844
loss: 0.25740766525268555
loss: 0.2574220299720764
loss: 0.25739049911499023
loss: 0.25740474462509155
loss: 0.2573733627796173
loss: 0.25738751888275146
loss: 0.25735631585121155
loss: 0.2573704123497009
loss: 0.25733938813209534
loss: 0.2573533356189728
loss: 0.2573225200176239
loss: 0.25733643770217896
loss: 0.2573058009147644
loss: 0.2573196589946747
loss: 0.2572891414165497
loss: 0.2573028802871704
loss: 0.2572726011276245
loss: 0.2572862207889557
loss: 0.2572561204433441
loss: 0.2572696805000305
loss: 0.2572397291660309
loss: 0.25725314021110535
loss: 0.2572233974933624
loss: 0.2572367191314697
loss: 0.25720709562301636
loss: 0.2572203278541565
loss: 0.2571909427642822
loss: 0.2572040557861328
loss: 0.2571748197078705
loss: 0.2571878731250763
loss: 0.2571588158607483
loss: 0.2571718096733093
loss: 0.25714290142059326
loss: 0.25715580582618713
loss: 0.257127046585083
loss: 0.2571398615837097
loss: 0.2571112811565399
loss: 0.25712400674819946
loss: 0.2570956349372864
loss: 0.25710830092430115
loss: 0.257080078125
loss: 0.2570926249027252
loss: 0.2570645809173584
loss: 0.25707700848579407
loss: 0.2570491433143616
loss: 0.25706154108047485
loss: 0.2570338249206543
loss: 0.2570461332798004
loss: 0.2570185959339142
loss: 0.25703075528144836
loss: 0.2570033371448517
loss: 0.25701549649238586
loss: 0.2569882273674011
loss: 0.25700023770332336
loss: 0.25697314739227295
loss: 0.2569850981235504
loss: 0.25695815682411194
loss: 0.25696998834609985
loss: 0.2569432258605957
loss: 0.25695499777793884
loss: 0.25692838430404663
loss: 0.256940096616745
loss: 0.2569136619567871
loss: 0.2569253146648407
loss: 0.25689905881881714
loss: 0.2569105923175812
loss: 0.25688445568084717
loss: 0.25689584016799927
loss: 0.256869912147522
loss: 0.2568812668323517
loss: 0.25685542821884155
loss: 0.25686678290367126
loss: 0.25684109330177307
loss: 0.2568523585796356
loss: 0.2568267583847046
loss: 0.2568378746509552
loss: 0.25681251287460327
loss: 0.2568235993385315
loss: 0.25679832696914673
loss: 0.2568093538284302
loss: 0.25678423047065735
loss: 0.25679516792297363
loss: 0.2567702531814575
loss: 0.25678113102912903
loss: 0.2567563056945801
loss: 0.25676706433296204
loss: 0.2567424178123474
loss: 0.2567531168460846
loss: 0.2567285895347595
loss: 0.2567392587661743
loss: 0.2567148208618164
loss: 0.25672537088394165
loss: 0.25670114159584045
loss: 0.2567116618156433
loss: 0.25668755173683167
loss: 0.25669801235198975
loss: 0.25667405128479004
loss: 0.25668442249298096
loss: 0.2566606104373932
loss: 0.25667083263397217
loss: 0.25664710998535156
loss: 0.25665736198425293
loss: 0.25663378834724426
loss: 0.2566438615322113
loss: 0.25662049651145935
loss: 0.2566305696964264
loss: 0.2566072344779968
loss: 0.2566172182559967
loss: 0.25659406185150146
loss: 0.2566039562225342
loss: 0.2565809488296509
loss: 0.2565908133983612
loss: 0.256568044424057
loss: 0.25657781958580017
loss: 0.25655508041381836
loss: 0.25656476616859436
loss: 0.2565421462059021
loss: 0.2565518319606781
loss: 0.256529301404953
loss: 0.25653892755508423
loss: 0.2565165162086487
loss: 0.25652599334716797
loss: 0.25650373101234436
loss: 0.25651314854621887
loss: 0.2564910352230072
loss: 0.2565004825592041
loss: 0.25647851824760437
loss: 0.25648781657218933
loss: 0.25646597146987915
loss: 0.2564752399921417
loss: 0.2564534544944763
loss: 0.2564626634120941
loss: 0.25644105672836304
loss: 0.25645023584365845
loss: 0.25642871856689453
loss: 0.2564377784729004
loss: 0.2564164400100708
loss: 0.25642549991607666
loss: 0.2564042806625366
loss: 0.25641322135925293
loss: 0.2563920319080353
loss: 0.2564009428024292
loss: 0.2563799321651459
loss: 0.25638875365257263
loss: 0.25636789202690125
loss: 0.2563766539096832
loss: 0.256355881690979
loss: 0.256364643573761
loss: 0.2563440501689911
loss: 0.2563526928424835
loss: 0.256332129240036
loss: 0.25634080171585083
loss: 0.2563203275203705
loss: 0.25632891058921814
loss: 0.25630855560302734
loss: 0.2563171088695526
loss: 0.25629687309265137
loss: 0.2563053071498871
loss: 0.25628525018692017
loss: 0.2562935948371887
loss: 0.2562735676765442
loss: 0.25628191232681274
loss: 0.25626206398010254
loss: 0.2562703490257263
loss: 0.2562505602836609
loss: 0.2562587857246399
loss: 0.2562391757965088
loss: 0.256247341632843
loss: 0.2562277615070343
loss: 0.25623592734336853
loss: 0.25621649622917175
loss: 0.2562245726585388
loss: 0.2562052011489868
loss: 0.2562131881713867
loss: 0.25619396567344666
loss: 0.25620192289352417
loss: 0.2561827600002289
loss: 0.2561906576156616
loss: 0.2561716139316559
loss: 0.25617942214012146
loss: 0.2561604976654053
loss: 0.25616830587387085
loss: 0.2561494708061218
loss: 0.2561572194099426
loss: 0.2561384439468384
0.2561384439468384Loss approaches ~0.256. Why not 0 (complete overfit)?
This is because, in our training batch of only 5 names.
- Take this example: We have first token predictions:
'...' -> 'e','...' -> 'o','...' -> 'a','...' -> 'i','...' -> 's'
e,o,a,i,sare all valid outcomes (next tokens) for the exact same input...(the first token)- This is why we cannot get
loss = 0Note however, when there is a unique input for a unique output, the predicted index (
logits.max(1)below) does match the labelY
- In these cases we overfit, getting the exact result
# i - compare model prediction indices vs desired label indices
print(logits.max(1)) # maximum value = model's prediction of next token
print(Y) # desired labels (correct answer)torch.return_types.max(
values=tensor([13.3348, 17.7905, 20.6013, 20.6120, 16.7355, 13.3348, 15.9984, 14.1723,
15.9146, 18.3614, 15.9396, 20.9265, 13.3348, 17.1089, 17.1319, 20.0601,
13.3348, 16.5892, 15.1017, 17.0581, 18.5861, 15.9670, 10.8740, 10.6871,
15.5056, 13.3348, 16.1794, 16.9743, 12.7426, 16.2008, 19.0846, 16.0195],
grad_fn=<MaxBackward0>),
indices=tensor([19, 13, 13, 1, 0, 19, 12, 9, 22, 9, 1, 0, 19, 22, 1, 0, 19, 19,
1, 2, 5, 12, 12, 1, 0, 19, 15, 16, 8, 9, 1, 0]))
tensor([ 5, 13, 13, 1, 0, 15, 12, 9, 22, 9, 1, 0, 1, 22, 1, 0, 9, 19,
1, 2, 5, 12, 12, 1, 0, 19, 15, 16, 8, 9, 1, 0])Train full dataset
# rebuild data set for all 32,033 names (228,146 training examples)
block_size = 3 # context length: how many characters do we take to predict the next one?
X, Y = [], [] # X: NN input training examples, Y: labels for each input in X
for w in words:
# print(w)
context = [0] * block_size
for ch in w + '.':
ix = stoi[ch]
X.append(context)
Y.append(ix)
# print(''.join(itos[i] for i in context), '->', itos[ix])
context = context[1:] + [ix] # crop and append
X = torch.tensor(X)
Y = torch.tensor(Y)
print('\nX.shape:', X.shape, 'X.dtype:', X.dtype)
print('Y.shape:', Y.shape, 'Y.dtype:', Y.dtype)X.shape: torch.Size([228146, 3]) X.dtype: torch.int64
Y.shape: torch.Size([228146]) Y.dtype: torch.int64# redefine SAME 3,481 params (with grads!): C, W1, b1, W2, b2
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 2), generator=g) # embedding matrix (lookup table for input tokens)
W1 = torch.randn((6, 100), generator=g) # hidden layer's incoming weights: 6 inputs to layer, 100 hidden neurons in layer
b1 = torch.randn(100, generator=g) # 100 biases live "in" hidden layer's neurons
W2 = torch.randn((100, 27), generator=g) # output layer's incoming weights: 100 inputs to layer, 27 output neurons in layer
b2 = torch.randn(27, generator=g) # 27 biases live "in" output layer's neurons
parameters = [C, W1, b1, W2, b2] # list of all parameters (makes easier to count)
print('num. of parameters:', sum(p.nelement() for p in parameters)) # total parameter count in network
# ensure all 3,481 parameters have gradient (to enable optimisation)
for p in parameters:
p.requires_grad = Truenum. of parameters: 3481Training on 228,146 examples is very slow
Inspect below code cell and output. Ten iterations of forward & backward passes are being computed on all 228,146 examples.
# 10 training iters on full data set (228,146 training examples, super slow!)
for _ in range(10):
# forward pass
emb = C[X] # (228146, 3, 2) -> (228146, 6) on next line via emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (228146, 100)
logits = h @ W2 + b2 # (228146, 27)
loss = F.cross_entropy(logits, Y) # simpler!
print('loss:', loss.item())
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent (update parameters)
for p in parameters:
p.data += -0.1 * p.grad
print(loss.item())loss: 19.505229949951172
loss: 17.084491729736328
loss: 15.776531219482422
loss: 14.83333683013916
loss: 14.002594947814941
loss: 13.253252029418945
loss: 12.579911231994629
loss: 11.983097076416016
loss: 11.470491409301758
loss: 11.051854133605957
11.051854133605957Train on mini batches instead: stochastic gradient descent (SGD)
- Construct a mini-batch: randomly select a portion of the data.
ix = torch.randint(0, X.shape[0], (32,)): creates a mini-batch of 32 random training examples:ix0,1, …,228,145
- Only perform training iterations (forward and backward passes) on those mini-batches
- I.e. use those 32 integer indexes
ixto index as follows:X[ix]andY[ix]
- I.e. use those 32 integer indexes
Trade-off: Why is using less (and random) data on each iteration acceptable?
- Better to have an approximate gradient and take many gradient-descent steps
- Rather than evaluating the exact gradient, but only take few steps
Each iter only improves loss for that mini-batch. So how does the overall model improve?
- Each mini-batch is a random sample of the full dataset (32/228,146), so its gradient is an unbiased estimate of the true gradient
- Improving loss on a random mini-batch still nudges weights in a direction that reduces loss on average across the full dataset
- The per-batch loss is noisy — it can increase on individual steps if that batch is poorly represented by the current weights. This is expected.
- Evaluate true progress by running the full dataset through the model periodically — this is the real signal, not the noisy per-batch loss
# i - train on mini-batches (fast!): each iter choose only 32 randonm training examples.
for _ in range(1000):
# minibatch construct:
ix = torch.randint(0, X.shape[0], (32,))
# forward pass
emb = C[X[ix]] # (228146, 3, 2) -> now mini batch (32, 3, 2) -> (32, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y[ix])
print(loss.item())
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent update
for p in parameters:
p.data += -0.1 * p.grad
#print(loss.item())9.896810531616211
8.759634971618652
11.592778205871582
11.068772315979004
10.387434959411621
10.323399543762207
11.140151023864746
7.260660171508789
7.957258224487305
8.01070785522461
9.025019645690918
8.891111373901367
7.39719820022583
6.644747734069824
8.270305633544922
6.499406814575195
8.376777648925781
6.285613059997559
9.202131271362305
5.413596153259277
6.618648529052734
6.646180629730225
5.162105083465576
8.965693473815918
7.714197635650635
4.9152727127075195
7.475235939025879
6.2279205322265625
5.40147066116333
5.672036170959473
5.691686153411865
6.316040992736816
6.067322731018066
6.016808986663818
5.73307466506958
5.246915817260742
6.349338531494141
5.006288051605225
6.079338073730469
6.373358726501465
5.339603900909424
5.850221157073975
5.202524662017822
4.307728290557861
5.587084770202637
5.449291229248047
5.099900722503662
4.610676288604736
4.018831729888916
4.817431449890137
4.432487487792969
4.263952732086182
4.487020015716553
4.830043315887451
6.071441173553467
5.0223002433776855
3.982974052429199
4.796706676483154
3.500162124633789
4.261867523193359
3.0264978408813477
5.300243854522705
4.843149185180664
5.7487335205078125
4.641740798950195
3.6821606159210205
4.965819835662842
4.268279075622559
3.366011381149292
4.834683895111084
3.853651523590088
3.342064142227173
3.1142499446868896
4.165676116943359
3.5957517623901367
3.772942543029785
4.480611801147461
4.183704853057861
3.3376612663269043
4.019365310668945
3.259702682495117
2.9081497192382812
5.518255233764648
4.471907615661621
4.194230556488037
4.130226135253906
3.7943551540374756
3.4184207916259766
4.17021369934082
4.258994102478027
3.5719900131225586
3.672661542892456
3.1804776191711426
3.8619909286499023
3.5704727172851562
4.372241020202637
3.983081102371216
4.664920806884766
4.076390266418457
4.2567009925842285
3.8436975479125977
2.4638304710388184
4.453227519989014
4.668423175811768
3.5303163528442383
3.8110783100128174
3.868636131286621
3.2786717414855957
3.2764527797698975
3.221223831176758
3.133331298828125
3.7935047149658203
3.537815570831299
3.42557954788208
3.76025652885437
3.5903851985931396
3.5028648376464844
3.514338254928589
3.394477367401123
4.051280498504639
3.321415424346924
2.7149994373321533
3.6025354862213135
3.3464596271514893
2.935556173324585
3.60280179977417
3.063647985458374
4.014288902282715
3.2054014205932617
3.9523463249206543
3.497669219970703
2.8833019733428955
3.0994889736175537
3.205845832824707
3.36759614944458
3.2073140144348145
3.0511221885681152
3.4791414737701416
3.178461790084839
4.093668460845947
4.23942232131958
3.7378413677215576
3.526540994644165
3.667722463607788
3.183119058609009
2.904029369354248
3.0047109127044678
2.7208499908447266
3.049098014831543
2.4678399562835693
3.6362173557281494
3.7326443195343018
3.2742226123809814
2.6617937088012695
5.058605670928955
3.0671744346618652
2.958848237991333
2.8383092880249023
3.019517183303833
2.6797537803649902
2.8744232654571533
3.200916290283203
3.675915002822876
2.872210741043091
3.4541170597076416
2.2041306495666504
3.6010208129882812
4.066131114959717
3.912446975708008
3.265887498855591
2.653804302215576
2.4092042446136475
2.894303560256958
2.7461187839508057
2.8898096084594727
2.9744374752044678
2.730070114135742
2.927487850189209
3.14833402633667
3.021498441696167
2.8947317600250244
3.1195359230041504
2.7545621395111084
3.0239453315734863
2.749570846557617
3.9576308727264404
2.496288776397705
3.036187171936035
3.2281534671783447
2.8846988677978516
3.1203649044036865
2.840456962585449
3.495826482772827
3.3081459999084473
3.4253246784210205
2.944472312927246
3.0015571117401123
2.3141558170318604
3.040396213531494
3.2636289596557617
2.976789712905884
3.237525463104248
3.3497140407562256
2.861330986022949
2.9446794986724854
2.933791160583496
2.9108498096466064
3.543982982635498
3.1424002647399902
3.0284316539764404
2.9535648822784424
2.956084966659546
3.1415021419525146
3.4627203941345215
3.6647329330444336
3.937263250350952
3.179900884628296
2.859182834625244
2.704624891281128
3.418658494949341
3.528203248977661
2.6758017539978027
3.116621971130371
3.063347578048706
3.178053379058838
3.0595521926879883
2.7290070056915283
2.812138319015503
2.7494921684265137
3.226741075515747
2.9392971992492676
3.0727202892303467
3.1728744506835938
2.7274012565612793
2.9503462314605713
2.8683552742004395
3.0513339042663574
2.9707112312316895
2.6815102100372314
2.740510940551758
3.537189245223999
3.127800464630127
3.099019765853882
3.020798921585083
2.9684953689575195
2.711578845977783
3.4158010482788086
3.00701904296875
2.4661059379577637
2.666820526123047
3.063178777694702
2.7668216228485107
3.586533784866333
2.8264293670654297
2.7572455406188965
2.709512948989868
2.7539403438568115
2.648860216140747
3.229961633682251
2.5421853065490723
3.1651437282562256
3.0915117263793945
2.756957530975342
3.3214855194091797
3.0397238731384277
2.5114879608154297
2.9002153873443604
2.829801321029663
2.801156759262085
2.7075796127319336
2.6793105602264404
2.5468697547912598
3.077357292175293
2.9921865463256836
3.0306408405303955
2.149855613708496
3.124777317047119
3.101353406906128
3.4048004150390625
3.280181407928467
2.3473243713378906
3.051095485687256
3.0384323596954346
2.493790626525879
3.515423536300659
3.1917965412139893
3.1285240650177
2.7693376541137695
3.4992458820343018
2.697481393814087
2.459474802017212
2.7460741996765137
2.770843744277954
3.193124771118164
2.692457914352417
2.899172782897949
2.811448097229004
2.8431241512298584
2.7656264305114746
3.153595447540283
3.3221263885498047
2.5638771057128906
3.0605640411376953
3.0287110805511475
2.8528940677642822
3.220609188079834
3.1010849475860596
2.944861888885498
2.801788568496704
3.139026641845703
2.934326410293579
3.273608922958374
2.8258252143859863
2.674679756164551
2.6728427410125732
2.9463179111480713
2.518956422805786
2.9208197593688965
2.800133228302002
3.0589418411254883
3.1667275428771973
2.8311831951141357
3.2277286052703857
2.9714841842651367
3.118936777114868
2.6565418243408203
2.7429494857788086
2.7019565105438232
2.6362626552581787
3.1371607780456543
2.911802053451538
2.875478506088257
2.812384843826294
2.989267587661743
2.8367557525634766
2.5542116165161133
3.188161611557007
3.649278163909912
2.8288750648498535
2.6405208110809326
2.793365478515625
2.477688789367676
2.360109806060791
3.1083755493164062
2.423168897628784
2.3363425731658936
3.2094717025756836
3.21278977394104
2.7510223388671875
3.346728801727295
2.7905216217041016
2.462324857711792
3.0638198852539062
3.128446102142334
3.041266441345215
2.7670063972473145
2.730647563934326
3.013957977294922
2.5801825523376465
2.8223071098327637
2.9960579872131348
2.611037492752075
2.8117141723632812
3.0930471420288086
3.0081591606140137
2.922236204147339
2.9449377059936523
3.8429453372955322
2.423759937286377
2.7271344661712646
2.98412823677063
2.8373513221740723
2.6203105449676514
2.7714662551879883
2.4229345321655273
2.5179457664489746
2.775203227996826
2.650952100753784
3.241196870803833
2.6704890727996826
2.9544029235839844
3.097597360610962
2.4761385917663574
2.89194655418396
2.979058027267456
2.802605390548706
2.748673915863037
2.8753387928009033
3.0002334117889404
2.506101369857788
2.9170401096343994
3.2655866146087646
2.5938963890075684
2.647557497024536
2.479513168334961
2.8036015033721924
2.9781651496887207
3.3568625450134277
2.526700258255005
2.6863925457000732
3.224980115890503
2.7909364700317383
3.0206215381622314
3.0171332359313965
2.566675901412964
2.4030442237854004
2.8406951427459717
3.0709095001220703
2.6900479793548584
3.1372532844543457
2.495948314666748
2.8664674758911133
2.486595630645752
2.9442873001098633
2.711088180541992
2.8414223194122314
2.7892277240753174
2.9988858699798584
2.6617612838745117
3.4242820739746094
2.804959774017334
2.625133514404297
2.4288718700408936
2.2583653926849365
2.740039587020874
3.0614168643951416
2.8406479358673096
3.6201775074005127
2.4176714420318604
2.8449172973632812
2.673156499862671
2.3610641956329346
2.6136462688446045
2.553170680999756
2.851689338684082
3.2861268520355225
3.126295566558838
2.891120672225952
2.7307074069976807
2.6744799613952637
2.5196235179901123
2.9542717933654785
2.8362538814544678
3.0418665409088135
2.8237314224243164
2.4917683601379395
2.8155722618103027
3.0890228748321533
2.604759931564331
2.87213134765625
3.149702548980713
2.6819229125976562
2.933257579803467
3.057554244995117
3.0213356018066406
3.3064651489257812
2.602602958679199
3.2336525917053223
2.951561450958252
2.653562307357788
2.5503273010253906
3.365694761276245
2.8546528816223145
2.9421989917755127
2.657466173171997
2.3357720375061035
2.691351890563965
3.302476406097412
2.9883034229278564
2.6113290786743164
3.091611385345459
2.5362229347229004
2.9506754875183105
2.4193758964538574
2.7513821125030518
2.9754669666290283
3.042597532272339
2.450284719467163
2.904916763305664
2.983684539794922
2.600524663925171
2.732783794403076
2.466791868209839
3.0514981746673584
2.955828905105591
2.5935604572296143
2.9114043712615967
3.0641751289367676
2.8431601524353027
2.9540159702301025
2.622328758239746
3.045548439025879
2.8484041690826416
2.9200944900512695
2.4167566299438477
2.7283685207366943
2.4879872798919678
2.8593833446502686
3.263176918029785
2.8305344581604004
2.521033525466919
2.653939962387085
2.8834590911865234
2.3701610565185547
2.603560209274292
2.8074545860290527
2.649172306060791
2.907069206237793
3.0923664569854736
2.6734485626220703
2.402214527130127
2.650343418121338
3.014247417449951
2.9635753631591797
2.5900557041168213
2.9755265712738037
2.7873377799987793
2.4400315284729004
2.8981528282165527
2.735365867614746
2.7150986194610596
2.8842878341674805
2.5263030529022217
2.8434836864471436
2.772948980331421
2.613042116165161
2.803757667541504
2.502037525177002
2.435739040374756
2.7699670791625977
2.934601306915283
3.0123772621154785
3.1064069271087646
2.5451741218566895
2.9909207820892334
2.734304428100586
2.789278507232666
2.8853139877319336
3.2231998443603516
2.685513496398926
3.0063395500183105
2.754470109939575
2.8153233528137207
2.606313705444336
2.3381476402282715
2.640493392944336
3.103374481201172
2.744060754776001
3.130049228668213
2.6109025478363037
2.7368416786193848
2.385883092880249
2.8128161430358887
3.3043272495269775
2.5240938663482666
2.6730563640594482
2.5856165885925293
2.5473716259002686
2.8130249977111816
2.586909055709839
2.6566755771636963
2.6260454654693604
3.5031487941741943
2.9940717220306396
3.234447956085205
2.520303964614868
2.8471145629882812
2.8430490493774414
2.7167861461639404
2.836732864379883
2.529829263687134
2.8094842433929443
2.5793020725250244
3.0816211700439453
2.9745800495147705
3.2287893295288086
2.8400585651397705
2.5870628356933594
2.380269765853882
2.9502830505371094
2.441178798675537
2.8060717582702637
3.0700488090515137
3.1018474102020264
2.7296385765075684
2.783068895339966
2.839552164077759
2.7452502250671387
2.6400187015533447
2.707057476043701
3.1021320819854736
2.3665144443511963
2.9257562160491943
2.4008166790008545
3.0618815422058105
2.776148557662964
2.741525411605835
2.682119131088257
2.5620152950286865
3.0327794551849365
2.4293646812438965
2.4183549880981445
2.8899903297424316
2.6891300678253174
2.7025129795074463
2.656099557876587
2.7058229446411133
3.128572940826416
2.68647837638855
2.4446358680725098
2.4517433643341064
2.7017769813537598
3.250014305114746
3.2261407375335693
2.414384603500366
2.594815492630005
2.8181443214416504
2.7489912509918213
2.8305516242980957
2.755126714706421
3.0641684532165527
2.270047903060913
2.6844677925109863
2.6547999382019043
2.8762452602386475
2.515949010848999
2.6972715854644775
2.9116666316986084
2.6422712802886963
2.891150951385498
2.7751808166503906
3.0885448455810547
2.875307083129883
3.0005669593811035
2.8435306549072266
2.3729004859924316
2.8083977699279785
2.7563390731811523
3.15527081489563
2.333895444869995
2.810551404953003
2.484579563140869
2.5797982215881348
2.6939685344696045
2.875403881072998
2.7660317420959473
3.2092552185058594
2.764976978302002
2.6974809169769287
2.7561378479003906
2.886259078979492
2.770718574523926
2.5285937786102295
2.7154226303100586
2.419149160385132
2.8320226669311523
2.5858614444732666
2.695986747741699
2.723477363586426
2.6743531227111816
2.6664061546325684
3.157430648803711
3.060312509536743
2.6531753540039062
2.9731881618499756
2.568183422088623
2.650225877761841
2.7373619079589844
2.7570858001708984
2.477444648742676
3.0774307250976562
2.8599650859832764
2.8250732421875
2.7980475425720215
3.0908942222595215
2.804319381713867
3.109673023223877
2.644010543823242
2.8532156944274902
2.668830633163452
2.9628138542175293
2.712195634841919
2.6231894493103027
2.519824981689453
2.334705352783203
2.853483200073242
2.8499176502227783
2.9370014667510986
2.489332675933838
2.9633302688598633
2.6852188110351562
2.4603424072265625
2.580900192260742
2.6431949138641357
2.601072072982788
2.425776958465576
2.653900623321533
2.8092401027679443
2.6398069858551025
3.00099515914917
2.773799419403076
2.535935878753662
2.934901714324951
2.7223825454711914
2.734408378601074
2.716212511062622
3.4013350009918213
2.780379295349121
2.6602838039398193
2.7087533473968506
2.8791139125823975
3.29168438911438
2.766439199447632
2.676079750061035
2.6400153636932373
2.7618298530578613
2.4963574409484863
2.999830722808838
3.214092493057251
2.649585485458374
2.4845101833343506
2.934107780456543
2.7247114181518555
2.6169190406799316
2.397282123565674
2.320159912109375
2.748965263366699
2.6604888439178467
2.3635146617889404
2.8323264122009277
2.5858750343322754
2.506159543991089
2.877027750015259
2.6443891525268555
2.7948713302612305
2.8400754928588867
2.7647511959075928
2.513339042663574
2.8992671966552734
2.783646821975708
2.722789764404297
2.510343551635742
2.9114437103271484
2.4318389892578125
2.9251999855041504
2.5252110958099365
2.6348750591278076
2.5415241718292236
2.541621208190918
3.0831685066223145
2.5054948329925537
2.4908108711242676
2.5855047702789307
3.06794810295105
2.9269113540649414
2.5846261978149414
2.4535410404205322
3.046879768371582
3.2061612606048584
2.8764572143554688
2.9108567237854004
2.3659865856170654
2.7416272163391113
2.9056506156921387
3.0862491130828857
2.5441653728485107
2.6346685886383057
2.3550243377685547
2.4558863639831543
2.6740224361419678
2.7657010555267334
2.806293249130249
2.6368935108184814
2.8264737129211426
2.4168875217437744
2.7799265384674072
2.6607043743133545
2.4593071937561035
2.8638193607330322
2.826246976852417
2.4056036472320557
2.74528169631958
2.4907267093658447
2.9789328575134277
2.587707042694092
2.6576614379882812
2.322974681854248
2.270209550857544
2.6799874305725098
2.7784082889556885
2.999915599822998
3.1009819507598877
2.652575969696045
2.625828742980957
2.4091482162475586
2.744112014770508
2.493551015853882
2.425532102584839
2.663930892944336
2.35817813873291
2.4924540519714355
2.4026436805725098
2.635939121246338
3.0328078269958496
2.9643445014953613
2.8912580013275146
2.3784120082855225
3.061020851135254
2.750985860824585
2.5587549209594727
2.966980218887329
2.6289196014404297
2.7197251319885254
3.067807912826538
2.480665922164917
2.3556296825408936
2.383138656616211
2.7257237434387207
2.9191646575927734
2.91636061668396
2.5764200687408447
2.5280165672302246
2.823317289352417
2.5287413597106934
2.8862907886505127
2.6677560806274414
2.9603354930877686
2.446056604385376
2.9302942752838135
2.7376177310943604
2.851625680923462
2.3450732231140137
2.3822972774505615
2.5803096294403076
2.8538990020751953
2.6416022777557373
2.842442512512207
3.2319693565368652
2.936274766921997
2.4621379375457764
2.9171547889709473
2.8074557781219482
2.589831829071045
2.743964910507202
3.1269636154174805
2.4186339378356934
2.8631837368011475
2.6985867023468018
2.4910881519317627
2.580343723297119
2.826159715652466
2.823943614959717
2.9076876640319824
2.813218355178833
3.038674831390381
2.3644509315490723
3.466707229614258
2.4905290603637695
2.6377203464508057
2.474961996078491
2.6381521224975586
3.0348145961761475
3.000978946685791
2.2717082500457764
2.6861064434051514
2.339785575866699
2.7677054405212402
2.566650867462158
2.82352876663208
2.5481903553009033
2.264493465423584
2.620990037918091
3.3856148719787598
2.350464105606079
2.9460227489471436
3.091808319091797
2.683084726333618
2.879528284072876
2.7280008792877197
2.5575592517852783
2.4253904819488525
3.252856969833374
2.4208173751831055
2.617271661758423
2.4175827503204346
2.8102872371673584
3.0430707931518555
2.6323912143707275
3.17130708694458
2.7600882053375244
3.0557408332824707
2.781604766845703
2.6214942932128906
2.8399405479431152
2.4229655265808105
2.772794008255005
2.407109260559082
2.462512493133545
2.833082437515259
2.8265151977539062
2.907477855682373
2.2429749965667725
2.8426103591918945
2.9321084022521973
2.6651856899261475
2.620706558227539
2.4702324867248535
2.52337646484375
2.9235000610351562
2.9373154640197754
2.499429702758789
2.6651432514190674
2.9835219383239746
3.2638299465179443
2.6007778644561768
2.6925697326660156
3.037665605545044
2.824467182159424
2.5210254192352295
2.349212169647217
2.662122964859009
2.3288941383361816
2.955582857131958
2.6809308528900146
2.526909589767456
2.8325252532958984
2.5581459999084473
2.699291467666626
2.6388683319091797
2.0503666400909424
2.971100330352783
2.4441330432891846
2.6144707202911377
2.6514711380004883
2.4083127975463867
2.7190165519714355
2.5655298233032227
2.5385727882385254
2.291931390762329
2.7223732471466064
2.954103946685791
3.1484429836273193
2.6533401012420654
2.488473892211914
2.768602132797241
3.106825351715088
2.7356529235839844
2.4345545768737793
3.125375509262085
2.6266744136810303
2.526561975479126
2.7951626777648926
2.5287058353424072
2.840695858001709
2.6153979301452637
2.8847320079803467
2.8346080780029297
2.263906478881836
2.949612617492676
2.7547426223754883
2.755993127822876
2.4086852073669434
2.6165666580200195
2.421797513961792
2.8673033714294434
2.6828908920288086
2.7150115966796875
2.7194149494171143
2.959629535675049
2.8854963779449463
2.246046304702759
3.2643508911132812
2.6965620517730713
2.6458795070648193
2.7794108390808105
3.076519012451172
2.609966993331909
2.784975528717041
2.762497663497925
2.853363513946533
2.900803804397583
2.464604377746582
2.6966464519500732
3.004682779312134
2.8297500610351562
2.6411030292510986
2.966417074203491
3.1526334285736084
2.761204242706299
2.5767288208007812
2.506442070007324
2.8677549362182617
2.8159005641937256
2.8601748943328857
2.6069583892822266
2.6428635120391846
2.7156057357788086
2.670398473739624
2.873629093170166
2.327685832977295# periodically forward pass FULL DATASET: clean loss number showing true model progress
emb = C[X] # (228146, 3, 2) -> (228146, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (228146, 100)
logits = h @ W2 + b2 # (228146, 27)
loss = F.cross_entropy(logits, Y)
print(loss.item())2.6633293628692627Learning rate parameter
Choosing suitable initial learning rate
Gradient descent updates p.data += -lr * p.grad are done by a learning rate parameter lr
- There are risks to choosing a fixed
lrvalue- Too small (e.g.
0.0001): Descent toward local minima takes forever - Too lage (e.g.
1or10): Massively overshoot the minima every time. Not optimising anything.
- Too small (e.g.
- Instead, choose
lrsby indexing from a curve of exponents10**lre:
# i - systematically choose learning rates
lre = torch.linspace(-3, 0, 1000) # lr exponents: 1000 linear steps between -3 and 0
lrs = 10**lre # exponentiate: 1000 linear steps between 10^-3 and 1
lrstensor([0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0011,
0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,
0.0011, 0.0011, 0.0011, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012,
0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0013, 0.0013, 0.0013,
0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0014,
0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014,
0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
0.0016, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017,
0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019,
0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020,
0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021, 0.0021, 0.0021, 0.0021,
0.0021, 0.0021, 0.0021, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022,
0.0022, 0.0023, 0.0023, 0.0023, 0.0023, 0.0023, 0.0023, 0.0024, 0.0024,
0.0024, 0.0024, 0.0024, 0.0024, 0.0025, 0.0025, 0.0025, 0.0025, 0.0025,
0.0025, 0.0026, 0.0026, 0.0026, 0.0026, 0.0026, 0.0027, 0.0027, 0.0027,
0.0027, 0.0027, 0.0027, 0.0028, 0.0028, 0.0028, 0.0028, 0.0028, 0.0029,
0.0029, 0.0029, 0.0029, 0.0029, 0.0030, 0.0030, 0.0030, 0.0030, 0.0030,
0.0031, 0.0031, 0.0031, 0.0031, 0.0032, 0.0032, 0.0032, 0.0032, 0.0032,
0.0033, 0.0033, 0.0033, 0.0033, 0.0034, 0.0034, 0.0034, 0.0034, 0.0034,
0.0035, 0.0035, 0.0035, 0.0035, 0.0036, 0.0036, 0.0036, 0.0036, 0.0037,
0.0037, 0.0037, 0.0037, 0.0038, 0.0038, 0.0038, 0.0039, 0.0039, 0.0039,
0.0039, 0.0040, 0.0040, 0.0040, 0.0040, 0.0041, 0.0041, 0.0041, 0.0042,
0.0042, 0.0042, 0.0042, 0.0043, 0.0043, 0.0043, 0.0044, 0.0044, 0.0044,
0.0045, 0.0045, 0.0045, 0.0045, 0.0046, 0.0046, 0.0046, 0.0047, 0.0047,
0.0047, 0.0048, 0.0048, 0.0048, 0.0049, 0.0049, 0.0049, 0.0050, 0.0050,
0.0050, 0.0051, 0.0051, 0.0051, 0.0052, 0.0052, 0.0053, 0.0053, 0.0053,
0.0054, 0.0054, 0.0054, 0.0055, 0.0055, 0.0056, 0.0056, 0.0056, 0.0057,
0.0057, 0.0058, 0.0058, 0.0058, 0.0059, 0.0059, 0.0060, 0.0060, 0.0060,
0.0061, 0.0061, 0.0062, 0.0062, 0.0062, 0.0063, 0.0063, 0.0064, 0.0064,
0.0065, 0.0065, 0.0066, 0.0066, 0.0067, 0.0067, 0.0067, 0.0068, 0.0068,
0.0069, 0.0069, 0.0070, 0.0070, 0.0071, 0.0071, 0.0072, 0.0072, 0.0073,
0.0073, 0.0074, 0.0074, 0.0075, 0.0075, 0.0076, 0.0076, 0.0077, 0.0077,
0.0078, 0.0079, 0.0079, 0.0080, 0.0080, 0.0081, 0.0081, 0.0082, 0.0082,
0.0083, 0.0084, 0.0084, 0.0085, 0.0085, 0.0086, 0.0086, 0.0087, 0.0088,
0.0088, 0.0089, 0.0090, 0.0090, 0.0091, 0.0091, 0.0092, 0.0093, 0.0093,
0.0094, 0.0095, 0.0095, 0.0096, 0.0097, 0.0097, 0.0098, 0.0099, 0.0099,
0.0100, 0.0101, 0.0101, 0.0102, 0.0103, 0.0104, 0.0104, 0.0105, 0.0106,
0.0106, 0.0107, 0.0108, 0.0109, 0.0109, 0.0110, 0.0111, 0.0112, 0.0112,
0.0113, 0.0114, 0.0115, 0.0116, 0.0116, 0.0117, 0.0118, 0.0119, 0.0120,
0.0121, 0.0121, 0.0122, 0.0123, 0.0124, 0.0125, 0.0126, 0.0127, 0.0127,
0.0128, 0.0129, 0.0130, 0.0131, 0.0132, 0.0133, 0.0134, 0.0135, 0.0136,
0.0137, 0.0137, 0.0138, 0.0139, 0.0140, 0.0141, 0.0142, 0.0143, 0.0144,
0.0145, 0.0146, 0.0147, 0.0148, 0.0149, 0.0150, 0.0151, 0.0152, 0.0154,
0.0155, 0.0156, 0.0157, 0.0158, 0.0159, 0.0160, 0.0161, 0.0162, 0.0163,
0.0165, 0.0166, 0.0167, 0.0168, 0.0169, 0.0170, 0.0171, 0.0173, 0.0174,
0.0175, 0.0176, 0.0178, 0.0179, 0.0180, 0.0181, 0.0182, 0.0184, 0.0185,
0.0186, 0.0188, 0.0189, 0.0190, 0.0192, 0.0193, 0.0194, 0.0196, 0.0197,
0.0198, 0.0200, 0.0201, 0.0202, 0.0204, 0.0205, 0.0207, 0.0208, 0.0210,
0.0211, 0.0212, 0.0214, 0.0215, 0.0217, 0.0218, 0.0220, 0.0221, 0.0223,
0.0225, 0.0226, 0.0228, 0.0229, 0.0231, 0.0232, 0.0234, 0.0236, 0.0237,
0.0239, 0.0241, 0.0242, 0.0244, 0.0246, 0.0247, 0.0249, 0.0251, 0.0253,
0.0254, 0.0256, 0.0258, 0.0260, 0.0261, 0.0263, 0.0265, 0.0267, 0.0269,
0.0271, 0.0273, 0.0274, 0.0276, 0.0278, 0.0280, 0.0282, 0.0284, 0.0286,
0.0288, 0.0290, 0.0292, 0.0294, 0.0296, 0.0298, 0.0300, 0.0302, 0.0304,
0.0307, 0.0309, 0.0311, 0.0313, 0.0315, 0.0317, 0.0320, 0.0322, 0.0324,
0.0326, 0.0328, 0.0331, 0.0333, 0.0335, 0.0338, 0.0340, 0.0342, 0.0345,
0.0347, 0.0350, 0.0352, 0.0354, 0.0357, 0.0359, 0.0362, 0.0364, 0.0367,
0.0369, 0.0372, 0.0375, 0.0377, 0.0380, 0.0382, 0.0385, 0.0388, 0.0390,
0.0393, 0.0396, 0.0399, 0.0401, 0.0404, 0.0407, 0.0410, 0.0413, 0.0416,
0.0418, 0.0421, 0.0424, 0.0427, 0.0430, 0.0433, 0.0436, 0.0439, 0.0442,
0.0445, 0.0448, 0.0451, 0.0455, 0.0458, 0.0461, 0.0464, 0.0467, 0.0471,
0.0474, 0.0477, 0.0480, 0.0484, 0.0487, 0.0491, 0.0494, 0.0497, 0.0501,
0.0504, 0.0508, 0.0511, 0.0515, 0.0518, 0.0522, 0.0526, 0.0529, 0.0533,
0.0537, 0.0540, 0.0544, 0.0548, 0.0552, 0.0556, 0.0559, 0.0563, 0.0567,
0.0571, 0.0575, 0.0579, 0.0583, 0.0587, 0.0591, 0.0595, 0.0599, 0.0604,
0.0608, 0.0612, 0.0616, 0.0621, 0.0625, 0.0629, 0.0634, 0.0638, 0.0642,
0.0647, 0.0651, 0.0656, 0.0660, 0.0665, 0.0670, 0.0674, 0.0679, 0.0684,
0.0688, 0.0693, 0.0698, 0.0703, 0.0708, 0.0713, 0.0718, 0.0723, 0.0728,
0.0733, 0.0738, 0.0743, 0.0748, 0.0753, 0.0758, 0.0764, 0.0769, 0.0774,
0.0780, 0.0785, 0.0790, 0.0796, 0.0802, 0.0807, 0.0813, 0.0818, 0.0824,
0.0830, 0.0835, 0.0841, 0.0847, 0.0853, 0.0859, 0.0865, 0.0871, 0.0877,
0.0883, 0.0889, 0.0895, 0.0901, 0.0908, 0.0914, 0.0920, 0.0927, 0.0933,
0.0940, 0.0946, 0.0953, 0.0959, 0.0966, 0.0973, 0.0979, 0.0986, 0.0993,
0.1000, 0.1007, 0.1014, 0.1021, 0.1028, 0.1035, 0.1042, 0.1050, 0.1057,
0.1064, 0.1072, 0.1079, 0.1087, 0.1094, 0.1102, 0.1109, 0.1117, 0.1125,
0.1133, 0.1140, 0.1148, 0.1156, 0.1164, 0.1172, 0.1181, 0.1189, 0.1197,
0.1205, 0.1214, 0.1222, 0.1231, 0.1239, 0.1248, 0.1256, 0.1265, 0.1274,
0.1283, 0.1292, 0.1301, 0.1310, 0.1319, 0.1328, 0.1337, 0.1346, 0.1356,
0.1365, 0.1374, 0.1384, 0.1394, 0.1403, 0.1413, 0.1423, 0.1433, 0.1443,
0.1453, 0.1463, 0.1473, 0.1483, 0.1493, 0.1504, 0.1514, 0.1525, 0.1535,
0.1546, 0.1557, 0.1567, 0.1578, 0.1589, 0.1600, 0.1611, 0.1623, 0.1634,
0.1645, 0.1657, 0.1668, 0.1680, 0.1691, 0.1703, 0.1715, 0.1727, 0.1739,
0.1751, 0.1763, 0.1775, 0.1788, 0.1800, 0.1812, 0.1825, 0.1838, 0.1850,
0.1863, 0.1876, 0.1889, 0.1902, 0.1916, 0.1929, 0.1942, 0.1956, 0.1969,
0.1983, 0.1997, 0.2010, 0.2024, 0.2038, 0.2053, 0.2067, 0.2081, 0.2096,
0.2110, 0.2125, 0.2140, 0.2154, 0.2169, 0.2184, 0.2200, 0.2215, 0.2230,
0.2246, 0.2261, 0.2277, 0.2293, 0.2309, 0.2325, 0.2341, 0.2357, 0.2373,
0.2390, 0.2406, 0.2423, 0.2440, 0.2457, 0.2474, 0.2491, 0.2508, 0.2526,
0.2543, 0.2561, 0.2579, 0.2597, 0.2615, 0.2633, 0.2651, 0.2669, 0.2688,
0.2707, 0.2725, 0.2744, 0.2763, 0.2783, 0.2802, 0.2821, 0.2841, 0.2861,
0.2880, 0.2900, 0.2921, 0.2941, 0.2961, 0.2982, 0.3002, 0.3023, 0.3044,
0.3065, 0.3087, 0.3108, 0.3130, 0.3151, 0.3173, 0.3195, 0.3217, 0.3240,
0.3262, 0.3285, 0.3308, 0.3331, 0.3354, 0.3377, 0.3400, 0.3424, 0.3448,
0.3472, 0.3496, 0.3520, 0.3544, 0.3569, 0.3594, 0.3619, 0.3644, 0.3669,
0.3695, 0.3720, 0.3746, 0.3772, 0.3798, 0.3825, 0.3851, 0.3878, 0.3905,
0.3932, 0.3959, 0.3987, 0.4014, 0.4042, 0.4070, 0.4098, 0.4127, 0.4155,
0.4184, 0.4213, 0.4243, 0.4272, 0.4302, 0.4331, 0.4362, 0.4392, 0.4422,
0.4453, 0.4484, 0.4515, 0.4546, 0.4578, 0.4610, 0.4642, 0.4674, 0.4706,
0.4739, 0.4772, 0.4805, 0.4838, 0.4872, 0.4906, 0.4940, 0.4974, 0.5008,
0.5043, 0.5078, 0.5113, 0.5149, 0.5185, 0.5221, 0.5257, 0.5293, 0.5330,
0.5367, 0.5404, 0.5442, 0.5479, 0.5517, 0.5556, 0.5594, 0.5633, 0.5672,
0.5712, 0.5751, 0.5791, 0.5831, 0.5872, 0.5913, 0.5954, 0.5995, 0.6036,
0.6078, 0.6120, 0.6163, 0.6206, 0.6249, 0.6292, 0.6336, 0.6380, 0.6424,
0.6469, 0.6513, 0.6559, 0.6604, 0.6650, 0.6696, 0.6743, 0.6789, 0.6837,
0.6884, 0.6932, 0.6980, 0.7028, 0.7077, 0.7126, 0.7176, 0.7225, 0.7275,
0.7326, 0.7377, 0.7428, 0.7480, 0.7531, 0.7584, 0.7636, 0.7689, 0.7743,
0.7796, 0.7850, 0.7905, 0.7960, 0.8015, 0.8071, 0.8127, 0.8183, 0.8240,
0.8297, 0.8355, 0.8412, 0.8471, 0.8530, 0.8589, 0.8648, 0.8708, 0.8769,
0.8830, 0.8891, 0.8953, 0.9015, 0.9077, 0.9140, 0.9204, 0.9268, 0.9332,
0.9397, 0.9462, 0.9528, 0.9594, 0.9660, 0.9727, 0.9795, 0.9863, 0.9931,
1.0000])Experiment: Increase lr with iteration count i
Limitation: this is a heuristic, not a fully controlled experiment
- Weights are updating throughout the loop, so later (higher)
lrvalues are applied to already-trained weights — not a clean comparison
- Higher
lrvalues appear later in the loop, so observed loss behaviour could be due to thelritself OR just accumulated weight updates from earlier iterations- Use this as a cheap order-of-magnitude heuristic only, then verify with a full training run at the chosen
lr
- i.e. reset weights before testing on chosen
lr/ shortlistedlr’s
# dynamically set lr in training loop. track lr and lre. (as i increases, lr increases)
# ---
# redefine SAME 3,481 network params (with grads!): C, W1, b1, W2, b2
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 2), generator=g) # embedding matrix (lookup table for input tokens)
W1 = torch.randn((6, 100), generator=g) # hidden layer's incoming weights: 6 inputs to layer, 100 hidden neurons in layer
b1 = torch.randn(100, generator=g) # 100 biases live "in" hidden layer's neurons
W2 = torch.randn((100, 27), generator=g) # output layer's incoming weights: 100 inputs to layer, 27 output neurons in layer
b2 = torch.randn(27, generator=g) # 27 biases live "in" output layer's neurons
parameters = [C, W1, b1, W2, b2] # list of all parameters (makes easier to count)
# print('num. of parameters:', sum(p.nelement() for p in parameters)) # total parameter count in network
# ensure all 3,481 parameters have gradient (to enable optimisation)
for p in parameters:
p.requires_grad = True
# ---
lri = [] # track lr used on each iteration
lrei = [] # track lr exponent used on each iteration
lossi = [] # track resulting loss on each iter
# training loop on mini-batches (32 examples per batch)
for i in range(1000):
# minibatch construct:
ix = torch.randint(0, X.shape[0], (32,))
# forward pass
emb = C[X[ix]] # (228146, 3, 2) -> now mini batch (32, 3, 2) -> (32, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y[ix])
print(loss.item())
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent update
lr = lrs[i]
for p in parameters:
p.data += -lr * p.grad
# track stats
lri.append(lr)
lrei.append(lre[i].item())
lossi.append(loss.item())15.817461013793945
20.727684020996094
18.579952239990234
22.8375244140625
23.302196502685547
19.573463439941406
17.357057571411133
20.574554443359375
19.50568389892578
19.73084831237793
17.46883773803711
18.575054168701172
21.118986129760742
18.842613220214844
20.95694923400879
15.70612907409668
16.297971725463867
20.264862060546875
16.862369537353516
19.566694259643555
17.874591827392578
19.03388786315918
17.502553939819336
17.601276397705078
21.453767776489258
17.029590606689453
20.801847457885742
15.9881591796875
18.116924285888672
19.077312469482422
18.56576919555664
16.912073135375977
19.52230453491211
20.265945434570312
17.471927642822266
20.907926559448242
18.94234848022461
19.806608200073242
18.438446044921875
17.96179962158203
18.266042709350586
19.850452423095703
16.8109188079834
19.540035247802734
19.49871063232422
16.79799461364746
20.043956756591797
16.388914108276367
17.3687686920166
17.246158599853516
17.101242065429688
20.203493118286133
17.508333206176758
18.272274017333984
18.707530975341797
17.77618408203125
19.50067710876465
19.5660457611084
18.392244338989258
18.434648513793945
17.435972213745117
20.030954360961914
16.72806739807129
15.857754707336426
17.850162506103516
17.986120223999023
17.83140754699707
17.08683204650879
21.707788467407227
17.715681076049805
18.132123947143555
17.655763626098633
17.94423484802246
14.747505187988281
17.107221603393555
16.613872528076172
17.551597595214844
18.341968536376953
16.759008407592773
22.19915771484375
16.104904174804688
17.35182762145996
17.48605728149414
16.704261779785156
14.76880931854248
16.7908878326416
13.279831886291504
19.142915725708008
19.170936584472656
16.73512077331543
17.22151756286621
16.905214309692383
14.863601684570312
14.58781909942627
13.699442863464355
16.010276794433594
14.648649215698242
21.26009178161621
16.758867263793945
17.338876724243164
13.69503116607666
16.42531394958496
18.00666046142578
17.01445198059082
17.686222076416016
16.922300338745117
19.10821533203125
15.866294860839844
17.8869571685791
18.559471130371094
15.720590591430664
12.700420379638672
15.552727699279785
18.539670944213867
15.660152435302734
16.04014015197754
16.199417114257812
14.388897895812988
14.621724128723145
16.47905731201172
17.970029830932617
15.082656860351562
17.98095703125
16.430278778076172
16.583200454711914
15.284698486328125
19.157058715820312
15.634795188903809
16.829116821289062
18.018508911132812
15.260010719299316
13.623098373413086
14.707389831542969
16.70245361328125
14.042597770690918
17.109378814697266
13.76956844329834
15.8726224899292
16.068376541137695
16.43650245666504
14.840234756469727
14.881007194519043
14.537935256958008
18.74363136291504
14.766796112060547
18.61719512939453
16.553157806396484
12.952573776245117
16.60595703125
16.530803680419922
15.036954879760742
14.995096206665039
15.157450675964355
15.885122299194336
16.146854400634766
13.979454040527344
18.5567684173584
14.092601776123047
15.415642738342285
16.219764709472656
13.74815559387207
13.207083702087402
15.20210075378418
16.52771759033203
15.589210510253906
9.97574234008789
12.870245933532715
15.477426528930664
14.71201229095459
14.307900428771973
14.82184886932373
15.927580833435059
12.724442481994629
15.107034683227539
13.969242095947266
18.32513427734375
16.16431427001953
13.906777381896973
15.333189964294434
13.973067283630371
16.36689567565918
13.670568466186523
14.799532890319824
14.011211395263672
14.978802680969238
14.511653900146484
12.185747146606445
13.682826042175293
14.626754760742188
15.735372543334961
13.263337135314941
12.730430603027344
13.679601669311523
13.666563987731934
13.262198448181152
15.055910110473633
14.483353614807129
14.387089729309082
15.845931053161621
16.10669708251953
14.602375030517578
14.322358131408691
12.919548988342285
11.90576457977295
14.941365242004395
13.845765113830566
15.372206687927246
13.163751602172852
12.624533653259277
13.591926574707031
15.296563148498535
13.926230430603027
13.115325927734375
13.255351066589355
12.727824211120605
13.834505081176758
14.707825660705566
12.35660171508789
12.603516578674316
11.284120559692383
11.569937705993652
12.1688232421875
13.587857246398926
16.043046951293945
12.975847244262695
15.292073249816895
11.999064445495605
11.599960327148438
14.50874137878418
12.434680938720703
12.813700675964355
13.397597312927246
13.507763862609863
9.044493675231934
11.605308532714844
12.988214492797852
12.929308891296387
11.957338333129883
15.260246276855469
12.880605697631836
13.33874225616455
11.700562477111816
8.968193054199219
11.655327796936035
10.769474983215332
16.416624069213867
12.966071128845215
11.985915184020996
11.22724437713623
12.583897590637207
11.020389556884766
12.214652061462402
12.66042709350586
10.518096923828125
11.966146469116211
12.650629043579102
11.72775936126709
10.909994125366211
9.311205863952637
9.622001647949219
14.20141315460205
14.675891876220703
14.375619888305664
11.923262596130371
10.202542304992676
10.58039379119873
11.757230758666992
10.008454322814941
11.824524879455566
13.547927856445312
10.829328536987305
10.16923713684082
10.142289161682129
12.918842315673828
12.844941139221191
9.750563621520996
12.05993366241455
12.015900611877441
12.79056167602539
12.897361755371094
12.380146980285645
10.491909980773926
11.664379119873047
14.856212615966797
11.462396621704102
10.371047019958496
10.2501859664917
11.122889518737793
14.91274356842041
10.199441909790039
14.331771850585938
9.643014907836914
10.691484451293945
8.17427921295166
13.161876678466797
12.396585464477539
10.69675350189209
9.419248580932617
12.643198013305664
10.193705558776855
11.536727905273438
14.485980987548828
12.137154579162598
10.154593467712402
11.904515266418457
12.510735511779785
12.757548332214355
12.417776107788086
9.201485633850098
11.578739166259766
11.758686065673828
10.48065185546875
8.813454627990723
10.908137321472168
10.423663139343262
9.652972221374512
11.782376289367676
9.683504104614258
11.782587051391602
10.260279655456543
11.849191665649414
10.048548698425293
10.799659729003906
11.942893981933594
10.021395683288574
7.844223976135254
8.314265251159668
8.820881843566895
10.505021095275879
12.117244720458984
9.830780029296875
10.075971603393555
8.200671195983887
10.69068717956543
10.822257995605469
8.72839069366455
10.961323738098145
8.620214462280273
9.769427299499512
8.674995422363281
11.772836685180664
10.248162269592285
10.579710960388184
10.014708518981934
11.384011268615723
7.590534687042236
9.842679023742676
9.92381763458252
10.116164207458496
7.839661598205566
9.057456970214844
7.815670013427734
9.050333023071289
9.815936088562012
6.749147891998291
9.515347480773926
10.300520896911621
11.50675106048584
7.5938191413879395
7.084095478057861
12.036088943481445
7.93531608581543
8.996929168701172
6.776278018951416
9.226418495178223
8.175265312194824
7.702246189117432
8.79820728302002
10.614384651184082
9.459273338317871
8.035140991210938
10.978120803833008
9.887626647949219
7.221487522125244
8.765748977661133
9.76805591583252
9.786867141723633
7.039319038391113
7.515896320343018
9.424691200256348
9.094143867492676
8.137001037597656
9.518457412719727
8.198888778686523
9.963623046875
7.576353073120117
8.962698936462402
9.003010749816895
7.298170566558838
8.11429500579834
7.3438262939453125
9.6019868850708
6.493124961853027
7.203567028045654
8.875027656555176
8.131864547729492
8.41823959350586
8.461609840393066
6.083341598510742
7.508514881134033
8.133302688598633
7.256199836730957
9.32106876373291
8.73934268951416
8.091803550720215
7.006851673126221
8.561687469482422
6.065860748291016
8.1846342086792
9.04491901397705
9.25984001159668
7.439466953277588
6.226120948791504
7.165615558624268
9.548285484313965
6.338156700134277
7.114072799682617
8.715886116027832
10.413020133972168
5.951970100402832
8.1874361038208
7.775701522827148
11.204992294311523
6.793336391448975
8.06830883026123
6.335153102874756
7.407699108123779
7.309744358062744
7.045804500579834
7.034719467163086
7.594656467437744
7.820465087890625
7.517097473144531
8.624669075012207
6.543505668640137
6.280078411102295
6.815810203552246
5.2654547691345215
7.165492534637451
5.788631439208984
7.881247520446777
7.840982437133789
7.688949108123779
6.38089656829834
6.205853462219238
6.256971836090088
7.152597904205322
8.98241138458252
7.057969093322754
5.947293758392334
7.729738712310791
6.496196746826172
5.857383728027344
6.091072082519531
7.091391086578369
6.2174882888793945
7.394464492797852
6.677643299102783
7.012604236602783
7.433105945587158
4.358253002166748
4.884062767028809
5.947708606719971
8.1154146194458
6.4793782234191895
5.434818744659424
6.324295520782471
6.7118940353393555
7.402220726013184
8.664177894592285
6.027527809143066
5.363332748413086
5.563065052032471
7.223409175872803
6.825653553009033
7.181892395019531
5.577725887298584
5.765166282653809
5.811359405517578
6.512786865234375
5.500263690948486
5.417076110839844
5.273087978363037
6.421687126159668
5.420617580413818
6.525946617126465
5.926933765411377
5.22885274887085
5.596137046813965
5.118509292602539
6.130753993988037
4.258639335632324
4.801738262176514
5.802199363708496
6.352166175842285
4.586472988128662
5.022086143493652
5.743339538574219
5.331919193267822
6.467977523803711
5.130674362182617
4.723318576812744
5.719657897949219
6.359344482421875
5.260143756866455
4.460512161254883
5.43932580947876
4.413027286529541
4.852667331695557
5.059020042419434
5.7027106285095215
4.319558620452881
5.371232509613037
2.7339441776275635
5.976492404937744
4.833802700042725
4.608621597290039
5.052318096160889
4.422835350036621
4.623468399047852
3.732891321182251
5.077185153961182
6.37028694152832
5.654547214508057
4.966649055480957
5.886227130889893
4.173159122467041
5.770507335662842
5.19843053817749
5.174032688140869
4.698253631591797
6.698628902435303
5.50028657913208
6.63260555267334
4.026273250579834
4.551321029663086
3.7292909622192383
4.584324836730957
4.5142316818237305
4.463397979736328
4.203932285308838
5.1025519371032715
4.599224090576172
4.938704967498779
4.192965030670166
4.03489875793457
5.014468669891357
4.834242343902588
4.916819095611572
4.5253190994262695
3.8350274562835693
4.629182815551758
5.056286334991455
3.680673360824585
4.198812961578369
3.8150758743286133
4.589877605438232
3.653414249420166
3.931208610534668
2.9424216747283936
4.218111515045166
4.945554733276367
4.977082252502441
5.098628520965576
3.9805052280426025
5.25982141494751
4.717648029327393
4.66027307510376
4.54087495803833
4.4865217208862305
2.836559772491455
4.623295783996582
3.136021137237549
3.687187671661377
3.643353223800659
3.94941782951355
5.1641082763671875
3.8817741870880127
3.2002267837524414
4.008306503295898
4.907932758331299
4.042471885681152
4.085297584533691
3.0080769062042236
3.5809271335601807
3.9217331409454346
3.594616651535034
3.6733691692352295
3.6859350204467773
5.048031330108643
3.6066367626190186
4.208157539367676
4.496017932891846
3.9942915439605713
4.700939655303955
3.771108627319336
3.902151346206665
3.324218988418579
4.685276985168457
5.0599751472473145
3.423225164413452
3.2650177478790283
3.3989944458007812
3.1908822059631348
4.194768905639648
3.42152738571167
3.4741051197052
3.5806238651275635
4.653530120849609
3.911792755126953
2.887409210205078
3.2504520416259766
3.6345858573913574
3.6564157009124756
4.74759578704834
3.2438623905181885
3.9820163249969482
3.425572395324707
4.040964126586914
3.6143391132354736
4.41935920715332
3.483180046081543
3.4551193714141846
4.238179683685303
3.26096773147583
3.452152729034424
3.4697489738464355
3.389662981033325
3.415745735168457
3.8932852745056152
3.6912147998809814
3.222118854522705
3.6435577869415283
2.8532049655914307
3.524038076400757
3.2972893714904785
3.033926248550415
3.6008856296539307
3.45412278175354
3.700554609298706
3.912724733352661
4.3351335525512695
3.4689834117889404
3.495131015777588
2.891348361968994
3.3355367183685303
3.0788583755493164
3.3942604064941406
3.151918411254883
3.624617099761963
3.5907702445983887
3.64681077003479
3.8377842903137207
3.8399786949157715
2.4840683937072754
3.4706692695617676
3.4759271144866943
3.398613691329956
3.4552485942840576
3.0517232418060303
3.446397542953491
2.6255924701690674
2.9487383365631104
2.899533987045288
2.915987491607666
3.209062099456787
3.411957025527954
3.2581229209899902
2.6031415462493896
3.252458333969116
3.2344040870666504
3.431769847869873
3.2666239738464355
3.8031985759735107
3.1492578983306885
3.195667028427124
3.1802048683166504
3.225482225418091
3.560657501220703
3.365044116973877
3.3410258293151855
2.824815034866333
2.9077701568603516
3.7000606060028076
3.3394508361816406
3.521113872528076
3.2473065853118896
2.920412063598633
2.849046230316162
2.819181442260742
3.374995470046997
2.6504933834075928
2.981879949569702
3.189527750015259
3.2021307945251465
2.8226473331451416
3.721097230911255
3.061795711517334
3.3750622272491455
3.2341127395629883
3.4401767253875732
3.1103105545043945
3.5829362869262695
2.7838311195373535
3.626854658126831
3.468759536743164
2.587657928466797
3.6598963737487793
3.1146562099456787
3.407831907272339
3.236748456954956
3.865591287612915
2.782829999923706
2.890465497970581
3.36246919631958
3.1910758018493652
3.495417594909668
3.5132391452789307
3.2499353885650635
3.063382625579834
3.1588799953460693
2.8644216060638428
3.632809638977051
2.9965295791625977
2.8661961555480957
3.075990676879883
2.8564398288726807
3.6105291843414307
3.817932367324829
3.172656297683716
2.9332127571105957
2.759866714477539
3.050732135772705
2.9623262882232666
3.0810437202453613
3.144174337387085
2.881589412689209
3.461766481399536
3.214839220046997
3.5485458374023438
2.904618501663208
3.178867816925049
3.041867971420288
3.243053674697876
3.067721366882324
3.029264450073242
3.4226155281066895
3.025907278060913
2.3481478691101074
2.6403729915618896
3.3060474395751953
2.855907917022705
2.943209171295166
3.7845206260681152
2.86333966255188
3.3930063247680664
3.2746834754943848
3.5776922702789307
2.9546449184417725
3.5754990577697754
3.352001190185547
2.8397879600524902
2.9217798709869385
3.376044273376465
2.894430160522461
3.122666358947754
3.703338623046875
3.1919636726379395
3.0408971309661865
2.9936001300811768
3.9806292057037354
3.260660409927368
3.021113872528076
2.5320322513580322
3.7150633335113525
3.9215950965881348
3.4438908100128174
3.758323907852173
2.8540289402008057
3.420647144317627
2.5231029987335205
3.2391269207000732
3.5116634368896484
3.598860740661621
3.4084227085113525
3.6748669147491455
2.992877244949341
3.155154228210449
3.3762595653533936
3.4840431213378906
4.014603614807129
3.4043445587158203
2.792564630508423
3.516199827194214
3.013350486755371
2.650639057159424
2.767376661300659
3.5124804973602295
3.0968921184539795
3.017576217651367
3.365018606185913
3.4322023391723633
3.8705086708068848
3.3570356369018555
3.275090456008911
4.210935115814209
3.4712753295898438
3.9349312782287598
2.9792966842651367
3.5685739517211914
3.2670207023620605
3.1253445148468018
3.0918123722076416
2.777027130126953
3.582979679107666
2.7619030475616455
3.1423559188842773
3.55428409576416
2.8293986320495605
4.083303451538086
3.4188029766082764
3.187849283218384
4.303637981414795
2.816495180130005
3.663630485534668
4.21827507019043
3.3348963260650635
3.8848202228546143
3.4161832332611084
3.796130895614624
3.8726565837860107
3.563249349594116
3.0726866722106934
3.7700259685516357
4.015288352966309
3.758397102355957
4.2085161209106445
3.784741163253784
3.678976535797119
3.724278688430786
3.148805618286133
3.6786396503448486
3.1293156147003174
3.833008289337158
4.624656677246094
4.344481468200684
3.766845464706421
3.3943867683410645
3.9693801403045654
5.157393455505371
4.260666370391846
3.8782434463500977
5.394137859344482
5.133867263793945
5.819971084594727
4.1954851150512695
3.3725175857543945
3.5306131839752197
3.8793745040893555
3.3092453479766846
2.9740042686462402
3.3260467052459717
3.7081711292266846
3.3406126499176025
4.635219573974609
3.761626720428467
3.2893662452697754
3.9020731449127197
2.751098871231079
3.8055832386016846
3.7387359142303467
3.8946759700775146
3.9941933155059814
3.804647922515869
4.798542499542236
6.066043376922607
3.5984761714935303
5.885663032531738
4.941493034362793
5.083160400390625
6.221226692199707
4.58046817779541
4.299558639526367
4.45412540435791
4.9655632972717285
4.549860000610352
4.228638172149658
3.874774694442749
3.8455698490142822
4.296142101287842
4.9673004150390625
4.020113945007324
4.384469509124756
2.9928877353668213
3.4438374042510986
4.151406764984131
5.852766513824463
3.943190336227417
5.68734884262085
3.9982540607452393
4.399563312530518
5.454884052276611
5.997074604034424
6.1042890548706055
4.4299821853637695
3.592714309692383
4.1003313064575195
4.062815189361572
3.787099599838257
4.77357292175293
3.6460764408111572
4.90971565246582
4.632485866546631
6.169092178344727
7.3831281661987305
5.30163049697876
4.5653815269470215
4.722064018249512
4.245277404785156
5.317004203796387
4.845587730407715
4.699864387512207
8.161114692687988
5.787071228027344
6.010645866394043
6.77473783493042
3.6904993057250977
6.666235446929932
6.336170673370361
6.62193489074707
10.04665756225586
7.8862152099609375
6.361058235168457
7.090547561645508
6.1960625648498535
4.485515594482422
5.852258682250977
6.680420875549316
7.553853511810303
5.686176300048828
5.039808750152588
6.579922676086426
6.285622596740723
5.261903762817383
6.431623935699463
5.675989151000977
7.078607559204102
5.6143317222595215
4.2161030769348145
5.0869245529174805
5.315333366394043
6.252468585968018
6.6229071617126465
6.679701805114746
5.406650066375732
6.825788497924805
4.968230724334717
4.9394965171813965
6.346344470977783
4.856315612792969
4.66141939163208
6.455812454223633
6.796693801879883
6.036572456359863
4.989653587341309
4.233273029327393
5.102689743041992
7.247410774230957
5.311208248138428
7.405582427978516
5.500097274780273
6.108129501342773
6.182790279388428
4.766483783721924
5.286949157714844
6.094269275665283
6.185227870941162
4.632195472717285
5.490902900695801
5.8463287353515625
5.4039812088012695
5.933323383331299
6.436313152313232
7.407352447509766
6.620468616485596
7.269753932952881
7.602623462677002
10.564513206481934
7.15559720993042
6.634165287017822
10.385062217712402
7.061819076538086
6.539729595184326
7.2922515869140625
7.308996200561523
8.133033752441406
6.079356670379639
9.263570785522461
7.676933288574219
7.756092071533203Plotting loss (-axis), vs learning rate (-axis), shows a sweet spot:
# plot loss (y-axis) vs learning rate (x-axis)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(lri, lossi)
ax1.set_title('Loss vs Learning Rate')
ax1.set_xlabel('Learning Rate')
ax1.set_ylabel('Loss')
ax2.plot(lrei, lossi)
ax2.set_title('Loss vs Learning Rate Exponent')
ax2.set_xlabel('Learning Rate Exponent')
ax2.set_ylabel('Loss')
plt.tight_layout()
plt.show()lre: Exponent () | lr: Learning Rate | Behaviour |
|---|---|---|
| too small — loss barely reduces | ||
| to | to | sweet spot — loss decreasing and stable |
| too large — loss unstable, gets worse |
So we have some confidence that lr = 0.1 was a reasonable starting point.
Update the learning rate (lr decay)
Note how after doing a few runs of 10,000 iterations each at lr = 0.1, the loss reduction eventually plateaus. We then decay the learning rate (e.g. by a factor of 10) to lr = 0.01 and do more iterations.
# reset network parameters: C, W1, b1, W2, b2
# redefine SAME 3,481 network params (with grads!)
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 2), generator=g) # embedding matrix (lookup table for input tokens)
W1 = torch.randn((6, 100), generator=g) # hidden layer's incoming weights: 6 inputs to layer, 100 hidden neurons in layer
b1 = torch.randn(100, generator=g) # 100 biases live "in" hidden layer's neurons
W2 = torch.randn((100, 27), generator=g) # output layer's incoming weights: 100 inputs to layer, 27 output neurons in layer
b2 = torch.randn(27, generator=g) # 27 biases live "in" output layer's neurons
parameters = [C, W1, b1, W2, b2] # list of all parameters (makes easier to count)
# print('num. of parameters:', sum(p.nelement() for p in parameters)) # total parameter count in network
# ensure all 3,481 parameters have gradient (to enable optimisation)
for p in parameters:
p.requires_grad = True# Run 1 (at lr = 0.1): 10,000 training iters (mini-batches)
for i in range(10000):
# minibatch construct:
ix = torch.randint(0, X.shape[0], (32,))
# forward pass
emb = C[X[ix]] # (228146, 3, 2) -> now mini batch (32, 3, 2) -> (32, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y[ix])
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent update
lr = 0.1
for p in parameters:
p.data += -lr * p.grad
# forward pass FULL DATASET: clean loss number showing true model progress
emb = C[X] # (228146, 3, 2) -> (228146, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (228146, 100)
logits = h @ W2 + b2 # (228146, 27)
loss = F.cross_entropy(logits, Y)
print('Run 1 - full dataset loss:', loss.item())Run 1 - full dataset loss: 2.453782558441162# Run 3 (at lr = 0.1): 10,000 training iters (mini-batches)
for i in range(10000):
# minibatch construct:
ix = torch.randint(0, X.shape[0], (32,))
# forward pass
emb = C[X[ix]] # (228146, 3, 2) -> now mini batch (32, 3, 2) -> (32, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y[ix])
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent update
lr = 0.1
for p in parameters:
p.data += -lr * p.grad
# forward pass FULL DATASET: clean loss number showing true model progress
emb = C[X] # (228146, 3, 2) -> (228146, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (228146, 100)
logits = h @ W2 + b2 # (228146, 27)
loss = F.cross_entropy(logits, Y)
print('Run 2 - full dataset loss:', loss.item())Run 2 - full dataset loss: 2.4282751083374023# Run 3 (at lr = 0.1): 10,000 training iters (mini-batches)
for i in range(10000):
# minibatch construct:
ix = torch.randint(0, X.shape[0], (32,))
# forward pass
emb = C[X[ix]] # (228146, 3, 2) -> now mini batch (32, 3, 2) -> (32, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y[ix])
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent update
lr = 0.1
for p in parameters:
p.data += -lr * p.grad
# forward pass FULL DATASET: clean loss number showing true model progress
emb = C[X] # (228146, 3, 2) -> (228146, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (228146, 100)
logits = h @ W2 + b2 # (228146, 27)
loss = F.cross_entropy(logits, Y)
print('Run 3 - full dataset loss:', loss.item())Run 3 - full dataset loss: 2.4041359424591064# Run 4 (at lr = 0.1): 10,000 training iters (mini-batches)
for i in range(10000):
# minibatch construct:
ix = torch.randint(0, X.shape[0], (32,))
# forward pass
emb = C[X[ix]] # (228146, 3, 2) -> now mini batch (32, 3, 2) -> (32, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y[ix])
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent update
lr = 0.1
for p in parameters:
p.data += -lr * p.grad
# forward pass FULL DATASET: clean loss number showing true model progress
emb = C[X] # (228146, 3, 2) -> (228146, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (228146, 100)
logits = h @ W2 + b2 # (228146, 27)
loss = F.cross_entropy(logits, Y)
print('Run 4 - full dataset loss:', loss.item())Run 4 - full dataset loss: 2.3795299530029297Plateaued loss improvement → decay learning rate 10×
Note, we have already surpassed the loss value of ~2.45 from Bigram NN approach 04_from_bigrams_to_nns and 05_optimisation.
Lower training
lossdoes not mean a better model
- This MLP has a far greater capacity at 3,481 parameters vs. the bigram model’s 27
- A larger model can achieve lower training loss simply by memorising the training data rather than learning general structure
- The bigram loss of ~2.45 was evaluated on the full dataset; if this MLP’s lower loss is driven by overfitting, it will perform worse on unseen data
- Must evaluate on a held-out validation set to make a fair comparison
# Run 5 (decay lr to 0.01): 10,000 training iters (mini-batches)
for i in range(10000):
# minibatch construct:
ix = torch.randint(0, X.shape[0], (32,))
# forward pass
emb = C[X[ix]] # (228146, 3, 2) -> now mini batch (32, 3, 2) -> (32, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y[ix])
# backward pass
for p in parameters:
p.grad = None
loss.backward()
# gradient descent update
lr = 0.01
for p in parameters:
p.data += -lr * p.grad
# forward pass FULL DATASET: clean loss number showing true model progress
emb = C[X] # (228146, 3, 2) -> (228146, 6) next line emb.view(-1, 6)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (228146, 100)
logits = h @ W2 + b2 # (228146, 27)
loss = F.cross_entropy(logits, Y)
print('Run 5 - full dataset loss:', loss.item())Run 5 - full dataset loss: 2.315138816833496Sources
- YouTube: The spelled-out intro to language modeling: building makemore
- Bengio et. al. 2003: A Neural Probabilistic Language Model (implemented here)
- karpathy/makemore on GitHub
- Google Colab: Exercises
- ezyang’s blog: PyTorch Internals