- Related: row normalisation in 02_sampling, bias term in 04_from_bigrams_to_nns
The best way to understand this: Test yourself against code cells below (run confusing ones)
What is broadcasting?
Broadcasting lets PyTorch perform operations on tensors of different shapes without explicitly copying data, by “stretching” the smaller tensor to match the larger one.
Mental model: imagine the smaller tensor being copied/tiled to fill the shape of the larger one, but without actually allocating that memory.
Core rule
PyTorch aligns .shapes from the right. A dimension is compatible if it is either:
- equal in both tensors, or
1in one of them (this dimension gets broadcast/stretched)
Example images
![]()
![]()
![]()
Higher dim example (3D)
import torch1D cases
# 1D + scalar
X = torch.arange(4) # (4,)
Y = torch.tensor(10) # () → broadcast to (4,)
Z = X + Y # (4,)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([0, 1, 2, 3])
tensor(10)
tensor([10, 11, 12, 13])
torch.Size([4]) torch.Size([]) torch.Size([4])# 1D + 1D same shape — no broadcasting
X = torch.arange(4) # (4,)
Y = torch.arange(4) # (4,) → same, no broadcast
Z = X + Y # (4,)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([0, 1, 2, 3])
tensor([0, 1, 2, 3])
tensor([0, 2, 4, 6])
torch.Size([4]) torch.Size([4]) torch.Size([4])# 1D + 1D different size — ERROR
X = torch.arange(4) # (4,)
Y = torch.arange(3) # (3,) → 4≠3, neither is 1 → ✗
print(X, '\n', Y)
print(X.shape, Y.shape)
Z = X + Y # ✗tensor([0, 1, 2, 3])
tensor([0, 1, 2])
torch.Size([4]) torch.Size([3])---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[4], line 6
4 print(X, '\n', Y)
5 print(X.shape, Y.shape)
----> 6 Z = X + Y # ✗
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 02D cases
# (4,) + (1,4) - same result, X treated as (1,4)
X = torch.arange(4) # ( 4,) → treated as (1,4), no broadcast needed
Y = torch.arange(4).reshape(1,4) # (1, 4,) → same
Z = X + Y # (1,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([0, 1, 2, 3])
tensor([[0, 1, 2, 3]])
tensor([[0, 2, 4, 6]])
torch.Size([4]) torch.Size([1, 4]) torch.Size([1, 4])Non-trivial case
# (4,) + (4,1) — outer sum to (4,4)
X = torch.arange(4) # ( 4,) → treated as (1,4) → broadcast to (4,4)
Y = torch.arange(4).reshape(4,1) # ( 4, 1) → broadcast to (4,4)
Z = X + Y # (4,4) Z[i,j] = Y[i] + X[j]
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([0, 1, 2, 3])
tensor([[0],
[1],
[2],
[3]])
tensor([[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]])
torch.Size([4]) torch.Size([4, 1]) torch.Size([4, 4])# (3,4) + (4,) — Y row broadcast across all rows
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(4) # ( 4,) → treated as (1,4) → broadcast to (3,4)
Z = X + Y # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([0, 1, 2, 3])
tensor([[ 0, 2, 4, 6],
[ 4, 6, 8, 10],
[ 8, 10, 12, 14]])
torch.Size([3, 4]) torch.Size([4]) torch.Size([3, 4])# (3,4) + (1,4) — Y row broadcast across all rows (explicit 1)
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(4).reshape(1,4) # (1, 4) → broadcast to (3,4)
Z = X + Y # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[0, 1, 2, 3]])
tensor([[ 0, 2, 4, 6],
[ 4, 6, 8, 10],
[ 8, 10, 12, 14]])
torch.Size([3, 4]) torch.Size([1, 4]) torch.Size([3, 4])# (3,4) + (3,1) — Y column broadcast across all columns
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(3).reshape(3,1) # (3, 1) → broadcast to (3,4)
Z = X + Y # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[0],
[1],
[2]])
tensor([[ 0, 1, 2, 3],
[ 5, 6, 7, 8],
[10, 11, 12, 13]])
torch.Size([3, 4]) torch.Size([3, 1]) torch.Size([3, 4])# (3,4) + (3,) — ERROR: right-aligns to 4 vs 3
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(3) # ( 3,) → treated as (1,3) → right-aligns: 4≠3, neither is 1 → ✗
print(X, '\n', Y)
print(X.shape, Y.shape)
Z = X + Y # ✗tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([0, 1, 2])
torch.Size([3, 4]) torch.Size([3])---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[10], line 6
4 print(X, '\n', Y)
5 print(X.shape, Y.shape)
----> 6 Z = X + Y # ✗
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1# (3,4) + (1,1) — scalar broadcast across entire matrix
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.tensor([[10]]) # (1, 1) → broadcast to (3,4)
Z = X + Y # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[10]])
tensor([[10, 11, 12, 13],
[14, 15, 16, 17],
[18, 19, 20, 21]])
torch.Size([3, 4]) torch.Size([1, 1]) torch.Size([3, 4])# (3,4) + (3,4) — same shape, no broadcasting
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(12).reshape(3,4) # (3, 4) → same, no broadcast
Z = X + Y # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[ 0, 2, 4, 6],
[ 8, 10, 12, 14],
[16, 18, 20, 22]])
torch.Size([3, 4]) torch.Size([3, 4]) torch.Size([3, 4])Non-trivial case
# (1,4) + (3,1) — outer sum. both broadcast to (3,4)
X = torch.arange(4).reshape(1,4) # (1, 4) → broadcast to (3,4)
Y = torch.arange(3).reshape(3,1) # (3, 1) → broadcast to (3,4)
Z = X + Y # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[0, 1, 2, 3]])
tensor([[0],
[1],
[2]])
tensor([[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 5]])
torch.Size([1, 4]) torch.Size([3, 1]) torch.Size([3, 4])3D cases
Non-trivial case
# (2,3,4) + (4,) — broadcast over last dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(4) # ( 4,) → treated as (1,1,4) → broadcast to (2,3,4)
Z = X + Y # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([0, 1, 2, 3])
tensor([[[ 0, 2, 4, 6],
[ 4, 6, 8, 10],
[ 8, 10, 12, 14]],
[[12, 14, 16, 18],
[16, 18, 20, 22],
[20, 22, 24, 26]]])
torch.Size([2, 3, 4]) torch.Size([4]) torch.Size([2, 3, 4])# (2,3,4) + (3,4) — broadcast over first dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(12).reshape(3,4) # ( 3, 4) → treated as (1,3,4) → broadcast to (2,3,4)
Z = X + Y # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[[ 0, 2, 4, 6],
[ 8, 10, 12, 14],
[16, 18, 20, 22]],
[[12, 14, 16, 18],
[20, 22, 24, 26],
[28, 30, 32, 34]]])
torch.Size([2, 3, 4]) torch.Size([3, 4]) torch.Size([2, 3, 4])Non-trivial case
# (2,3,4) + (3,1) — broadcast over first and last dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(3).reshape(3,1) # ( 3, 1) → treated as (1,3,1) → broadcast to (2,3,4)
Z = X + Y # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([[0],
[1],
[2]])
tensor([[[ 0, 1, 2, 3],
[ 5, 6, 7, 8],
[10, 11, 12, 13]],
[[12, 13, 14, 15],
[17, 18, 19, 20],
[22, 23, 24, 25]]])
torch.Size([2, 3, 4]) torch.Size([3, 1]) torch.Size([2, 3, 4])# (2,3,4) + (1,3,1) — broadcast over first and last dim (explicit 1s)
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(3).reshape(1,3,1) # (1, 3, 1) → broadcast to (2,3,4)
Z = X + Y # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([[[0],
[1],
[2]]])
tensor([[[ 0, 1, 2, 3],
[ 5, 6, 7, 8],
[10, 11, 12, 13]],
[[12, 13, 14, 15],
[17, 18, 19, 20],
[22, 23, 24, 25]]])
torch.Size([2, 3, 4]) torch.Size([1, 3, 1]) torch.Size([2, 3, 4])# (2,3,4) + (2,1,4) — broadcast over middle dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(8).reshape(2,1,4) # (2, 1, 4) → broadcast to (2,3,4)
Z = X + Y # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([[[0, 1, 2, 3]],
[[4, 5, 6, 7]]])
tensor([[[ 0, 2, 4, 6],
[ 4, 6, 8, 10],
[ 8, 10, 12, 14]],
[[16, 18, 20, 22],
[20, 22, 24, 26],
[24, 26, 28, 30]]])
torch.Size([2, 3, 4]) torch.Size([2, 1, 4]) torch.Size([2, 3, 4])# (2,3,4) + (2,3,1) — broadcast over last dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(6).reshape(2,3,1) # (2, 3, 1) → broadcast to (2,3,4)
Z = X + Y # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([[[0],
[1],
[2]],
[[3],
[4],
[5]]])
tensor([[[ 0, 1, 2, 3],
[ 5, 6, 7, 8],
[10, 11, 12, 13]],
[[15, 16, 17, 18],
[20, 21, 22, 23],
[25, 26, 27, 28]]])
torch.Size([2, 3, 4]) torch.Size([2, 3, 1]) torch.Size([2, 3, 4])Non-trivial case
# (2,1,4) + (1,3,1) — broadcast over both middle dims → (2,3,4)
X = torch.arange(8).reshape(2,1,4) # (2, 1, 4) → broadcast to (2,3,4)
Y = torch.arange(3).reshape(1,3,1) # (1, 3, 1) → broadcast to (2,3,4)
Z = X + Y # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[[0, 1, 2, 3]],
[[4, 5, 6, 7]]])
tensor([[[0],
[1],
[2]]])
tensor([[[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 5]],
[[4, 5, 6, 7],
[5, 6, 7, 8],
[6, 7, 8, 9]]])
torch.Size([2, 1, 4]) torch.Size([1, 3, 1]) torch.Size([2, 3, 4])# (2,3,4) + (2,4) — ERROR: right-aligns to (2,3,4) vs (_,2,4) → 3 vs 2
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(8).reshape(2,4) # ( 2, 4) → treated as (1,2,4) → right-aligns: 3≠2, neither is 1 → ✗
print(X, '\n', Y)
print(X.shape, Y.shape)
Z = X + Y # ✗tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
torch.Size([2, 3, 4]) torch.Size([2, 4])---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[21], line 6
4 print(X, '\n', Y)
5 print(X.shape, Y.shape)
----> 6 Z = X + Y # ✗
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1# (1,3,4) + (2,1,4) — both 1-dims broadcast → (2,3,4)
X = torch.arange(12).reshape(1,3,4) # (1, 3, 4) → broadcast to (2,3,4)
Y = torch.arange(8).reshape(2,1,4) # (2, 1, 4) → broadcast to (2,3,4)
Z = X + Y # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])
tensor([[[0, 1, 2, 3]],
[[4, 5, 6, 7]]])
tensor([[[ 0, 2, 4, 6],
[ 4, 6, 8, 10],
[ 8, 10, 12, 14]],
[[ 4, 6, 8, 10],
[ 8, 10, 12, 14],
[12, 14, 16, 18]]])
torch.Size([1, 3, 4]) torch.Size([2, 1, 4]) torch.Size([2, 3, 4])Sources
- PyTorch Docs: Broadcasting semantics
- Medium: Understanding Broadcasting in PyTorch
- StackOverflow: How does pytorch broadcasting work?
- X - Akshay Pachaar
- D2L: Dive into Deep Learning → Data Manipulation → Broadcasting
Higher dim example (3D)
