Summary: A directed acyclic graph (DAG) where each node is a mathematical operation (or a leaf tensor) and each edge is a data dependency; built implicitly during the forward pass, it is the structure that backpropagation walks in reverse to compute gradients.

What it is

For any differentiable computation , the computation graph is the DAG that decomposes into atomic operations:

  • Leaf nodes — inputs and learnable parameters (raw tensors with no _prev). Examples: input , weight matrix , bias .
  • Internal nodes — operations applied to other nodes (e.g. matmul, add, exp, log, sum). Each stores its operands and a local backward rule.
  • Root — the final output (typically a scalar loss).
  • Edges — pure data dependencies. An edge from to means ” is an input to the op that produces .” Edges carry no weights — weights are themselves leaf nodes.

The forward pass evaluates each node in topological order; the backward pass walks the same graph in reverse, accumulating at every node by repeatedly applying the chain rule (see backprop-graph-terminology for the upstream/downstream conventions).

Granularity is a choice

The “atomic operation” level is an implementation choice, not a property of the math.

LibraryAtomic opExample: exp(x @ W + b)
micrograd (Karpathy)scalar *, +, exp, …one node per scalar multiply and add — many nodes for a single neuron
PyTorch / JAX / TFtensor-level op (matmul, add, exp, …)3 internal nodes: matmul → add → exp, plus 3 leaves (x, W, b)

Coarser granularity = fewer nodes, faster traversal, fused gradient rules. Finer granularity = simpler per-op math, easier to teach. Both produce the same gradients.

Static vs dynamic

  • Dynamic (define-by-run) — PyTorch (default), JAX jit-free, TF eager. The graph is rebuilt on every forward pass. Lets you use Python control flow (if, for) that depends on tensor values; the cost is no whole-program optimisation.
  • Static (define-then-run)torch.compile, JAX jit, TF 1.x graph mode. The graph is captured once and reused; allows kernel fusion, shape specialisation, ahead-of-time compilation. Control flow that varies per input has to be expressed via graph primitives (jax.lax.cond, torch.cond).

For training a typical neural network the distinction is invisible to the user — both produce the same gradients via the same chain-rule traversal.

Worked example

For a single linear-then-exp computation with , , at the tensor-op level:

   x ──┐
       ├─► matmul ──► z₁ ──┐
   W ──┘                   ├─► add ──► z₂ ──► exp ──► y
                       b ──┘
  • 3 leaves (x, W, b), 3 internal ops (matmul, add, exp), 1 output y.
  • During backward, flows in at the right; each op’s local rule pushes a gradient to each of its inputs (exp → multiply by y; add → pass through to both branches; matmul → multiply by the other operand transposed).
  • and are leaves, so their accumulated .grad is what the optimiser uses; is a leaf too but typically has requires_grad=False.

At the scalar-op level (micrograd-style), the same expression with inputs and output expands to roughly multiplies + adds + 1 add (for ) + 1 exp ≈ 8 internal nodes for a single output neuron.

Distinct from the network architecture diagram

The MLP “circles and arcs” picture and the computation graph are two views of the same network at different levels of abstraction.

  • Network diagram — neurons as nodes, weighted connections as edges. One circle = one neuron’s full . A static architectural description.
  • Computation graph — operations as nodes, data dependencies as edges. That same neuron expands into many op-nodes. A dynamic operational trace built at runtime.

Full breakdown: network-diagram-vs-computation-graph.

Why it matters

  • Autodiff. Every modern ML framework computes gradients by recording a computation graph and replaying it backward. There is no “neural-network gradient formula” — just the chain rule, applied node-by-node, which is exactly what the graph encodes.
  • Generality. The graph framework doesn’t care that you’re training a neural network. Any differentiable Python program (physics simulation, probabilistic model, optimisation problem) gets the same treatment — this is why JAX-style grad(f) works on arbitrary functions.
  • Memory-vs-compute trade-offs. The graph holds intermediate activations needed for backward. Techniques like gradient checkpointing drop some intermediates and recompute them on the backward pass — only legible once you think of the network as a graph of stored ops.

See also