- Prev: 02_nn_data_structs_and_forward_pass
- Next: 04_backprop_train_a_neuron
- Related:
Valueobject data structure, computation graph direction terminology
# imports, `Value` class, init toy nn, graphviz: trace() & draw_dot()
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# Object definitions from end of previous chapter:
# Value class:
class Value:
def __init__(self, data, _children=(), _op='', label=''):
self.data = data
self.grad = 0.0
self._prev = set(_children)
self._op = _op
self.label = label
def __repr__(self):
return f"Value(data={self.data})"
def __add__(self, other):
out = Value(self.data + other.data, (self, other), '+')
return out
def __mul__(self, other):
out = Value(self.data * other.data, (self, other), '*')
return out
# Initialise toy neural network
a = Value(2.0, label='a')
b = Value(-3.0, label='b')
c = Value(10.0, label='c')
e = a*b; e.label = 'e'
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L' # the output of our graph
L
# graphviz
from graphviz import Digraph
def trace(root):
# recursively builds a set of all nodes and edges in a graph
nodes, edges = set(), set()
def build(v):
if v not in nodes:
nodes.add(v)
for child in v._prev:
edges.add((child, v))
build(child)
build(root)
return nodes, edges
def draw_dot(root):
dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right
nodes, edges = trace(root)
for n in nodes:
uid = str(id(n))
# for any value in the graph, create a rectangular ('record') node for it
dot.node(name = uid, label = "{ %s | data %.4f | grad %.4f }" % (n.label, n.data, n.grad), shape='record')
if n._op:
# if this value is a result of some operation, create an op node for it
dot.node(name = uid + n._op, label = n._op)
# and connect this node to it
dot.edge(uid + n._op, uid)
for n1, n2 in edges:
# connect n_i to the op node of n2
dot.edge(str(id(n1)), str(id(n2)) + n2._op)
return dot
# draw_dot(L)6. Manually calculated gradient backpropagation
6.1. Manually fill gradients of L wrt nodes L, f, and d
since , by substituting into the derivative definition we can show for node f:
Hence, , and by symmetry
# manually set L.grad (i.e. d(L)/dL = 1), d.grad, and f.grad
L.grad = 1.0
# since L = d * f, by substituting into the derivative definition: (f(x+h) - f(x))/h, we can show
d.grad = f.data # (bc dL/dd = f = -2.0)
f.grad = d.data # (bc dL/df = d = 4.0)
print('L.grad:', L.grad)
print('d.grad:', d.grad)
print('f.grad:', f.grad)
draw_dot(L) # go and test the new gradients (4.0 and -2.0) via lol()L.grad: 1.0
d.grad: -2.0
f.grad: 4.0Aside: Numerically verify derivatives (via helper function)
def dl_dl(): # for d(L)/dL
'''
- gating function (like a staging area) for testing our manual calculations above,
- keeps all vars local to lol(), so avoids polluting global scope
- lets us set small values for 'h' to test dL/d... one variable at a time
'''
h = 0.0001
# d = e + c; d.label = 'd' # official node definition
d = Value(4.0, label = 'd') # hardcoded for brevity
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L1 = L.data
d = Value(4.0, label = 'd')
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L.data += h # add a small amt 'h'
L2 = L.data
# therefore display d(L)/dL
print('derivative of L wrt itself: \nd(L)/dL =', (L2 - L1)/h)
print('\nrecall, derivative of any variable wrt itself = 1')
print('therefore demonstrating that d(L)/dL = 1')
dl_dl()derivative of L wrt itself:
d(L)/dL = 0.9999999999976694
recall, derivative of any variable wrt itself = 1
therefore demonstrating that d(L)/dL = 1def dl_df(): # for d(L)/df
h = 0.0001
# d = e + c; d.label = 'd' # official node definition
d = Value(4.0, label = 'd') # hardcoded for brevity
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L1 = L.data
d = Value(4.0); d.label = 'd'
f = Value(-2.0, label='f')
f.data += h # add a small amt 'h'
L = d * f; L.label = 'L'
L2 = L.data # try instead to add 'h' here (instead of to 'a'); convince yourself that d(L)/dL gives you 1
# therefore display d(L)/df
print('derivative of L wrt f: \nd(L)/df =', (L2 - L1)/h)
print('\nrecall, d =', d)
print('therefore demonstrating that d(L)/df = d')
dl_df()derivative of L wrt f:
d(L)/df = 3.9999999999995595
recall, d = Value(data=4.0)
therefore demonstrating that d(L)/df = ddef dl_dd(): # for d(L)/dd
h = 0.0001
# d = e + c; d.label = 'd' # official node definition
d = Value(4.0, label = 'd') # hardcoded for brevity
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L1 = L.data
d = Value(4.0, label = 'd')
d.data += h # add a small amt 'h'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L2 = L.data
# therefore display d(L)/dd
print('derivative of L wrt d: \nd(L)/dd =', (L2 - L1)/h)
print('\nrecall, f =', f)
print('therefore demonstrating that d(L)/dd = f')
dl_dd()derivative of L wrt d:
d(L)/dd = -1.9999999999953388
recall, f = Value(data=-2.0)
therefore demonstrating that d(L)/dd = f6.2. But what about node c? Since it influences L through node d, how do we determine its effect?
First, what is ?; Recall:
Hence, , and by symmetry
Recap:
- The
+node (i.e.d) knows thatcandewere added to produced - Also, all
+nodes have “local derivative” of1.0as shown above!- Therefore we know the “local derivatives”
dd/dcanddd/de(both are1.0); - And from before, we also know how
dimpactsL - The question: How do we know how
Lis impacted byc(ande)?
- Therefore we know the “local derivatives”
The answer:
7. The chain rule from calculus
From Wikipedia
If depends on , which itself depends on the (that is, and are dependent variables), then depends on as well, via the intermediate variable .
In this case, the chain rule is expressed as:
and
for indicating at which points the derivatives have to be evaluated.
7.1. Intuitive explanation
Intuitively, the chain rule states that knowing the instantaneous rate of change of z relative to y and that of y relative to x allows one to calculate the instantaneous rate of change of z relative to x as the product of the two rates of change.
As put by George F. Simmons: “If a car travels twice as fast as a bicycle and the bicycle is four times as fast as a walking man, then the car travels 2 × 4 = 8 times as fast as the man.”
7.2. By chain rule:
We know + nodes have local derivatives of 1.0 (for all its inputs) so it “distributes” the gradient! (i.e. gradient passes-thru, unchanged):
Hence,
and by symmetry
# manually set c.grad and e.grad (i.e. d(L)/dc, and d(L)/de)
# `+` nodes simply DISTRIBUTE gradients, because their "local gradients" are 1
# so by chain rule, it's simply multiplying by 1
c.grad = d.grad # (dL/dc = (dL/dd) * (dd/dc)) <---- and dd/dc = 1
e.grad = d.grad # (dL/de = (dL/dd) * (dd/de)) <---- and dd/de = 1
print('e.grad:', e.grad)
print('c.grad:', c.grad)
draw_dot(L) # go and test the new gradients (-2.0, and -2.0) via lol()e.grad: -2.0
c.grad: -2.0Numerical verifications of derivatives and
def dl_dc(): # for d(L)/dc
h = 0.0001
c = Value(10.0, label='c')
# e = a*b; e.label = 'e' # official node definition
e = Value(-6.0, label = 'e') # hardcoded for brevity
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L1 = L.data
c = Value(10.0, label='c')
c.data += h # add small h
e = Value(-6.0, label = 'e')
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L2 = L.data
# therefore display d(L)/dc
print('derivative of L wrt c: \nd(L)/dc =', (L2 - L1)/h)
print('\nrecall previously, d(L)/dd = f =', f)
print('therefore demonstrating "gradient routing": d(L)/dc = d(L)/dd = f')
dl_dc()derivative of L wrt c:
d(L)/dc = -1.9999999999953388
recall previously, d(L)/dd = f = Value(data=-2.0)
therefore demonstrating "gradient routing": d(L)/dc = d(L)/dd = fdef dl_de(): # for d(L)/dc
h = 0.0001
c = Value(10.0, label='c')
# e = a*b; e.label = 'e' # official node definition
e = Value(-6.0); e.label = 'e' # hardcoded for brevity
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L1 = L.data
c = Value(10.0, label='c')
e = Value(-6.0, label = 'e')
e.data += h # add small h
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L2 = L.data
# therefore display d(L)/de
print('derivative of L wrt e: \nd(L)/de =', (L2 - L1)/h)
print('\nrecall previously, d(L)/dd = f =', f)
print('therefore demonstrating "gradient routing": d(L)/de = d(L)/dd = f')
dl_de()derivative of L wrt e:
d(L)/de = -1.9999999999953388
recall previously, d(L)/dd = f = Value(data=-2.0)
therefore demonstrating "gradient routing": d(L)/de = d(L)/dd = f7.3. Apply chain rule again (now on a * multiplication node)
We know
We want to find and . Per the chain rule:
and similarly:
As we did earlier for Node L (another * node), we can say
Therefore
# manually set a.grad and b.grad (i.e. d(L)/da, and d(L)/db)
a.grad = e.grad * b.data # (i.e. dL/da = dL/de * de/da = -2 * b = -2 * -3)
b.grad = e.grad * a.data # (i.e. dL/db = dL/de * de/db = -2 * a = -2 * 2)
print('a.grad:', a.grad)
print('b.grad:', b.grad)
draw_dot(L) # go and test the new gradiea.grad: 6.0
b.grad: -4.0Numerical verifications of derivatives and
def dl_da(): # for d(L)/da
h = 0.0001
a = Value(2.0, label='a')
b = Value(-3.0, label='b')
c = Value(10.0, label='c')
e = a*b; e.label = 'e'
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L1 = L.data
a = Value(2.0, label='a')
a.data += h
b = Value(-3.0, label='b')
c = Value(10.0, label='c')
e = a*b; e.label = 'e'
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L2 = L.data # try instead to add 'h' here (instead of to 'a'); convince yourself that d(L)/dL gives you 1
# therefore display d(L)/da
print('derivative of L wrt a: \nd(L)/da =', (L2 - L1)/h)
print('\nrecall previously, d(L)/da = d(L)/de * d(e)/da = (-2)(b) = (-2)(-3) = 6')
print('therefore demonstrating: d(L)/da = -2b')
dl_da()derivative of L wrt a:
d(L)/da = 6.000000000021544
recall previously, d(L)/da = d(L)/de * d(e)/da = (-2)(b) = (-2)(-3) = 6
therefore demonstrating: d(L)/da = -2bdef dl_db(): # for d(L)/db
h = 0.0001
a = Value(2.0, label='a')
b = Value(-3.0, label='b')
c = Value(10.0, label='c')
e = a*b; e.label = 'e'
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L1 = L.data
a = Value(2.0, label='a') # add a small amt 'h'
b = Value(-3.0, label='b')
b.data += h # add a small amt 'h'
c = Value(10.0, label='c')
e = a*b; e.label = 'e'
d = e + c; d.label = 'd'
f = Value(-2.0, label='f')
L = d * f; L.label = 'L'
L2 = L.data # try instead to add 'h' here (instead of to 'a'); convince yourself that d(L)/dL gives you 1
# therefore display d(L)/db
print('derivative of L wrt b: \nd(L)/db =', (L2 - L1)/h)
print('\nrecall previously, d(L)/da = d(L)/de * d(e)/db = (-2)(a) = (-2)(2) = -4')
print('therefore demonstrating: d(L)/da = -2a')
dl_db()derivative of L wrt b:
d(L)/db = -4.000000000008441
recall previously, d(L)/da = d(L)/de * d(e)/db = (-2)(a) = (-2)(2) = -4
therefore demonstrating: d(L)/da = -2aTakeaways
- That’s backpropagation: A recursive application of chain rule backwards through the graph.
- We iterated backwards from
Lnode; - Locally applied the chain rule at each node (
atof) to calculate gradientLwrt that node
- We iterated backwards from
Towards training:
Using this information to improve the loss function (Node L)
- Nudge the leaf nodes’ (
a,b,c,f) data in the direction of their own gradients- i.e. nudge positive gradient nodes in the positive direction and vice versa
- Then re-evaluate the dependent nodes (
e,d, andL).- i.e. perform a forward pass
Why only leaf nodes?
For illustrative reasons only!
- We can usually change at least some of these during the optimisation process.
- We can change the input data weights (also leaf nodes)
- We can’t change the input data leaf nodes.
- In a real network, intermediate weight nodes may also be changed
# nudge each leaf node's .data value in the direction of its gradient. see L improve
nudge = 0.01
# apply nudges
a.data += nudge * a.grad # note how we're nudging each input in the direction of its own gradient
b.data += nudge * b.grad
c.data += nudge * c.grad
f.data += nudge * f.grad
# forward pass:
e = a * b
d = e + c
L = d * f
# re-evaluate the nodes who are dependent on these leaf nodes
print(L.data, '\b, an improvement from previous L = -8.0')-7.286496, an improvement from previous L = -8.0Sources
- YouTube: The spelled-out intro to neural networks and backpropagation: building micrograd
- karpathy/micrograd on GitHub
- Jupyter notebooks from this chapter
- Google Colab exercises