The best way to understand this: Test yourself against code cells below (run confusing ones)

What is broadcasting?

Broadcasting lets PyTorch perform operations on tensors of different shapes without explicitly copying data, by “stretching” the smaller tensor to match the larger one.

Mental model: imagine the smaller tensor being copied/tiled to fill the shape of the larger one, but without actually allocating that memory.

Core rule

PyTorch aligns .shapes from the right. A dimension is compatible if it is either:

  • equal in both tensors, or
  • 1 in one of them (this dimension gets broadcast/stretched)
import torch

1D cases

# 1D + scalar
X = torch.arange(4)          # (4,)
Y = torch.tensor(10)         # ()  → broadcast to (4,)
Z = X + Y                    # (4,)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([0, 1, 2, 3]) 
 tensor(10) 
 tensor([10, 11, 12, 13])
torch.Size([4]) torch.Size([]) torch.Size([4])
# 1D + 1D same shape — no broadcasting
X = torch.arange(4)          # (4,)
Y = torch.arange(4)          # (4,)  → same, no broadcast
Z = X + Y                    # (4,)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([0, 1, 2, 3]) 
 tensor([0, 1, 2, 3]) 
 tensor([0, 2, 4, 6])
torch.Size([4]) torch.Size([4]) torch.Size([4])
# 1D + 1D different size — ERROR
X = torch.arange(4)          # (4,)
Y = torch.arange(3)          # (3,)  → 4≠3, neither is 1 → ✗
print(X, '\n', Y)
print(X.shape, Y.shape)
Z = X + Y                    # ✗
tensor([0, 1, 2, 3]) 
 tensor([0, 1, 2])
torch.Size([4]) torch.Size([3])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[4], line 6
      4 print(X, '\n', Y)
      5 print(X.shape, Y.shape)
----> 6 Z = X + Y                    # ✗
 
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

2D cases

# (4,) + (1,4) - same result, X treated as (1,4)
X = torch.arange(4)              # (   4,)  → treated as (1,4), no broadcast needed
Y = torch.arange(4).reshape(1,4) # (1, 4,)  → same
Z = X + Y                        # (1,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([0, 1, 2, 3]) 
 tensor([[0, 1, 2, 3]]) 
 tensor([[0, 2, 4, 6]])
torch.Size([4]) torch.Size([1, 4]) torch.Size([1, 4])

Non-trivial case

# (4,) + (4,1) — outer sum to (4,4)
X = torch.arange(4)              # (     4,)  → treated as (1,4) → broadcast to (4,4)
Y = torch.arange(4).reshape(4,1) # ( 4,  1)   → broadcast to (4,4)
Z = X + Y                        # (4,4)  Z[i,j] = Y[i] + X[j]
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([0, 1, 2, 3]) 
 tensor([[0],
        [1],
        [2],
        [3]]) 
 tensor([[0, 1, 2, 3],
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]])
torch.Size([4]) torch.Size([4, 1]) torch.Size([4, 4])
# (3,4) + (4,) — Y row broadcast across all rows
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(4)               # (   4,) → treated as (1,4) → broadcast to (3,4)
Z = X + Y                         # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]) 
 tensor([0, 1, 2, 3]) 
 tensor([[ 0,  2,  4,  6],
        [ 4,  6,  8, 10],
        [ 8, 10, 12, 14]])
torch.Size([3, 4]) torch.Size([4]) torch.Size([3, 4])
# (3,4) + (1,4) — Y row broadcast across all rows (explicit 1)
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(4).reshape(1,4)  # (1, 4)  → broadcast to (3,4)
Z = X + Y                         # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]) 
 tensor([[0, 1, 2, 3]]) 
 tensor([[ 0,  2,  4,  6],
        [ 4,  6,  8, 10],
        [ 8, 10, 12, 14]])
torch.Size([3, 4]) torch.Size([1, 4]) torch.Size([3, 4])
# (3,4) + (3,1) — Y column broadcast across all columns
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(3).reshape(3,1)  # (3, 1)  → broadcast to (3,4)
Z = X + Y                         # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]) 
 tensor([[0],
        [1],
        [2]]) 
 tensor([[ 0,  1,  2,  3],
        [ 5,  6,  7,  8],
        [10, 11, 12, 13]])
torch.Size([3, 4]) torch.Size([3, 1]) torch.Size([3, 4])
# (3,4) + (3,) — ERROR: right-aligns to 4 vs 3
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(3)               # (   3,)  → treated as (1,3) → right-aligns: 4≠3, neither is 1 → ✗
print(X, '\n', Y)
print(X.shape, Y.shape)
Z = X + Y                         # ✗
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]) 
 tensor([0, 1, 2])
torch.Size([3, 4]) torch.Size([3])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[10], line 6
      4 print(X, '\n', Y)
      5 print(X.shape, Y.shape)
----> 6 Z = X + Y                         # ✗
 
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1
# (3,4) + (1,1) — scalar broadcast across entire matrix
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.tensor([[10]])          # (1, 1)  → broadcast to (3,4)
Z = X + Y                         # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]) 
 tensor([[10]]) 
 tensor([[10, 11, 12, 13],
        [14, 15, 16, 17],
        [18, 19, 20, 21]])
torch.Size([3, 4]) torch.Size([1, 1]) torch.Size([3, 4])
# (3,4) + (3,4) — same shape, no broadcasting
X = torch.arange(12).reshape(3,4) # (3, 4)
Y = torch.arange(12).reshape(3,4) # (3, 4)  → same, no broadcast
Z = X + Y                         # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]) 
 tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]) 
 tensor([[ 0,  2,  4,  6],
        [ 8, 10, 12, 14],
        [16, 18, 20, 22]])
torch.Size([3, 4]) torch.Size([3, 4]) torch.Size([3, 4])

Non-trivial case

# (1,4) + (3,1) — outer sum. both broadcast to (3,4)
X = torch.arange(4).reshape(1,4)  # (1, 4)  → broadcast to (3,4)
Y = torch.arange(3).reshape(3,1)  # (3, 1)  → broadcast to (3,4)
Z = X + Y                         # (3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[0, 1, 2, 3]]) 
 tensor([[0],
        [1],
        [2]]) 
 tensor([[0, 1, 2, 3],
        [1, 2, 3, 4],
        [2, 3, 4, 5]])
torch.Size([1, 4]) torch.Size([3, 1]) torch.Size([3, 4])

3D cases

Non-trivial case

# (2,3,4) + (4,) — broadcast over last dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(4)                 # (      4,)  → treated as (1,1,4) → broadcast to (2,3,4)
Z = X + Y                           # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],
 
        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]]) 
 tensor([0, 1, 2, 3]) 
 tensor([[[ 0,  2,  4,  6],
         [ 4,  6,  8, 10],
         [ 8, 10, 12, 14]],
 
        [[12, 14, 16, 18],
         [16, 18, 20, 22],
         [20, 22, 24, 26]]])
torch.Size([2, 3, 4]) torch.Size([4]) torch.Size([2, 3, 4])
# (2,3,4) + (3,4) — broadcast over first dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(12).reshape(3,4)   # (   3, 4)  → treated as (1,3,4) → broadcast to (2,3,4)
Z = X + Y                           # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],
 
        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]]) 
 tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]) 
 tensor([[[ 0,  2,  4,  6],
         [ 8, 10, 12, 14],
         [16, 18, 20, 22]],
 
        [[12, 14, 16, 18],
         [20, 22, 24, 26],
         [28, 30, 32, 34]]])
torch.Size([2, 3, 4]) torch.Size([3, 4]) torch.Size([2, 3, 4])

Non-trivial case

# (2,3,4) + (3,1) — broadcast over first and last dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(3).reshape(3,1)    # (   3, 1)  → treated as (1,3,1) → broadcast to (2,3,4)
Z = X + Y                           # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],
 
        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]]) 
 tensor([[0],
        [1],
        [2]]) 
 tensor([[[ 0,  1,  2,  3],
         [ 5,  6,  7,  8],
         [10, 11, 12, 13]],
 
        [[12, 13, 14, 15],
         [17, 18, 19, 20],
         [22, 23, 24, 25]]])
torch.Size([2, 3, 4]) torch.Size([3, 1]) torch.Size([2, 3, 4])
# (2,3,4) + (1,3,1) — broadcast over first and last dim (explicit 1s)
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(3).reshape(1,3,1)  # (1, 3, 1)  → broadcast to (2,3,4)
Z = X + Y                           # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],
 
        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]]) 
 tensor([[[0],
         [1],
         [2]]]) 
 tensor([[[ 0,  1,  2,  3],
         [ 5,  6,  7,  8],
         [10, 11, 12, 13]],
 
        [[12, 13, 14, 15],
         [17, 18, 19, 20],
         [22, 23, 24, 25]]])
torch.Size([2, 3, 4]) torch.Size([1, 3, 1]) torch.Size([2, 3, 4])
# (2,3,4) + (2,1,4) — broadcast over middle dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(8).reshape(2,1,4)  # (2, 1, 4)  → broadcast to (2,3,4)
Z = X + Y                           # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],
 
        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]]) 
 tensor([[[0, 1, 2, 3]],
 
        [[4, 5, 6, 7]]]) 
 tensor([[[ 0,  2,  4,  6],
         [ 4,  6,  8, 10],
         [ 8, 10, 12, 14]],
 
        [[16, 18, 20, 22],
         [20, 22, 24, 26],
         [24, 26, 28, 30]]])
torch.Size([2, 3, 4]) torch.Size([2, 1, 4]) torch.Size([2, 3, 4])
# (2,3,4) + (2,3,1) — broadcast over last dim
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(6).reshape(2,3,1)  # (2, 3, 1)  → broadcast to (2,3,4)
Z = X + Y                           # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],
 
        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]]) 
 tensor([[[0],
         [1],
         [2]],
 
        [[3],
         [4],
         [5]]]) 
 tensor([[[ 0,  1,  2,  3],
         [ 5,  6,  7,  8],
         [10, 11, 12, 13]],
 
        [[15, 16, 17, 18],
         [20, 21, 22, 23],
         [25, 26, 27, 28]]])
torch.Size([2, 3, 4]) torch.Size([2, 3, 1]) torch.Size([2, 3, 4])

Non-trivial case

# (2,1,4) + (1,3,1) — broadcast over both middle dims → (2,3,4)
X = torch.arange(8).reshape(2,1,4)  # (2, 1, 4)  → broadcast to (2,3,4)
Y = torch.arange(3).reshape(1,3,1)  # (1, 3, 1)  → broadcast to (2,3,4)
Z = X + Y                           # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[[0, 1, 2, 3]],
 
        [[4, 5, 6, 7]]]) 
 tensor([[[0],
         [1],
         [2]]]) 
 tensor([[[0, 1, 2, 3],
         [1, 2, 3, 4],
         [2, 3, 4, 5]],
 
        [[4, 5, 6, 7],
         [5, 6, 7, 8],
         [6, 7, 8, 9]]])
torch.Size([2, 1, 4]) torch.Size([1, 3, 1]) torch.Size([2, 3, 4])
# (2,3,4) + (2,4) — ERROR: right-aligns to (2,3,4) vs (_,2,4) → 3 vs 2
X = torch.arange(24).reshape(2,3,4) # (2, 3, 4)
Y = torch.arange(8).reshape(2,4)    # (   2, 4)  → treated as (1,2,4) → right-aligns: 3≠2, neither is 1 → ✗
print(X, '\n', Y)
print(X.shape, Y.shape)
Z = X + Y                           # ✗
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],
 
        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]]) 
 tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])
torch.Size([2, 3, 4]) torch.Size([2, 4])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[21], line 6
      4 print(X, '\n', Y)
      5 print(X.shape, Y.shape)
----> 6 Z = X + Y                           # ✗
 
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1
# (1,3,4) + (2,1,4) — both 1-dims broadcast → (2,3,4)
X = torch.arange(12).reshape(1,3,4) # (1, 3, 4)  → broadcast to (2,3,4)
Y = torch.arange(8).reshape(2,1,4)  # (2, 1, 4)  → broadcast to (2,3,4)
Z = X + Y                           # (2,3,4)
print('', X, '\n', Y, '\n', Z)
print(X.shape, Y.shape, Z.shape)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]]) 
 tensor([[[0, 1, 2, 3]],
 
        [[4, 5, 6, 7]]]) 
 tensor([[[ 0,  2,  4,  6],
         [ 4,  6,  8, 10],
         [ 8, 10, 12, 14]],
 
        [[ 4,  6,  8, 10],
         [ 8, 10, 12, 14],
         [12, 14, 16, 18]]])
torch.Size([1, 3, 4]) torch.Size([2, 1, 4]) torch.Size([2, 3, 4])

Sources

  • PyTorch Docs: Broadcasting semantics
  • Medium: Understanding Broadcasting in PyTorch
  • StackOverflow: How does pytorch broadcasting work?
  • X - Akshay Pachaar
  • D2L: Dive into Deep Learning Data Manipulation Broadcasting