# imports, reset_graph() to init nn, graphviz: trace() & draw_dot()
 
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
 
# helper: re-initialise graph
def reset_graph(reset_level):
 
    # declare global gradients
    global x1, x2, w1, w2, x1w1, x2w2, x1w1x2w2, b, n, o
 
    if reset_level == 'gradients':
        x1.grad = x2.grad = w1.grad = w2.grad = x1w1.grad = x2w2.grad = x1w1x2w2.grad = b.grad = n.grad = o.grad = 0
 
        print("reset_graph(): All gradients have been reset to 0")
 
    # reset all variables
    elif reset_level == 'graph':
        # redefine inputs (x1,x2), weights (w1,w2), and then the graph (n = x1*w1 + x2*w2 + b)
        x1 = Value(2.0, label='x1'); x2 = Value(0.0, label='x2')
        w1 = Value(-3.0, label='w1'); w2 = Value(1.0, label='w2')
        x1w1 = x1 * w1; x1w1.label = 'x1*w1'; x2w2 = x2 * w2; x2w2.label = 'x2*w2'
        x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'
 
        # manually change bias to make number nice for education: (b=8 to see tanh squishing!, b=6.8813735870195432 so deriv = 1)
        b = Value(6.8813735870195432, label='b'); 
        n = x1w1x2w2 + b; n.label = 'n'
 
        # try re-run the activation function on n (the raw cell body) and draw the output node o
        o = n.tanh(); o.label = 'o'
 
        print("reset_graph(): All vars, initial and intermediate, have been reset. All gradients now 0")
 
    else: print("reset_graph(): please specify the level of reset desired 'gradients' or 'graph'")
 
# graphviz
from graphviz import Digraph
 
def trace(root):
    # recursively builds a set of all nodes and edges in a graph
    nodes, edges = set(), set()
    def build(v):
        if v not in nodes: 
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)
    build(root)
    return nodes, edges
 
def draw_dot(root):
    dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right
    nodes, edges = trace(root)
    for n in nodes:
        uid = str(id(n))
        # for any value in the graph, create a rectangular ('record') node for it
        dot.node(name = uid, label = "{ %s | data %.4f | grad %.4f }" % (n.label, n.data, n.grad), shape='record')
        if n._op:
            # if this value is a result of some operation, create an op node for it
            dot.node(name = uid + n._op, label = n._op)
            # and connect this node to it
            dot.edge(uid + n._op, uid)
    for n1, n2 in edges:
        # connect n_i to the op node of n2
        dot.edge(str(id(n1)), str(id(n2)) + n2._op)
    return dot

Exercise

  • We implemented tanh as a single composite operation (method: .tanh()).
    • Valid, because we know its local derivative (see self.grad in _backward())
  • Now re-implement it, only using its constituent operations
  • Bonus: good practice implementing a few more neuron operations!

Approach:

  1. Generalise existing “left operand” methods to handle expressions with multiple Types:
    1. __add__() method must handle: Value + int (i.e. Value.__add__(int))
    2. __mul__() method must handle: Value * int (i.e. Value.__mul__(int))
    3. How: Assume non-Value operand is int (/float) wrap: Value(int)
  2. Fallbacks: create reflected versions of the above (for swapped operands)
    1. New __radd__() method handles: int + Value (i.e. int.__radd__(Value))
    2. New __rmul__() method handles: int * Value (i.e. int.__rmul__(Value))
  3. Define exponentiation method: exp()
    1. math.exp() builtin function, and single input (self)
  4. One could define a division method __truediv()__
    1. But it’s more general to implement a __pow__() (e.g. for x**k)
    2. Division is a special case of multiplication (by the inverse: a / b => a * (1/b) => a * b**-1)

For steps 3 and 4, we need to define the local derivative for backpropagating gradients (self.grad out.grad):

OperationForwardLocal Derivative_backward() Gradient Flow
Exponentiation
Power

Recall: The conventional _backward() pass gradient flow direction described here

Implementation (extend Value)

# extend `Value` class with the constituent methods listed above:
class Value:
    def __init__(self, data, _children=(), _op='', label=''):
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op
        self.label = label
 
    def __repr__(self):
        return f"Value(data={self.data})" 
 
    def __add__(self, other):
        # pre-process `other`. If it is non-`Value`, assume `int`/`float` and wrap in `Value()`
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data + other.data, (self, other), '+')
 
        def _backward():
            self.grad += out.grad * 1.0
            other.grad += out.grad * 1.0
        out._backward = _backward
        
        return out
 
    def __mul__(self, other):
        # pre-process `other`. If it is non-`Value`, assume int/float and wrap in `Value()`
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data * other.data, (self, other), '*')
 
        def _backward():
            self.grad += other.data * out.grad
            other.grad += self.data * out.grad
        out._backward = _backward
        
        return out
 
    # def __radd__(self, other):  # fallback for swapped operands: i.e. other + self
    #     return self + other     # route to `__add__`
    
    def __rmul__(self, other):  # fallback for swapped operands: i.e. other * self
        return self * other     # route to `__mul__`
    
    # ensure `other` is NEVER a `Value` object. Only int/float allowed
    def __pow__(self, other):
        assert isinstance(other, (int, float)), "only supporting int/float powers for now"
        out = Value(self.data**other, (self,), f'**{other}')
        
        # recall downstream grad = local grad * upstream grad
        # local gradient for x^k: d(x^k)/dx = kx^(k-1)
        def _backward():
            self.grad += other * (self.data ** (other - 1)) * out.grad
        out._backward = _backward
        
        return out
    
    def __truediv__(self, other): # i.e. self / other but...
        return self * other**-1   # use previously defined __mul__() and __pow__(), instead of implementing `/` operation and its own `_backward()``
    
    def __neg__(self): # -self
        return self * -1        # use previously defined __mul__() to evaluate this `Value` * `int` expression
    
    def __sub__(self, other):   # self - other
        return self + (-other)  # use previously defined __add__(), instead of implementing `-` operation and its own `_backward()``
 
    def tanh(self): 
        x = self.data
        t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)
        out = Value(t, (self, ), 'tanh')
 
        def _backward():
            self.grad += (1 - t**2) * out.grad
        out._backward = _backward
        
        return out
    
    # define exponentiation method
    def exp(self):
        x = self.data                               # input data value
        out = Value(math.exp(x), (self, ), 'exp')   # output data value: use builtin math.exp(x)
        
        # recall downstream grad = local grad * upstream grad
        # local gradient for exp: d(e^x)/dx = e^x (i.e. out.data, just calculated!)
        def _backward():
            self.grad += out.data * out.grad
        out._backward = _backward
        
        return out
    
    # define division method
 
    def backward(self):
        topo = []
        visited = set()
        
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        build_topo(self)
        
        self.grad = 1.0
        for node in reversed(topo):
            node._backward()

Test implementation

Original tanh

# i - original o = n.tanh() call
reset_graph('graph')
o.backward()
draw_dot(o)
reset_graph(): All vars, initial and intermediate, have been reset. All gradients now 0
4574679056 x2*w2 data 0.0000 grad 0.5000 4463192784+ + 4574679056->4463192784+ 4574679056* * 4574679056*->4574679056 4464755728 x2 data 0.0000 grad 0.5000 4464755728->4574679056* 4573575808 x1*w1 data -6.0000 grad 0.5000 4573575808->4463192784+ 4573575808* * 4573575808*->4573575808 4573200080 w1 data -3.0000 grad 1.0000 4573200080->4573575808* 4463192784 x1*w1 + x2*w2 data -6.0000 grad 0.5000 4574516560+ + 4463192784->4574516560+ 4463192784+->4463192784 4463045392 w2 data 1.0000 grad 0.0000 4463045392->4574679056* 4574516560 n data 0.8814 grad 0.5000 4574517584tanh tanh 4574516560->4574517584tanh 4574516560+->4574516560 4574517584 o data 0.7071 grad 1.0000 4574517584tanh->4574517584 4574289872 x1 data 2.0000 grad -1.5000 4574289872->4573575808* 4574661088 b data 6.8814 grad 0.5000 4574661088->4574516560+

Simplified tanh

Inspect the new graph:

  • Data values must add up in the forward pass.
  • The tanh operation node should be decomposed into a series of simple operation nodes
  • Inspect backpropagated gradients at leaves (x1, x2, w1, w2, b). Should match the above.
# reset graph -> overwrite o = n.tanh() node
reset_graph('graph')
 
# overwrite o = n.tanh() -> express activation function as constituent operations 
e = (2*n).exp()
o = (e - 1) / (e + 1)
o.label = 'o'
 
# perform backward pass, and draw the output node o
o.backward()
draw_dot(o)
reset_graph(): All vars, initial and intermediate, have been reset. All gradients now 0
4574181904 w2 data 1.0000 grad 0.0000 4461779792* * 4574181904->4461779792* 4574646304 n data 0.8814 grad 0.5000 4574476576* * 4574646304->4574476576* 4574646304+ + 4574646304+->4574646304 4574646832 b data 6.8814 grad 0.5000 4574646832->4574646304+ 4574900304 data 5.8284 grad 0.0429 4574900816+ + 4574900304->4574900816+ 4574899664+ + 4574900304->4574899664+ 4574900304exp exp 4574900304exp->4574900304 4464876112 x1*w1 data -6.0000 grad 0.5000 4574449488+ + 4464876112->4574449488+ 4464876112* * 4464876112*->4464876112 4573887088 x1 data 2.0000 grad -1.5000 4573887088->4464876112* 4574900816 data 4.8284 grad 0.1464 4574902608* * 4574900816->4574902608* 4574900816+->4574900816 4574900688 data -1.0000 grad 0.1464 4574900688->4574900816+ 4574182576 w1 data -3.0000 grad 1.0000 4574182576->4464876112* 4574904016 data 0.1464 grad 4.8284 4574904016->4574902608* 4574904016**-1 **-1 4574904016**-1->4574904016 4574903248 data 1.0000 grad -0.1036 4574903248->4574899664+ 4574476576 data 1.7627 grad 0.2500 4574476576->4574900304exp 4574476576*->4574476576 4574902608 o data 0.7071 grad 1.0000 4574902608*->4574902608 4574449488 x1*w1 + x2*w2 data -6.0000 grad 0.5000 4574449488->4574646304+ 4574449488+->4574449488 4461779792 x2*w2 data 0.0000 grad 0.5000 4461779792->4574449488+ 4461779792*->4461779792 4573887328 x2 data 0.0000 grad 0.5000 4573887328->4461779792* 4574933424 data 2.0000 grad 0.2203 4574933424->4574476576* 4574899664 data 6.8284 grad -0.1036 4574899664->4574904016**-1 4574899664+->4574899664

Takeaways

  • The level at which a neuron operation (method) is implemented is arbitrary.
    • Simple operations like addition (+) and complex composite ones like tanh are equivalent.
  • The only prerequisite to implement an operation. You must be able to perform:
    • Forward pass: Some output(s) that are a function of some input(s),
    • Backward pass: The operation is differentiable (i.e. we can find and chain its local gradient)

Sources