How Neural Networks Train
August 19, 2025
A high-level overview of computation graphs, operators, operator fusion, and tiling—how your model code becomes fast kernels on GPUs/NPUs.
Who is this for?
Engineers who know how to writemodel(x)and want intuition for what actually happens between that line and the hardware.
TL;DR
- Models become graphs (ops as nodes, tensors as edges).
- Graphs compile into operator kernels (matmul, conv, elementwise).
- Most speedups come from moving data less (fusion, tiling, layout), not just flops.
- Compilers and runtimes schedule ops to maximize parallelism and locality.
From Math to Graphs
A simple layer:
becomes a dataflow graph:
- Nodes:
MatMul,Add,ReLU - Edges: tensors produced/consumed
- Why graphs? Dependencies are explicit → enables parallel execution and whole-graph optimization.
Operators: The Building Blocks
Each node maps to a backend kernel (often vendor-tuned):
MatMul→ cuBLAS / oneDNN / accelerated GEMMConv2D→ cuDNN / MIOpen- Elementwise ops (
Add,ReLU) → lightweight fused kernels
A forward pass is effectively: launch kernels in dependency order with shapes, strides, and pointers.
python-snippet1# Illustration in PyTorch 2y = torch.relu(x @ W.T + b) # triggers multiple kernels under the hood
Bottleneck: Data Movement
Accelerators are compute-rich but bandwidth-bound. Writing an intermediate to DRAM and reading it back can dwarf arithmetic time.
Key idea:
- Minimize round-trips to slow memory, keep values in registers or shared/cache as long as possible.
Operator Fusion
Goal: merge small ops to avoid extra loads/stores.
Before (3 kernels):
text-snippet1MatMul → write C 2Add(bias) → read C, write C' 3ReLU → read C', write Y
After (1 fused kernel):
text-snippet1for each output tile: 2 C = TileGEMM(A, B) 3 C = C + bias 4 Y = max(C, 0) 5 write Y
Benefits:
- Fewer kernel launches
- Fewer DRAM trips
- Better register/cache reuse
Tiling / Blocking (Fit Work to the Memory Hierarchy)
Large tensors don’t fit in fast memory. Tiling breaks problems into sub-blocks that do:
- GEMM: compute over tiles of
- Tiles live in shared memory / L1 cache, accumulators in registers
- Improves arithmetic intensity (flops per byte from DRAM)
Parallelism (Exploit the Graph)
- Intra-op: split a single op across many cores/warps (e.g., tiles over output space).
- Inter-op: run independent branches concurrently (e.g., ResNet skip path and conv path).
- Pipelining: overlap compute and IO (prefetch next tile while computing current).
Memory Layout & Strides Matter
- Choose layouts (e.g., NCHW, NHWC) that match kernel expectations to avoid implicit transposes.
- Align tensors and pad to avoid bank conflicts/misaligned accesses.
- Sometimes a layout transform upfront unlocks faster fused kernels downstream.
Example: Conv → Bias → ReLU
Naive: three passes over output feature map (conv, add bias, relu).
Fused: while writing each output tile from conv accumulation, add bias and apply ReLU in-register, then store once.
Result: less bandwidth, fewer launches, better cache hit rate.
Worked Example: A Tiny Forward Pass
Let , with shapes .
- Planner: Picks GEMM variant (e.g., tensor cores if fp16).
- Tiling: Block over in chunks , over in .
- Fusion: Accumulate into registers; add and apply ReLU before store.
- Schedule: Launch enough thread blocks to cover all -tiles; overlap global reads of and chunks of .
Where Compilers Fit (XLA, Inductor, TVM, Glow)
- Graph IR: High-level ops (HLO, FX, Relay)
- Passes: Canonicalize, fuse, constant-fold, layout and dtype propagation
- Lowering: Emit kernels (Triton/CUDA) or call libraries
- Auto-tuning: Search tile sizes, unroll factors, vectorization for your shapes/hardware
Practical Tips
- Prefer fusable patterns: bias + activation right after GEMM/Conv.
- Keep tensors in friendly dtypes/layouts (e.g., fp16/bf16 when safe).
- Batch work where possible (larger batches = higher utilization).
- Avoid shape polymorphism in hot loops when it prevents fusion/specialization.
A Note on Training vs Inference
- Inference leans harder on fusion & quantization (stable shapes).
- Training adds backward graphs + optimizer ops; compilers still fuse, but need to preserve grads and often deal with more dynamic shapes.
Glossary
- Operator (Op): Primitive computation (MatMul, Conv, Add, ReLU).
- Kernel: Compiled implementation of an op on a device.
- Fusion: Combine ops into one kernel to reduce memory traffic.
- Tiling/Blocking: Split work to fit fast memory and maximize reuse.
- Layout: Tensor memory order (e.g., NCHW).
- Scheduler: Decides op order/parallel execution respecting deps.
Key Takeaways
- The fastest models move data the least.
- Graphs enable global reasoning; fusion/tiling realize local efficiency.
- Understanding these ideas makes you a better practitioner—even if you never write a kernel.