# imports, `Value` class, reset_graph() to init nn, graphviz: trace() & draw_dot()
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
 
# Object definitions from end of previous chapter:
class Value:
 
    # NB: __init__ is auto-called when you create a new instance of a class. they initialise the attributes of the class.
    def __init__(self, data, _children=(), _op='', label=''):
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None # function for local chain rule at each node. default: do nothing (e.g. at leaf node)
        self._prev = set(_children)
        self._op = _op
        self.label = label
 
    # NB: __repr__ is a Py built-in function: provides a string representation of an obj. for debugging, logging, etc
    def __repr__(self):
        return f"Value(data={self.data})" 
    
    def __add__(self, other):
        out = Value(self.data + other.data, (self, other), '+')
        
        # CHAIN RULE: 
        # -> propagate `out.grad`; i.e. the gradient of '+' (addition) operation's output
        # -> "backwards" into `self.grad` and `other.grad`; i.e. the children nodes' gradients
        def _backward():
            self.grad + 1.0 * out.grad # '+' node, gradient distributed back, as-is
            other.grad + 1.0 * out.grad # same for 2nd child node
        
        # save WHOLE _backward FUNCTION as variable leaving () would CALL it here. 
        # we want to call it LATER (e.g. running x._backward() on some `Value` object `x`)
        out._backward = _backward
        
        return out
 
    def __mul__(self, other):
        out = Value(self.data * other.data, (self, other), '*')
        
        # CHAIN RULE: 
        # -> propagate `out.grad`; i.e. the gradient of '*' (multiplication) operation's output
        # -> "backwards" into `self.grad` and `other.grad`; i.e. the children nodes' gradients
        def _backward():
            
            self.grad = other.data * out.grad
            other.grad = self.data * out.grad
        out._backward = _backward
        
        return out
    
    def tanh(self):
        x = self.data
        t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)
        out = Value(t, (self, ), 'tanh')
        
        # CHAIN RULE: 
        # -> propagate `out.grad`; i.e. the gradient of 'tanh' (act. fn.) operation's output
        # -> "backwards" into `self.grad`; i.e. the (single) child node's gradient
        def _backward():
            self.grad = (1 - t**2) * out.grad # local derivative of tanh (per wikipedia)
        out._backward = _backward
    
        return out
    
    def backward(self):
    
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)
        
        self.grad = 1.0
        for node in reversed(topo):
            node._backward()
 
# helper: re-initialise graph
def reset_graph(reset_level):
 
    # declare global gradients
    global x1, x2, w1, w2, x1w1, x2w2, x1w1x2w2, b, n, o
 
    if reset_level == 'gradients':
        x1.grad = x2.grad = w1.grad = w2.grad = x1w1.grad = x2w2.grad = x1w1x2w2.grad = b.grad = n.grad = o.grad = 0
 
        print("reset_graph(): All gradients have been reset to 0")
 
    # reset all variables
    elif reset_level == 'graph':
        # redefine inputs (x1,x2), weights (w1,w2), and then the graph (n = x1*w1 + x2*w2 + b)
        x1 = Value(2.0, label='x1'); x2 = Value(0.0, label='x2')
        w1 = Value(-3.0, label='w1'); w2 = Value(1.0, label='w2')
        x1w1 = x1 * w1; x1w1.label = 'x1*w1'; x2w2 = x2 * w2; x2w2.label = 'x2*w2'
        x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'
 
        # manually change bias to make number nice for education: (b=8 to see tanh squishing!, b=6.8813735870195432 so deriv = 1)
        b = Value(6.8813735870195432, label='b'); 
        n = x1w1x2w2 + b; n.label = 'n'
 
        # try re-run the activation function on n (the raw cell body) and draw the output node o
        o = n.tanh(); o.label = 'o'
 
        print("reset_graph(): All vars, initial and intermediate, have been reset. All gradients now 0")
 
    else: print("reset_graph(): please specify the level of reset desired 'gradients' or 'graph'")
 
# graphviz
from graphviz import Digraph
 
def trace(root):
    # recursively builds a set of all nodes and edges in a graph
    nodes, edges = set(), set()
    def build(v):
        if v not in nodes: 
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)
    build(root)
    return nodes, edges
 
def draw_dot(root):
    dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right
    nodes, edges = trace(root)
    for n in nodes:
        uid = str(id(n))
        # for any value in the graph, create a rectangular ('record') node for it
        dot.node(name = uid, label = "{ %s | data %.4f | grad %.4f }" % (n.label, n.data, n.grad), shape='record')
        if n._op:
            # if this value is a result of some operation, create an op node for it
            dot.node(name = uid + n._op, label = n._op)
            # and connect this node to it
            dot.edge(uid + n._op, uid)
    for n1, n2 in edges:
        # connect n_i to the op node of n2
        dot.edge(str(id(n1)), str(id(n2)) + n2._op)
    return dot
 

12. Automate _backward() pass calls (all nodes)

12.1. Initialise the network

Reset data values, and zero all gradients, draw graph

# i - re-initialise graph; visualise
reset_graph('graph')
draw_dot(o)
reset_graph(): All vars, initial and intermediate, have been reset. All gradients now 0
4498850832 w1 data -3.0000 grad 0.0000 4499128704* * 4498850832->4499128704* 4609087568 n data 0.8814 grad 0.0000 4609087824tanh tanh 4609087568->4609087824tanh 4609087568+ + 4609087568+->4609087568 4609215952 x2*w2 data 0.0000 grad 0.0000 4497287072+ + 4609215952->4497287072+ 4609215952* * 4609215952*->4609215952 4499128400 w2 data 1.0000 grad 0.0000 4499128400->4609215952* 4609264368 b data 6.8814 grad 0.0000 4609264368->4609087568+ 4609087824 o data 0.7071 grad 0.0000 4609087824tanh->4609087824 4498851152 x2 data 0.0000 grad 0.0000 4498851152->4609215952* 4499128704 x1*w1 data -6.0000 grad 0.0000 4499128704->4497287072+ 4499128704*->4499128704 4497287072 x1*w1 + x2*w2 data -6.0000 grad 0.0000 4497287072->4609087568+ 4497287072+->4497287072 4608827344 x1 data 2.0000 grad 0.0000 4608827344->4499128704*

12.2. Topologically sort the expression graph

To avoid calling ._backward() individually for each node:

  • Simply ensure that a given node’s gradient has been calculated before calling ._backward() on that node
  • i.e. ensures all dependencies (gradients) are evaluated, before we try to propagate the network further backwards/downstream to “more” prior (input) nodes.

Topological sort lays nodes out as a directed acyclical graph (DAG). All edges go from L R (inspect output):

# build a topological graph
topo = [] # this list is a global variable
visited = set() # maintain set of visited nodes (global variable)
 
def build_topo(v):
    '''
    given a root node, v, this function recursively lays the child nodes of v from Left-to-Right
    this ensuring all arcs point in one direction
 
    Inputs:
        v: root node
 
    Returns:
        None: global variable 'topo' is now a sorted list of nodes
    '''
    if v not in visited:
        visited.add(v) # add node to the list, if we have not yet visited it
        # for child in v._prev: # non-deterministic, but valid!
        for child in sorted(v._prev, key=lambda x: x.label): # enforce stability for education purposes
            build_topo(child) # recursively call
        topo.append(v) # a node only adds itself to the 'topo' global list AFTER all its CHILDREN have been added.
 
build_topo(o) # start the topo sort at node `o` (the root node)
topo # note: this (global) list is now topo sorted!
[Value(data=6.881373587019543),
 Value(data=-3.0),
 Value(data=2.0),
 Value(data=-6.0),
 Value(data=1.0),
 Value(data=0.0),
 Value(data=0.0),
 Value(data=-6.0),
 Value(data=0.8813735870195432),
 Value(data=0.7071067811865476)]

12.3. Run node._backward() on each node in reverse topo order

# i - run `._backward()` on all nodes in reverse topological order (start at final node, `o`)
o.grad = 1.0    # base case: init "global gradient" for final node
 
for node in reversed(topo):
    node._backward()
 
draw_dot(o)
4498850832 w1 data -3.0000 grad 0.0000 4499128704* * 4498850832->4499128704* 4609087568 n data 0.8814 grad 0.5000 4609087824tanh tanh 4609087568->4609087824tanh 4609087568+ + 4609087568+->4609087568 4609215952 x2*w2 data 0.0000 grad 0.0000 4497287072+ + 4609215952->4497287072+ 4609215952* * 4609215952*->4609215952 4499128400 w2 data 1.0000 grad 0.0000 4499128400->4609215952* 4609264368 b data 6.8814 grad 0.0000 4609264368->4609087568+ 4609087824 o data 0.7071 grad 1.0000 4609087824tanh->4609087824 4498851152 x2 data 0.0000 grad 0.0000 4498851152->4609215952* 4499128704 x1*w1 data -6.0000 grad 0.0000 4499128704->4497287072+ 4499128704*->4499128704 4497287072 x1*w1 + x2*w2 data -6.0000 grad 0.0000 4497287072->4609087568+ 4497287072+->4497287072 4608827344 x1 data 2.0000 grad -0.0000 4608827344->4499128704*

12.4. Define dedicated backward() method for the whole expression graph (in Value)

  • Move our recursive topologocial sort function into Value class, to define a backward() method
  • Call this once, on the final node: o.backward(). (hence self = o)
# extend `Value` class: define `backward()` method (recursive, called once on the final node)
class Value:
 
    def __init__(self, data, _children=(), _op='', label=''):
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None # NB init bc ._backward() is uniquely defined and called in the '+', '*' and 'tanh' operations
        self._prev = set(_children)
        self._op = _op
        self.label = label
 
    def __repr__(self):
        return f"Value(data={self.data})" 
 
    def __add__(self, other):
        out = Value(self.data + other.data, (self, other), '+')
 
        def _backward():
            self.grad = out.grad * 1.0
            other.grad = out.grad * 1.0
            
        out._backward = _backward
        
        return out
 
    def __mul__(self, other):
        out = Value(self.data * other.data, (self, other), '*')
 
        def _backward():
            self.grad = out.grad * other.data
            other.grad = out.grad * self.data
            
        out._backward = _backward
        
        return out
 
    def tanh(self): 
        x = self.data
        t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)
        out = Value(t, (self, ), 'tanh')
 
        def _backward():
            self.grad = (1 - t**2) * out.grad
            
        out._backward = _backward
        
        return out
 
    def backward(self):
        '''
        this function is intended to be called once on the FINAL node (thus self = o, upon global function call)
        function call: o.backward()
        '''
        topo = [] # list is now local to this backward() function's scope
        visited = set() # now local to this backward() function's scope
        
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                # for child in v._prev: # non-deterministic, but valid!
                for child in sorted(v._prev, key=lambda x: x.label): # enforce stability for education purposes
                    build_topo(child)
                topo.append(v) # from build_topo() POV, "topo" list is "global". However, "topo" scope is limited to the backward() function only.
        build_topo(self)
        
        self.grad = 1.0 # initialise the current node's gradient to 1.0
        for node in reversed(topo): # i.e. starting at the current node, and going backwards to its children
            node._backward()

Reset graph, call recursive o.backward() function, and visualise:

# i - re-init -> call recursive backprop
reset_graph('graph') # this works because we have already redefined the Value object!
o.backward()
draw_dot(o)
reset_graph(): All vars, initial and intermediate, have been reset. All gradients now 0
4609264640 x1*w1 + x2*w2 data -6.0000 grad 0.5000 4609088592+ + 4609264640->4609088592+ 4609264640+ + 4609264640+->4609264640 4609330704 x2 data 0.0000 grad 0.5000 4609214800* * 4609330704->4609214800* 4609265184 b data 6.8814 grad 0.5000 4609265184->4609088592+ 4609088592 n data 0.8814 grad 0.5000 4609088848tanh tanh 4609088592->4609088848tanh 4609088592+->4609088592 4499129008 x1*w1 data -6.0000 grad 0.5000 4499129008->4609264640+ 4499129008* * 4499129008*->4499129008 4609330384 w1 data -3.0000 grad 1.0000 4609330384->4499129008* 4608830704 x1 data 2.0000 grad -1.5000 4608830704->4499129008* 4609088848 o data 0.7071 grad 1.0000 4609088848tanh->4609088848 4609214800 x2*w2 data 0.0000 grad 0.5000 4609214800->4609264640+ 4609214800*->4609214800 4608129968 w2 data 1.0000 grad 0.0000 4608129968->4609214800*

13. There is a bug with _backward():

Gradient computations via backpropagation are wrong in a few cases, such as the following:

  • We have a data node a we set b = a + a
  • We have data nodes a and b we set d = a * b, e = a + b, and f = d * e.

In both cases, a node is influencing an upstream node through two separate paths on the forward pass.

Incorrect gradient ex. 1:

# i - example 1: incorrect gradient if we set `b = a + a`
a = Value(3.0, label='a')
b = a + a; b.label = 'b'
b.backward()
draw_dot(b)
4608423600 a data 3.0000 grad 1.0000 4608424560+ + 4608423600->4608424560+ 4608424560 b data 6.0000 grad 1.0000 4608424560+->4608424560

Incorrect gradient ex. 2:

# i - example 2: incorrect gradient when a node influences an upstream node multiple times (e.g. through multiple paths)
a = Value(-2.0, label='a')
b = Value(3.0, label='b')
d = a * b; d.label = 'd'
e = a + b; e.label = 'e'
f = d * e; f.label = 'f'
f.backward()
draw_dot(f)
4608719376 a data -2.0000 grad 3.0000 4498993424* * 4608719376->4498993424* 4609003920+ + 4608719376->4609003920+ 4609003728 f data -6.0000 grad 1.0000 4609003728* * 4609003728*->4609003728 4498993424 d data -6.0000 grad 1.0000 4498993424->4609003728* 4498993424*->4498993424 4499162960 b data 3.0000 grad -2.0000 4499162960->4498993424* 4499162960->4609003920+ 4609003920 e data 1.0000 grad -6.0000 4609003920->4609003728* 4609003920+->4609003920

Error fix (redefine Value to accumulate gradients)

For the multivariate case of the chain rule, we need to accumulate (+=) the gradients, instead of overwriting (=) them!

# modify `Value` class to accumulate gradients (+=) per multivariate chain rule
class Value:
 
    def __init__(self, data, _children=(), _op='', label=''):
        self.data = data
        self.grad = 0.0 # initialising gradients to 0 allows us to accumulate (+=) then when using multivar. chain rule case
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op
        self.label = label
 
    def __repr__(self):
        return f"Value(data={self.data})" 
 
    def __add__(self, other):
        out = Value(self.data + other.data, (self, other), '+')
 
        def _backward():
            # accumulate the gradients (+=) bc multivar chain rule (i.e. a node influences the final node MORE THAN ONCE)
            self.grad += out.grad * 1.0
            other.grad += out.grad * 1.0
            
        out._backward = _backward
        
        return out
 
    def __mul__(self, other):
        out = Value(self.data * other.data, (self, other), '*')
 
        def _backward():
            # accumulate the gradients (+=) bc multivar chain rule (i.e. a node influences the final node MORE THAN ONCE)
            self.grad += out.grad * other.data
            other.grad += out.grad * self.data
            
        out._backward = _backward
        
        return out
 
    def tanh(self): 
        x = self.data
        t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)
        out = Value(t, (self, ), 'tanh')
 
        def _backward():
            # accumulate the gradients (+=) bc multivar chain rule (i.e. a node influences the final node MORE THAN ONCE)
            self.grad += (1 - t**2) * out.grad
            
        out._backward = _backward
        
        return out
 
    def backward(self):
        topo = []
        visited = set()
        
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)
        
        self.grad = 1.0
        for node in reversed(topo):
            node._backward()

Retry both examples

# reset example 1
a = Value(3.0, label='a'); b = a + a; b.label = 'b'
b.backward()
draw_dot(b)
4609331984 b data 6.0000 grad 1.0000 4609331984+ + 4609331984+->4609331984 4608831376 a data 3.0000 grad 2.0000 4608831376->4609331984+
# reset example 2
a = Value(-2.0, label='a'); b = Value(3.0, label='b'); d = a * b; d.label = 'd'; e = a + b; e.label = 'e'; f = d * e; f.label = 'f'
f.backward()
draw_dot(f)
4499129920 b data 3.0000 grad -8.0000 4609218256+ + 4499129920->4609218256+ 4499130224* * 4499129920->4499130224* 4608144480 f data -6.0000 grad 1.0000 4608144480* * 4608144480*->4608144480 4609218256 e data 1.0000 grad -6.0000 4609218256->4608144480* 4609218256+->4609218256 4499130224 d data -6.0000 grad 1.0000 4499130224->4608144480* 4499130224*->4499130224 4609332624 a data -2.0000 grad -3.0000 4609332624->4609218256+ 4609332624->4499130224*

Sources