← Back to blog

How Neural Networks Train

August 19, 2025

A high-level overview of computation graphs, operators, operator fusion, and tiling—how your model code becomes fast kernels on GPUs/NPUs.

Who is this for?
Engineers who know how to write model(x) and want intuition for what actually happens between that line and the hardware.


TL;DR

  • Models become graphs (ops as nodes, tensors as edges).
  • Graphs compile into operator kernels (matmul, conv, elementwise).
  • Most speedups come from moving data less (fusion, tiling, layout), not just flops.
  • Compilers and runtimes schedule ops to maximize parallelism and locality.

From Math to Graphs

A simple layer:

y=ReLU(Wx+b)y = \mathrm{ReLU}(W x + b)

becomes a dataflow graph:

WxbyMatMul(W, x)Add(+ b)ReLU
  • Nodes: MatMul, Add, ReLU
  • Edges: tensors produced/consumed
  • Why graphs? Dependencies are explicit → enables parallel execution and whole-graph optimization.

Operators: The Building Blocks

Each node maps to a backend kernel (often vendor-tuned):

  • MatMul → cuBLAS / oneDNN / accelerated GEMM
  • Conv2D → cuDNN / MIOpen
  • Elementwise ops (Add, ReLU) → lightweight fused kernels

A forward pass is effectively: launch kernels in dependency order with shapes, strides, and pointers.

python-snippet
1# Illustration in PyTorch 2y = torch.relu(x @ W.T + b) # triggers multiple kernels under the hood

Bottleneck: Data Movement

Accelerators are compute-rich but bandwidth-bound. Writing an intermediate to DRAM and reading it back can dwarf arithmetic time.

Key idea:

  • Minimize round-trips to slow memory, keep values in registers or shared/cache as long as possible.

Operator Fusion

Goal: merge small ops to avoid extra loads/stores.

Before (3 kernels):

text-snippet
1MatMul → write C 2Add(bias) → read C, write C' 3ReLU → read C', write Y

After (1 fused kernel):

text-snippet
1for each output tile: 2 C = TileGEMM(A, B) 3 C = C + bias 4 Y = max(C, 0) 5 write Y

Benefits:

  • Fewer kernel launches
  • Fewer DRAM trips
  • Better register/cache reuse

Tiling / Blocking (Fit Work to the Memory Hierarchy)

Large tensors don’t fit in fast memory. Tiling breaks problems into sub-blocks that do:

  • GEMM: compute Cij+=Aik×BkjC_{ij} += A_{ik} \times B_{kj} over tiles of i,j,ki, j, k
  • Tiles live in shared memory / L1 cache, accumulators in registers
  • Improves arithmetic intensity (flops per byte from DRAM)
Ci0:i0+T,j0:j0+T=k0Ai0:i0+T,k0:k0+T  Bk0:k0+T,j0:j0+TC_{i_0:i_0+T,\, j_0:j_0+T} = \sum_{k_0} A_{i*0:i_0+T,\, k_0:k_0+T}\; B_{k_0:k_0+T,\, j_0:j_0+T}

Parallelism (Exploit the Graph)

  • Intra-op: split a single op across many cores/warps (e.g., tiles over output space).
  • Inter-op: run independent branches concurrently (e.g., ResNet skip path and conv path).
  • Pipelining: overlap compute and IO (prefetch next tile while computing current).

Memory Layout & Strides Matter

  • Choose layouts (e.g., NCHW, NHWC) that match kernel expectations to avoid implicit transposes.
  • Align tensors and pad to avoid bank conflicts/misaligned accesses.
  • Sometimes a layout transform upfront unlocks faster fused kernels downstream.

Example: Conv → Bias → ReLU

Naive: three passes over output feature map (conv, add bias, relu).
Fused: while writing each output tile from conv accumulation, add bias and apply ReLU in-register, then store once.

Result: less bandwidth, fewer launches, better cache hit rate.


Worked Example: A Tiny Forward Pass

Let y=ReLU(Wx+b)y = \mathrm{ReLU}(W x + b), with shapes WRM×K,xRK,bRMW \in \mathbb{R}^{M\times K}, x \in \mathbb{R}^{K}, b \in \mathbb{R}^{M}.

  1. Planner: Picks GEMM variant (e.g., tensor cores if fp16).
  2. Tiling: Block over MM in chunks TMT_M, over KK in TKT_K.
  3. Fusion: Accumulate WxWx into registers; add bb and apply ReLU before store.
  4. Schedule: Launch enough thread blocks to cover all MM-tiles; overlap global reads of WW and chunks of xx.

Where Compilers Fit (XLA, Inductor, TVM, Glow)

  • Graph IR: High-level ops (HLO, FX, Relay)
  • Passes: Canonicalize, fuse, constant-fold, layout and dtype propagation
  • Lowering: Emit kernels (Triton/CUDA) or call libraries
  • Auto-tuning: Search tile sizes, unroll factors, vectorization for your shapes/hardware

Practical Tips

  • Prefer fusable patterns: bias + activation right after GEMM/Conv.
  • Keep tensors in friendly dtypes/layouts (e.g., fp16/bf16 when safe).
  • Batch work where possible (larger batches = higher utilization).
  • Avoid shape polymorphism in hot loops when it prevents fusion/specialization.

A Note on Training vs Inference

  • Inference leans harder on fusion & quantization (stable shapes).
  • Training adds backward graphs + optimizer ops; compilers still fuse, but need to preserve grads and often deal with more dynamic shapes.

Glossary

  • Operator (Op): Primitive computation (MatMul, Conv, Add, ReLU).
  • Kernel: Compiled implementation of an op on a device.
  • Fusion: Combine ops into one kernel to reduce memory traffic.
  • Tiling/Blocking: Split work to fit fast memory and maximize reuse.
  • Layout: Tensor memory order (e.g., NCHW).
  • Scheduler: Decides op order/parallel execution respecting deps.

Key Takeaways

  • The fastest models move data the least.
  • Graphs enable global reasoning; fusion/tiling realize local efficiency.
  • Understanding these ideas makes you a better practitioner—even if you never write a kernel.