# imports, graphviz: trace() & draw_dot()
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
 
# Object definitions from end of previous chapter:
# graphviz
from graphviz import Digraph
 
def trace(root):
    # recursively builds a set of all nodes and edges in a graph
    nodes, edges = set(), set()
    def build(v):
        if v not in nodes: 
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)
    build(root)
    return nodes, edges
 
def draw_dot(root):
    dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right
    nodes, edges = trace(root)
    for n in nodes:
        uid = str(id(n))
        # for any value in the graph, create a rectangular ('record') node for it
        dot.node(name = uid, label = "{ %s | data %.4f | grad %.4f }" % (n.label, n.data, n.grad), shape='record')
        if n._op:
            # if this value is a result of some operation, create an op node for it
            dot.node(name = uid + n._op, label = n._op)
            # and connect this node to it
            dot.edge(uid + n._op, uid)
    for n1, n2 in edges:
        # connect n_i to the op node of n2
        dot.edge(str(id(n1)), str(id(n2)) + n2._op)
    return dot

Helper function: Reset graph

Quickly resets the graph to one of two reset_levels. Argument values:

  • gradients: zeros-out all gradients .grad (retains .data values)
  • graph: re-initialises entire graph (nodes: x1, x2, w1, w2, x1w1, x2w2, x1w1x2w2, b, n, and o)
def reset_graph(reset_level):
 
    # declare global gradients
    global x1, x2, w1, w2, x1w1, x2w2, x1w1x2w2, b, n, o
 
    if reset_level == 'gradients':
        x1.grad = x2.grad = w1.grad = w2.grad = x1w1.grad = x2w2.grad = x1w1x2w2.grad = b.grad = n.grad = o.grad = 0
 
        print("reset_graph(): All gradients have been reset to 0")
 
    # reset all variables
    elif reset_level == 'graph':
        # redefine inputs (x1,x2), weights (w1,w2), and then the graph (n = x1*w1 + x2*w2 + b)
        x1 = Value(2.0, label='x1'); x2 = Value(0.0, label='x2')
        w1 = Value(-3.0, label='w1'); w2 = Value(1.0, label='w2')
        x1w1 = x1 * w1; x1w1.label = 'x1*w1'; x2w2 = x2 * w2; x2w2.label = 'x2*w2'
        x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'
 
        # manually change bias to make number nice for education: (b=8 to see tanh squishing!, b=6.8813735870195432 so deriv = 1)
        b = Value(6.8813735870195432, label='b'); 
        n = x1w1x2w2 + b; n.label = 'n'
 
        # try re-run the activation function on n (the raw cell body) and draw the output node o
        o = n.tanh(); o.label = 'o'
 
        print("reset_graph(): All vars, initial and intermediate, have been reset. All gradients now 0")
 
    else: print("reset_graph(): please specify the level of reset desired 'gradients' or 'graph'")

9. Automating backpropagation: Define the backward pass

In section 8.3 (previous chapter), we manually threaded each gradient backwards — 10+ steps by hand. This doesn’t scale.

The fix: each operation (+, *, tanh, …) attaches a _backward() closure to its output Value at creation time. When called, it:

  • applies the chain rule locally (at that operation) to propagate out.grad
  • backwards, into the gradients of that operation’s inputs (self.grad and other.grad)

Notes:

  • Occasionally, is referred to as
  • + and * are operations — they’re not themselves Value nodes.
    • Each operation produces a Value object: out (which does carry .data and .grad),
    • with _backward defined to write gradients back into the inputs (self, other).

Confused? See backprop-graph-terminology

10. Implement _backward() pass for each operation in Value class

  • First, create a _backward function to do the chain rule locally at each node
    • default None, e.g. at leaf nodes, nothing to do
  • More specifically, at each operation, we propagate the output gradient (out.grad)
    • back into the gradients of that operation’s inputs (self.grad and other.grad).
  • Single underscore in _backward() method name (and _prev, _op attribute names) are a convention (not Python-enforced)
    • Denotes a private method/attribute (intended for internal use within the class only)

Important. Expand code cell:

# extend `Value` class (and its methods) with lambda function attribute: `_backward` (function as attribute!)
class Value:
 
    # NB: __init__ is auto-called when you create a new instance of a class. they initialise the attributes of the class.
    def __init__(self, data, _children=(), _op='', label=''):
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None # function for local chain rule at each node. default: do nothing (e.g. at leaf node)
        self._prev = set(_children)
        self._op = _op
        self.label = label
 
    # NB: __repr__ is a Py built-in function: provides a string representation of an obj. for debugging, logging, etc
    def __repr__(self):
        return f"Value(data={self.data})" 
    
    def __add__(self, other):
        out = Value(self.data + other.data, (self, other), '+')
        
        # CHAIN RULE: 
        # -> propagate `out.grad`; i.e. the gradient of '+' (addition) operation's output
        # -> "backwards" into `self.grad` and `other.grad`; i.e. the children nodes' gradients
        def _backward():
            self.grad = 1.0 * out.grad # '+' node, gradient distributed back, as-is
            other.grad = 1.0 * out.grad # same for 2nd child node
        
        # save WHOLE _backward FUNCTION as variable leaving () would CALL it here. 
        # we want to call it LATER (e.g. running x._backward() on some `Value` object `x`)
        out._backward = _backward
        
        return out
 
    def __mul__(self, other):
        out = Value(self.data * other.data, (self, other), '*')
        
        # CHAIN RULE: 
        # -> propagate `out.grad`; i.e. the gradient of '*' (multiplication) operation's output
        # -> "backwards" into `self.grad` and `other.grad`; i.e. the children nodes' gradients
        def _backward():
            
            self.grad = other.data * out.grad
            other.grad = self.data * out.grad
        out._backward = _backward
        
        return out
    
    def tanh(self):
        x = self.data
        t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)
        out = Value(t, (self, ), 'tanh')
        
        # CHAIN RULE: 
        # -> propagate `out.grad`; i.e. the gradient of 'tanh' (act. fn.) operation's output
        # -> "backwards" into `self.grad`; i.e. the (single) child node's gradient
        def _backward():
            self.grad = (1 - t**2) * out.grad # local derivative of tanh (per wikipedia)
        out._backward = _backward
    
        return out

Each operation has a different local derivative

Expand the code cell above to inspect the def _backward(): closures for each operation type in the Value class.

Here’s a summary (Confused? See backprop-graph-terminology):

OperationForwardLocal derivativeGradient flow (_backward)
Leafno-op — gradient stays here (lambda: None)
Additiongradient passes through unchanged to both children
Multiplicationgradient scaled by the other child’s value
tanhgradient scaled by tanh’s local sensitivity
Sigmoidsame structure as tanh — derivative expressible from output alone
ReLUbinary gate — gradient passes if input was positive, killed otherwise
GELU = std. normal CDF, = PDF; derivative requires access to , not just
Maxgradient router — only the winning input receives gradient, loser gets 0

11. Manually call _backward() pass on each node

11.1. Initialise the network

Set real data values, and zero all gradients

# i - re-initialise graph; visualise
reset_graph('graph')
draw_dot(o)
reset_graph(): All vars, initial and intermediate, have been reset. All gradients now 0
4577332224 x1*w1 + x2*w2 data -6.0000 grad 0.0000 4577168720+ + 4577332224->4577168720+ 4577332224+ + 4577332224+->4577332224 4577349648 x2*w2 data 0.0000 grad 0.0000 4577349648->4577332224+ 4577349648* * 4577349648*->4577349648 4577172048 o data 0.7071 grad 0.0000 4577172048tanh tanh 4577172048tanh->4577172048 4577378384 x2 data 0.0000 grad 0.0000 4577378384->4577349648* 4576246400 x1*w1 data -6.0000 grad 0.0000 4576246400->4577332224+ 4576246400* * 4576246400*->4576246400 4577379024 w1 data -3.0000 grad 0.0000 4577379024->4576246400* 4577332496 b data 6.8814 grad 0.0000 4577332496->4577168720+ 4577168720 n data 0.8814 grad 0.0000 4577168720->4577172048tanh 4577168720+->4577168720 4576246704 w2 data 1.0000 grad 0.0000 4576246704->4577349648* 4576976848 x1 data 2.0000 grad 0.0000 4576976848->4576246400*

11.2. Call _backward() (backprop) manually for each node

  • Initialise “global gradient” for the final node
    • o.grad = 1
  • Manually call ._backward() method node-by-node (recursively, in this order):
    • o._backward() propagates o’s gradient to n
    • n._backward() propagates n’s gradient to children b and x1w1x2w2
    • b._backward() do nothing. b is leaf node. initialised _backward = lambda:None
    • x1w1x2w2._backward() propagates x1w1x2w2’s gradient to children x1w1 and x2w2
    • x1w1._backward() propagates x1w1’s gradient to children x1 and w1
    • x2w2._backward() propagates x2w2’s gradient to children x2 and w2
# Uncomment draw_dot(o) lines one at a time to see gradient propagate backwards
 
print('o.grad (before initialising): ', o.grad)
o.grad = 1.0 # init global gradient for final node
print('--> o.grad (initialised): ', o.grad)
 
print('\nn.grad (before o._backward): ', n.grad)
o._backward() # this will route o's gradient backward to n
print('--> n.grad (after o._backward): ', n.grad)
# draw_dot(o) # voila! check n.grad
 
print('\nx1w1x2w2.grad (before n._backward): ', x1w1x2w2.grad, '\nb.grad (before n._backward): ', b.grad)
n._backward() # this will route n's gradient backward to its children, b and x1w1x2w2
print('--> x1w1x2w2.grad (after n._backward): ', x1w1x2w2.grad, '\n--> b.grad (after n._backward): ', b.grad)
# draw_dot(o) # voila! check b.grad and x1w1x2w2.grad
 
b._backward() # by initialisation b's _backward the lambda:None (empty function); nothing happens!
# draw_dot(o)
 
print('\nx1w1.grad (before x1w1x2w2._backward): ', x1w1.grad, '\nx2w2.grad (before x1w1x2w2._backward): ', x2w2.grad)
x1w1x2w2._backward()
print('--> x1w1.grad (after x1w1x2w2._backward): ', x1w1.grad, '\n--> x2w2.grad (after x1w1x2w2._backward): ', x2w2.grad)
# draw_dot(o)
 
print('\nx1.grad (before x1w1._backward): ', x1.grad, '\nw1.grad (before x1w1._backward): ', w1.grad, '\nx2.grad (before x2w2._backward): ', x2.grad, '\nw2.grad (before x2w2._backward): ', w2.grad)
x1w1._backward()
x2w2._backward()
print('--> x1.grad (after x1w1._backward): ', x1.grad, '\n--> w1.grad (after x1w1._backward): ', w1.grad, '\n--> x2.grad (after x2w2._backward): ', x2.grad, '\n--> w2.grad (after x2w2._backward): ', w2.grad)
 
draw_dot(o)
o.grad (before initialising):  0.0
--> o.grad (initialised):  1.0
 
n.grad (before o._backward):  0.0
--> n.grad (after o._backward):  0.4999999999999999
 
x1w1x2w2.grad (before n._backward):  0.0 
b.grad (before n._backward):  0.0
--> x1w1x2w2.grad (after n._backward):  0.4999999999999999 
--> b.grad (after n._backward):  0.4999999999999999
 
x1w1.grad (before x1w1x2w2._backward):  0.0 
x2w2.grad (before x1w1x2w2._backward):  0.0
--> x1w1.grad (after x1w1x2w2._backward):  0.4999999999999999 
--> x2w2.grad (after x1w1x2w2._backward):  0.4999999999999999
 
x1.grad (before x1w1._backward):  0.0 
w1.grad (before x1w1._backward):  0.0 
x2.grad (before x2w2._backward):  0.0 
w2.grad (before x2w2._backward):  0.0
--> x1.grad (after x1w1._backward):  -1.4999999999999996 
--> w1.grad (after x1w1._backward):  0.9999999999999998 
--> x2.grad (after x2w2._backward):  0.4999999999999999 
--> w2.grad (after x2w2._backward):  0.0
4577332224 x1*w1 + x2*w2 data -6.0000 grad 0.5000 4577168720+ + 4577332224->4577168720+ 4577332224+ + 4577332224+->4577332224 4577349648 x2*w2 data 0.0000 grad 0.5000 4577349648->4577332224+ 4577349648* * 4577349648*->4577349648 4577172048 o data 0.7071 grad 1.0000 4577172048tanh tanh 4577172048tanh->4577172048 4577378384 x2 data 0.0000 grad 0.5000 4577378384->4577349648* 4576246400 x1*w1 data -6.0000 grad 0.5000 4576246400->4577332224+ 4576246400* * 4576246400*->4576246400 4577379024 w1 data -3.0000 grad 1.0000 4577379024->4576246400* 4577332496 b data 6.8814 grad 0.5000 4577332496->4577168720+ 4577168720 n data 0.8814 grad 0.5000 4577168720->4577172048tanh 4577168720+->4577168720 4576246704 w2 data 1.0000 grad 0.0000 4576246704->4577349648* 4576976848 x1 data 2.0000 grad -1.5000 4576976848->4576246400*

Sources