GPU Kernels & Compilers: A Compressed Interview Course
From linear algebra to FlashAttention, MoE, and the XLA lowering pipeline — the working set for a GPU-kernel / ML-compiler interview, told with diagrams, the underlying math, and real kernels. Weighted toward the compiler / XLA / PTX axis.
What they actually test START HERE
A kernel/compiler interview is not a LeetCode screen. The interviewer is probing whether you can reason about where the time and the bytes go, and whether you can move fluidly between four layers of the stack without losing the thread.
The five questions behind every question
- Where are the FLOPs and where are the bytes? Is this op compute-bound, memory-bound, or communication-bound? (Module 04.)
- What does the memory hierarchy force? Tiling, fusion, recompute-vs-store — all fall out of HBM << SRAM << registers.
- What does the hardware demand? Tensor-core shapes, alignment, coalescing, occupancy, precision/scaling.
- What can the compiler do for you, and where does it stop? Fusion boundaries, why some patterns need a hand-written kernel.
- How do you know it's right and fast? Numerical checks vs a reference, autotuning, regression tests, profiling.
Each module gives you the intuition (a diagram), the math you'd write on the whiteboard, and a real kernel or IR snippet. The green boxes are interview answers you can say almost verbatim; the orange boxes are gotchas interviewers love to poke at. Read top-to-bottom once; thereafter jump via the sidebar.
“My instinct on any op is to compute its arithmetic intensity — FLOPs per byte — and place it on the roofline. That tells me whether to chase tensor-core utilization or to fuse and cut HBM traffic. Then I check whether the compiler already fuses it; if it can't cross that boundary, I write the kernel.”
Linear algebra you actually need FOUNDATION
You don't need spectral theory for these interviews. You need to be fluent in shapes, reductions, and the fact that almost everything reduces to matmul.
The four core operations
| Op | Shapes | FLOPs | Arithmetic intensity |
|---|---|---|---|
| Dot product | (K)·(K) → scalar | 2K | ~1 (memory-bound) |
| Matrix–vector (GEMV) | (M,K)·(K) → (M) | 2MK | ~1–2 (memory-bound) |
| Matrix–matrix (GEMM) | (M,K)·(K,N) → (M,N) | 2MNK | scales with tile size (compute-bound) |
| Outer product | (M)⊗(N) → (M,N) | MN | <1 (very memory-bound) |
Matrix multiplication \(C = AB\) with \(A\in\mathbb{R}^{M\times K}\), \(B\in\mathbb{R}^{K\times N}\):
\[ C_{ij} = \sum_{k=1}^{K} A_{ik}\,B_{kj} \qquad \text{cost} = 2MNK \text{ FLOPs}\]
The factor of 2 is one multiply + one add per term. Memorize 2MNK — you will use it to size every GEMM on the roofline.
Why “everything is a matmul”
A linear layer is \(Y = XW^\top + b\). A convolution lowers to a matmul (im2col / implicit GEMM). Attention is two matmuls with a softmax between them. This is why a fast, well-scheduled GEMM is the single most valuable kernel on the planet — and why tensor cores exist to do nothing else.
\(B\) vs \(B^\top\) changes the memory access pattern, not the math. In a GEMM you choose whether to read \(B\) transposed; tensor-core paths often want a particular layout (e.g. TN: A row-major, B column-major). “Transpose” in a fused graph is frequently a no-op metadata change or folded into the load, not a separate kernel.
“Broadcasting, bias-add, transpose, reshape — these are bookkeeping. The compute lives in the contractions. So I budget by counting 2MNK for the matmuls and treating the elementwise glue as memory traffic to be fused away.”
Matrix multiplication & tiling CORE KERNEL
A naive GEMM is correct and slow: each output element re-reads a full row of \(A\) and column of \(B\) from HBM. The whole game is reuse — load a block once into fast memory, do as much math as possible on it, then move on.
The memory hierarchy that forces tiling
| Level | ~Latency | ~Bandwidth | Size |
|---|---|---|---|
| Registers | ~1 cycle | ~10s TB/s | ~256 KB / SM |
| Shared mem / SRAM | ~20 cycles | ~10s TB/s | ~100–228 KB / SM |
| L2 cache | ~200 cycles | several TB/s | tens of MB |
| HBM (global) | ~400–600 cycles | ~1.5–8 TB/s | tens of GB |
The jump from HBM to SRAM is ~10–100×. Tiling exists to pay the HBM cost once per block and amortize it over \(O(\text{tile})\) FLOPs.
A BM×BN output tile over a BK step does 2·BM·BN·BK FLOPs while loading (BM·BK + BK·BN) elements. Intensity ≈ BM·BN / (BM+BN). Bigger square tiles → higher intensity → you cross from memory-bound into compute-bound. This is the reason tiles are square-ish and as large as registers/SRAM allow.
A real tiled GEMM in Triton
Triton lets you write the tile schedule in Python while it handles vectorization, shared-memory staging, and tensor-core selection. This is the canonical kernel to be able to sketch.
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(A, B, C, M, N, K,
sam, sak, sbk, sbn, scm, scn,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BM + tl.arange(0, BM) # rows of C this block owns
offs_n = pid_n * BN + tl.arange(0, BN) # cols of C this block owns
offs_k = tl.arange(0, BK)
a_ptrs = A + offs_m[:, None] * sam + offs_k[None, :] * sak
b_ptrs = B + offs_k[:, None] * sbk + offs_n[None, :] * sbn
acc = tl.zeros((BM, BN), dtype=tl.float32) # accumulate in FP32
for k in range(0, K, BK):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k, other=0.0)
acc += tl.dot(a, b) # ← maps to tensor-core MMA
a_ptrs += BK * sak
b_ptrs += BK * sbk
c = acc.to(C.dtype.element_ty)
c_ptrs = C + offs_m[:, None] * scm + offs_n[None, :] * scn
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
The whole kernel is: pick your output tile → loop over K loading A/B tiles → tl.dot accumulates in FP32 → store once. tl.dot is what the backend turns into mma/wgmma/tcgen05.
Tensor cores: the MMA the tl.dot becomes
Tensor cores execute a small fixed-shape matrix-multiply-accumulate (MMA) per instruction. You don't get to pick arbitrary shapes — the hardware shape constrains your tiling.
| Generation | Instruction | Issued by | Representative shape (m·n·k) |
|---|---|---|---|
| Ampere | mma | 1 warp | m16 n8 k16 (or k8) |
| Hopper | wgmma | warpgroup (4 warps) | m64 nN k16, N∈{16…256} |
| Blackwell | tcgen05.mma | single thread → tensor memory | up to m256 n256 k16 across 2 SMs |
MMA shapes mean your BM/BN/BK should be multiples of the hardware tile (e.g. 16) and your pointers/leading dims should be 16-byte aligned, or you fall off the tensor-core path onto slow CUDA-core code. Hopper's wgmma also needs operands staged in shared memory in a specific swizzled layout — “why is my GEMM at 30% of peak?” is very often a layout/alignment answer.
“GEMM performance is a reuse problem. I tile the output, stage A/B tiles in shared memory, and accumulate in FP32 across the K loop so each HBM byte feeds many MMAs. The tile sizes are bounded above by registers/SRAM and below by the tensor-core MMA shape and alignment requirements.”
GPU execution model FOUNDATION
You can't reason about a kernel without the cost model underneath it. The GPU is a throughput machine: it hides latency with massive parallelism rather than caches and out-of-order tricks.
The five things that govern speed
- Occupancy — enough resident warps per SM to hide HBM latency. Bounded by registers/thread and shared memory/block. Too many registers → fewer warps → stalls show through.
- Coalescing — the 32 lanes of a warp should touch contiguous global addresses so the hardware merges them into a few 128-byte transactions. Strided/scattered access multiplies your byte traffic.
- Shared-memory bank conflicts — SMEM has 32 banks; if lanes hit the same bank with different addresses, accesses serialize. Padding / swizzling fixes it.
- Divergence — an
ifthat splits a warp executes both sides masked. Branchy per-lane code wastes throughput. - Latency hiding — the GPU tolerates ~500-cycle HBM latency by having other warps ready to issue. This is why “make it more parallel” often beats “make each thread faster.”
100% occupancy is not the target; enough occupancy to hide latency is. A register-heavy GEMM at 50% occupancy can beat a 100%-occupancy version because the registers hold a bigger accumulator tile (more reuse). Interviewers like to hear that you optimize for sustained throughput, not for an occupancy number.
“I think of the SM as a latency-hiding engine. My job is to give it coalesced loads, conflict-free shared-memory access, and enough independent work (warps + ILP) to cover HBM latency — while keeping the accumulator in registers for reuse.”
Roofline & the three bounds MENTAL MODEL
The roofline is the single most useful diagram in this whole field, and it's the lens behind half the interview questions. It answers: is my kernel limited by compute, by memory bandwidth, or by communication?
Define arithmetic intensity \(I = \dfrac{\text{FLOPs}}{\text{bytes moved}}\) (FLOP/byte). Achievable performance is
\[ P = \min\big(\underbrace{P_{\text{peak}}}_{\text{compute roof}},\ \underbrace{B \cdot I}_{\text{memory roof}}\big) \]
The ridge point is \(I^{*} = P_{\text{peak}} / B\). Below \(I^{*}\) you are memory-bound; above it, compute-bound. On an H100-class part \(P_{\text{peak}}\!\sim\!10^{15}\) FLOP/s and \(B\!\sim\!3\times10^{12}\) B/s, so \(I^{*}\approx 300\) FLOP/byte — a sobering number.
The three bounds, by example
| Bound | Limited by | Typical ops | The fix |
|---|---|---|---|
| Compute | peak FLOP/s (tensor cores) | large square GEMM, conv | better tiling, higher MMA utilization, lower precision |
| Memory | HBM bandwidth | softmax, LayerNorm, bias+act, GEMV, attention's I/O | fuse to cut round-trips; recompute instead of store |
| Communication | NVLink / network BW & latency | all-reduce, all-to-all (MoE), pipeline/tensor parallel | overlap comm with compute; reduce message volume; topology-aware collectives |
For an (M,N) tensor in BF16: ~2MN FLOPs (a couple ops/element), but 2·M·N·2 bytes read+write. Intensity ≈ 2MN / (4MN) = 0.5 FLOP/byte — far below the ~300 ridge. Verdict: deeply memory-bound. Therefore: never launch it as its own kernel if you can fuse it onto the producer GEMM's epilogue. This is the whole motivation for Module 05.
“First question I ask of any op: what's its arithmetic intensity versus the ridge point? Elementwise and normalization ops sit far left — memory-bound — so the win is fusion, not faster math. Big GEMMs sit right of the ridge — compute-bound — so the win is MMA utilization. Collectives are a separate, comm-bound roof where the win is overlap.”
Fused element-wise kernels FUSION 101
Elementwise and reduction ops are memory-bound (Module 04). The lever is fusion: do multiple ops in one pass so intermediate tensors never touch HBM.
Fused bias + activation (epilogue)
The pattern: compute the GEMM tile, then apply bias and the activation before the single store. As a standalone elementwise kernel it looks like this; in practice you splice it into the matmul epilogue.
@triton.jit
def bias_gelu_kernel(X, Bias, Y, M, N, sxm, sxn, BM: tl.constexpr, BN: tl.constexpr):
pid_m, pid_n = tl.program_id(0), tl.program_id(1)
offs_m = pid_m * BM + tl.arange(0, BM)
offs_n = pid_n * BN + tl.arange(0, BN)
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
ptr = X + offs_m[:, None] * sxm + offs_n[None, :] * sxn
x = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
b = tl.load(Bias + offs_n, mask=offs_n < N, other=0.0).to(tl.float32)
x = x + b[None, :]
# tanh approximation of GELU
y = 0.5 * x * (1.0 + tl.math.tanh(0.7978845608 * (x + 0.044715 * x * x * x)))
tl.store(Y + offs_m[:, None] * sxm + offs_n[None, :] * sxn,
y.to(Y.dtype.element_ty), mask=mask)
Note FP32 math on BF16/FP16 storage — compute in high precision, store in low. That theme repeats everywhere.
Row reduction — the other primitive
Softmax, LayerNorm, RMSNorm, and loss functions are all row reductions followed by an elementwise pass. The skill is doing the reduction in one block with a tree/warp reduction, keeping the row resident.
@triton.jit
def row_sum_kernel(X, Out, N, BN: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BN)
x = tl.load(X + row * N + cols, mask=cols < N, other=0.0)
tl.store(Out + row, tl.sum(x, axis=0)) # tl.sum lowers to a warp/tree reduction
If the row is wider than one block can hold (N > tile), you need a two-stage reduction (partial sums per block → combine) or the online/streaming trick from Module 06. Interviewers test whether you notice the row doesn't fit and reach for a numerically-stable combine instead of a naive second kernel.
“Anything memory-bound, I try to fuse into a neighboring compute-bound kernel's epilogue or prologue. A bias+GELU after a GEMM should never be its own kernel — it rides the GEMM tile in registers, turning three HBM passes into one.”
Softmax, LayerNorm, RMSNorm FUSED FORWARD
These three are the canonical fused-reduction kernels and they show up constantly. Know the math, the numerical-stability trick, and the one-pass kernel for each.
Numerically-stable softmax
\[ \text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} \;=\; \frac{e^{x_i-m}}{\sum_j e^{x_j-m}},\quad m=\max_j x_j \]
Subtracting the row max \(m\) is mandatory: \(e^{x_i}\) overflows for \(x_i\gtrsim 88\) in FP32 (and ~11 in FP16!). The shifted form is identical mathematically and bounded by 1.
@triton.jit
def softmax_kernel(X, Y, N, BN: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BN)
mask = cols < N
x = tl.load(X + row * N + cols, mask=mask, other=-float("inf")).to(tl.float32)
x = x - tl.max(x, axis=0) # stability shift
num = tl.exp(x)
y = num / tl.sum(num, axis=0)
tl.store(Y + row * N + cols, y.to(Y.dtype.element_ty), mask=mask)
Online (streaming) softmax — the FlashAttention seed
When the row doesn't fit in SRAM, you process it in blocks while maintaining a running max \(m\) and running denominator \(\ell\), rescaling the accumulator when the max grows. This recurrence is exactly what makes FlashAttention possible (Module 09).
For a new block with local max \(m_{\text{blk}}\) and local sum \(\ell_{\text{blk}}=\sum e^{x-m_{\text{blk}}}\):
\[ m_{\text{new}}=\max(m,m_{\text{blk}}),\quad \ell_{\text{new}}=e^{\,m-m_{\text{new}}}\ell + e^{\,m_{\text{blk}}-m_{\text{new}}}\ell_{\text{blk}} \]
The factor \(e^{m-m_{\text{new}}}\) corrects the old partial sum for the new max. Carry the same correction through the weighted value accumulator and you never materialize the full row.
LayerNorm vs RMSNorm
| LayerNorm | RMSNorm | |
|---|---|---|
| Formula | \(y=\dfrac{x-\mu}{\sqrt{\sigma^2+\epsilon}}\,\gamma+\beta\) | \(y=\dfrac{x}{\sqrt{\frac1d\sum_i x_i^2+\epsilon}}\,\gamma\) |
| Stats needed | mean and variance (2 reductions) | mean-square (1 reduction) |
| Recenters? | yes (subtracts μ) | no |
| Params | γ and β | γ only |
| Why used | classic, stable | cheaper, ~same quality — default in LLaMA-style models |
@triton.jit
def rmsnorm_kernel(X, W, Y, N, eps, BN: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BN); mask = cols < N
x = tl.load(X + row * N + cols, mask=mask, other=0.0).to(tl.float32)
ms = tl.sum(x * x, axis=0) / N # mean of squares (1 pass)
inv = 1.0 / tl.sqrt(ms + eps)
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
y = x * inv * w
tl.store(Y + row * N + cols, y.to(Y.dtype.element_ty), mask=mask)
RMSNorm drops the mean subtraction and β — one reduction instead of two, no second statistic to carry. That's the whole reason it's popular.
Summing squares of BF16 values in BF16 loses catastrophic precision (BF16 has 8 mantissa bits). Load → upcast to FP32 → reduce → normalize → downcast on store. A variance computed in low precision is a classic source of silent NaNs/accuracy regressions, and a favorite interview “why is training diverging?” prompt.
“Softmax and the norms are one-pass fused reductions: load the row once, reduce in FP32, normalize, store once. For softmax I subtract the row max for stability; when the row exceeds SRAM I switch to the online-softmax recurrence — which is the exact mechanism FlashAttention uses to avoid the N×N matrix.”
Low precision & block scaling NUMERICS
Lower precision means more FLOP/s and less HBM traffic — it moves you up and right on the roofline. The catch is dynamic range and quantization error. This module is the numerics an interviewer will push hardest on for a 2025-era kernel role.
Float anatomy & the format zoo
A float is sign · exponent · mantissa. Exponent bits buy range; mantissa bits buy precision. Every low-precision format is a different bet on that trade.
| Format | Bits (S/E/M) | Buys | Used for |
|---|---|---|---|
| FP32 | 1 / 8 / 23 | reference precision | accumulation, master weights |
| TF32 | 1 / 8 / 10 | FP32 range, less mantissa | tensor-core matmul inputs (Ampere+) |
| BF16 | 1 / 8 / 7 | FP32 range, low precision | training (no loss scaling needed) |
| FP16 | 1 / 5 / 10 | precision, narrow range | inference; training needs loss scaling |
| FP8 E4M3 | 1 / 4 / 3 | more precision | fwd activations & weights |
| FP8 E5M2 | 1 / 5 / 2 | more range | gradients (backward) |
| FP4 E2M1 | 1 / 2 / 1 | 4 bits → 16 values total | weights/activations with block scaling |
Same 16 bits, opposite bets. BF16 keeps FP32's 8 exponent bits, so its range matches FP32 — gradients rarely overflow, so training “just works” without loss scaling. FP16 spends those bits on mantissa (10 vs 7), so it's more precise but its small range means training needs loss scaling (multiply the loss up so small gradients survive, divide back before the update). Rule of thumb: BF16 for training, FP16 where precision matters and range is controlled.
FP16/BF16 matmul from first principles
The non-negotiable rule: inputs low precision, accumulate in FP32. Tensor cores are built for exactly this — they take FP16/BF16 (or FP8) operands and accumulate into an FP32 register tile. Summing thousands of products in BF16 would lose the small terms; FP32 accumulation preserves them.
acc in the Triton GEMM) and casts on store.In PTX the accumulator type is explicit — the .f32 in the MMA opcode is the FP32 register tile:
// Ampere MMA: BF16 inputs, FP32 accumulate
mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {d0,d1,d2,d3}, {a..}, {b..}, {c0,c1,c2,c3};
// shape acc A B acc
FP8, FP4, and why FP4 requires scaling
FP4 E2M1 has just 16 representable values spanning roughly ±6. You cannot represent a realistic weight distribution with that directly — so you scale. Each small block of values shares a scale factor that maps the block's range into FP4's tiny range; you store the 4-bit values plus the per-block scale.
MXFP4 vs NVFP4 — know the difference cold
| MXFP4 (OCP microscaling) | NVFP4 (Blackwell) | |
|---|---|---|
| Element | FP4 E2M1 | FP4 E2M1 |
| Block size | 32 values | 16 values |
| Scale format | E8M0 (power-of-two only) | FP8 E4M3 (fractional) + per-tensor FP32 |
| Levels of scaling | one (per-block) | two (per-block + per-tensor) |
| Error | higher (coarser block, pow-2 scale) | lower (finer block, fractional scale) |
| Backing | open standard (AMD, Intel, NV, …) | NVIDIA Blackwell-native |
NVFP4's two wins: a smaller block (16) narrows the dynamic range each scale must cover, and a fractional E4M3 scale (not just powers of two) fits that range more tightly. Both reduce quantization error; reports put MXFP4 at needing materially more training tokens to match NVFP4 loss.
Round-to-nearest in a format with mantissa step \(\Delta\) gives per-element error \(\lesssim \Delta/2\) relative to the block scale. Total error ≈ (scale mismatch to the block's true range) × (mantissa coarseness). Shrinking the block tightens the first factor; more mantissa or fractional scales tightens the second. That's the entire design space.
Tensor-core constraints & block-scaled matmul
Low-precision MMAs impose hard rules your kernel must satisfy:
- Format pairs are fixed — the MMA opcode names operand and accumulator types; you can't feed arbitrary combinations.
- K must align to the MMA K and to the scaling block — e.g. FP8/FP4 paths want K as a multiple of the block (16/32) so each MMA step consumes whole blocks with their scale.
- Scales have a prescribed layout — Blackwell's block-scaled MMA reads the scale tensor in a specific interleaved order; getting that layout wrong is a top cause of “correct in NumPy, garbage on device.”
# Block-scaled matmul, conceptual inner loop (FP4 weights + per-block scales)
acc = tl.zeros((BM, BN), tl.float32)
for k in range(0, K, BLOCK): # BLOCK = 16 (NVFP4) or 32 (MXFP4)
a4 = load_fp4(A, ...) # packed 4-bit
b4 = load_fp4(B, ...)
sa = load_scale(Ascale, ...) # one scale per block
sb = load_scale(Bscale, ...)
a = dequant(a4) * sa # → higher precision
b = dequant(b4) * sb
acc += tl.dot(a, b) # FP32 accumulate
# On Blackwell this whole block+scale path is a single hardware-scaled MMA.
People assume FP4 “loses 12 bits of math.” In practice the accumulator is still FP32; the loss is in representing each input. So accuracy work concentrates on scaling (block size, scale format, outlier handling), not on the matmul's adder. If asked “how do you keep FP4 accurate?”, talk scaling and outliers, not accumulation.
“Lower precision is a roofline move: more FLOP/s, fewer bytes. BF16 for training because its range matches FP32; FP8 split into E4M3 forward and E5M2 for gradients; FP4 only with block scaling. The accumulator stays FP32 throughout — the accuracy game is the scaling scheme. NVFP4 beats MXFP4 by using a 16-element block and a fractional E4M3 scale instead of 32 elements and a power-of-two scale.”
Matmul in practice SHAPES · TUNING · TESTS
The same GEMM code performs wildly differently across shapes. Recognizing the shape regime — and tuning, then locking it down with tests — is the practical core of the job.
Three shape regimes
| Shape | Example | Regime | Strategy |
|---|---|---|---|
| Square | 4096³ | compute-bound | large tiles, max tensor-core util |
| Tall-skinny | (65536, 128)·(128, 256) | K too small to fill tile | split-K / stream-K to expose parallelism |
| MLP prefill | (B·S, d)·(d, 4d) | compute-bound (large B·S) | fuse bias+act epilogue (Module 05) |
| Decode (M=1) | (1, d)·(d, 4d) | memory-bound (GEMV) | weight-only quant; batch requests |
When K is large but M·N is small (few output tiles → few blocks → idle SMs), partition the K loop across multiple blocks that each compute a partial sum, then reduce (atomic add or a second pass). It trades an extra reduction for parallelism — essential for tall-skinny shapes. Stream-K generalizes this to balance work across all SMs.
Autotuning
The best BM/BN/BK, num_warps, and pipeline num_stages depend on shape, dtype, and architecture. You don't guess — you search, then cache the winner per shape key.
configs = [
triton.Config({'BM':128,'BN':256,'BK':64}, num_warps=8, num_stages=3),
triton.Config({'BM':128,'BN':128,'BK':64}, num_warps=8, num_stages=4),
triton.Config({'BM':64, 'BN':128,'BK':64}, num_warps=4, num_stages=4),
triton.Config({'BM':64, 'BN':64, 'BK':32}, num_warps=4, num_stages=3),
]
@triton.autotune(configs=configs, key=['M','N','K']) # re-tune when shape changes
@triton.jit
def matmul_kernel(...):
...
num_stages is software pipeliningnum_stages controls how many K-loop iterations are prefetched (double/triple buffering of shared memory) to overlap global loads with MMAs. More stages hides latency but costs shared memory — so it interacts with tile size and occupancy. “Why did bigger tiles get slower?” is often “you ran out of SMEM for the pipeline and stages dropped.”
Numerical checks & regression testing
A kernel isn't done until it's provably correct against a reference and guarded against regressions. Tolerances must match the dtype — BF16 results are not bit-exact to FP32.
import torch
def check(M, N, K, dtype):
a = torch.randn(M, K, device='cuda', dtype=dtype)
b = torch.randn(K, N, device='cuda', dtype=dtype)
ref = (a.float() @ b.float()) # FP32 reference
out = triton_matmul(a, b).float()
# tolerance scales with K (accumulated rounding) and with dtype
rtol = {torch.float16: 1e-2, torch.bfloat16: 2e-2, torch.float32: 1e-4}[dtype]
torch.testing.assert_close(out, ref, rtol=rtol, atol=rtol * K**0.5)
# regression matrix: shapes × dtypes, run in CI; also assert TFLOP/s ≥ baseline
for dt in (torch.float16, torch.bfloat16):
for (M,N,K) in [(4096,4096,4096), (65536,256,128), (1,4096,4096)]:
check(M, N, K, dt)
Two regressions to catch: correctness (output drifts outside tolerance after a refactor) and performance (achieved TFLOP/s or GB/s drops below a recorded baseline for a given shape). CI runs the shape×dtype matrix, compares to stored numbers, and fails the build on either. This is exactly the “autotuning + regression testing + numerical checks” triad from your topic list.
“I classify the shape first. Square is compute-bound — chase MMA utilization. Tall-skinny needs split-K to fill the machine. Decode is M=1, memory-bound, so I quantize weights and batch requests. Then I autotune tile/warps/stages per shape, and lock it in with a CI matrix that checks both numerical tolerance against an FP32 reference and TFLOP/s against a baseline.”
Attention & FlashAttention THE BIG ONE
Attention is two matmuls with a softmax between them. The whole performance story is that the middle matrix is \(S\times S\) — quadratic in sequence length — and you must never write it to HBM.
\[ \text{Attn}(Q,K,V)=\text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right)V \]
Shapes per head: \(Q,K,V\in\mathbb{R}^{S\times d}\). Scores \(QK^\top\in\mathbb{R}^{S\times S}\). The \(1/\sqrt{d}\) keeps the logits' variance ~1 so softmax doesn't saturate. Batched over \(B\) sequences and \(H\) heads: tensors are \((B,H,S,d)\).
FlashAttention: tile + online softmax, never materialize S×S
FlashAttention keeps a block of \(Q\) in SRAM and streams blocks of \(K,V\) past it, updating the output with the online-softmax recurrence from Module 06. The \(S\times S\) matrix exists only one tile at a time, in SRAM.
Standard attention does \(\Theta(Nd + N^2)\) HBM accesses (it touches the \(N\times N\) matrix). FlashAttention does \(\Theta(N^2 d^2 / M)\) accesses, where \(M\) is SRAM size — and this is IO-optimal (no algorithm does asymptotically fewer for all \(M\)). For real \(d,M\) that's up to ~9× fewer HBM accesses and up to ~7.6× faster than a standard PyTorch attention on GPT-2 — and memory that is linear, not quadratic, in sequence length. Same exact result, just I/O-aware.
FlashAttention forward kernel (Triton-style)
@triton.jit
def flash_attn_fwd(Q, K, V, O, scale, S, D,
BQ: tl.constexpr, BK: tl.constexpr, CAUSAL: tl.constexpr):
qb = tl.program_id(0) # this program owns one Q block
offs_q = qb * BQ + tl.arange(0, BQ)
offs_d = tl.arange(0, D)
q = tl.load(Q + offs_q[:, None]*D + offs_d[None, :]) # BQ×D stays resident
m = tl.full((BQ,), -float("inf"), tl.float32) # running row max
l = tl.zeros((BQ,), tl.float32) # running denominator
acc = tl.zeros((BQ, D), tl.float32) # running output
kv_end = (qb + 1) * BQ if CAUSAL else S # causal: skip future blocks
for kb in range(0, kv_end, BK):
offs_k = kb + tl.arange(0, BK)
k = tl.load(K + offs_k[:, None]*D + offs_d[None, :])
v = tl.load(V + offs_k[:, None]*D + offs_d[None, :])
s = tl.dot(q, tl.trans(k)) * scale # BQ×BK scores (in SRAM only)
if CAUSAL: # mask the diagonal block
s = tl.where(offs_q[:, None] >= offs_k[None, :], s, -float("inf"))
m_new = tl.maximum(m, tl.max(s, axis=1)) # online softmax update
p = tl.exp(s - m_new[:, None])
alpha = tl.exp(m - m_new) # correction for old acc/denom
l = l * alpha + tl.sum(p, axis=1)
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
m = m_new
acc = acc / l[:, None] # finalize
tl.store(O + offs_q[:, None]*D + offs_d[None, :], acc.to(O.dtype.element_ty))
The three carried values m, l, acc and the alpha rescale are the entire trick. Drop the FP32 accumulators or the rescale and you get NaNs or wrong sums.
Causal masking, head dim, and seq-len sweeps
- Causal masking: token \(i\) attends only to \(j\le i\). Whole K/V blocks above the diagonal are skipped (the
kv_endbound) — roughly half the work — and only the diagonal block needs the per-element triangular mask. Forgetting to skip the upper blocks leaves an easy 2× on the table. - Head dimension \(d\): usually 64 or 128. It sets the SRAM tile footprint (Q,K,V blocks are \(B\!\cdot\!d\)); larger \(d\) raises arithmetic intensity but shrinks how many rows fit per block. Small \(d\) attention is more memory-bound.
- Sequence-length sweep: standard attention's memory grows \(O(S^2)\) and OOMs; FlashAttention's memory is \(O(S)\) while compute stays \(O(S^2 d)\). So as \(S\) grows the gap widens — the canonical benchmark plot you should be able to describe.
Check against a reference softmax(QKᵀ/√d)V computed in FP32. Two subtle bugs: (1) using the local block max instead of the running max in the rescale, and (2) forgetting the alpha correction on acc — both pass small random tests but fail on adversarial inputs (one large logit far down the sequence). Test with a planted outlier logit, not just randn.
“Standard attention is memory-bound because it round-trips the S×S scores through HBM. FlashAttention keeps a Q block in SRAM, streams K/V, and uses the online-softmax recurrence to fold each block into a running (max, denom, output) — so the S×S matrix never hits HBM. That's O(N²d²/M) accesses, IO-optimal, linear memory in S. Causal masking lets me skip the future blocks for roughly a 2× on top.”
Mixture of Experts & distributed SCALE-OUT
MoE replaces a dense FFN with many experts and routes each token to only a few. You get more parameters at constant FLOPs — but you buy a communication problem, and that's where the kernel/systems interview goes.
Routing & top-k gating
\[ g = \text{softmax}(x W_r),\quad \mathcal{E}=\text{top-}k(g),\quad y=\sum_{e\in\mathcal{E}} g_e\,\text{Expert}_e(x) \]
A tiny linear \(W_r\) scores experts per token; keep the top-\(k\) (usually 1–2), renormalize their gate weights, and combine the chosen experts' outputs weighted by those gates.
logits = x @ Wr # [T, E] token→expert scores
probs = logits.softmax(-1)
topv, topi = probs.topk(k, dim=-1) # [T, k] weights and expert ids
topv = topv / topv.sum(-1, keepdim=True) # renormalize the kept gates
# dispatch: sort/permute tokens so each expert sees a contiguous run
order = topi.view(-1).argsort()
Dispatch / combine, grouped GEMM, all-to-all
- Dispatch: permute tokens so each expert gets a contiguous block; when experts live on other GPUs (expert parallelism), this permute is an all-to-all — every device sends each other device its tokens.
- Grouped GEMM: experts share weight shapes but receive different token counts (ragged M). One kernel does a batch of variable-M GEMMs without launching one kernel per expert and without padding waste.
- Combine: a second all-to-all routes outputs back to the originating device, then the weighted sum applies the gate weights.
Load balancing, expert parallelism, overlap
- Load balancing: routing can pile tokens onto a few popular experts → idle GPUs + stragglers. Fixes: an auxiliary load-balancing loss that pushes the router toward uniform usage, a capacity factor (max tokens/expert, drop or reroute overflow), and noisy/again gating.
- Expert parallelism (EP): shard experts across GPUs. Each token may need an expert on another device → the all-to-alls. EP is usually combined with data/tensor/pipeline parallelism.
- Overlap: the killer optimization. While expert \(e\)'s GEMM runs, prefetch/communicate the next group's tokens — overlap all-to-all with grouped-GEMM compute so the comm-bound dispatch hides behind compute.
Bottleneck taxonomy & a microbenchmark
| Stage | Bound | Why | Lever |
|---|---|---|---|
| Router top-k | cheap / memory | tiny GEMM + sort | rarely the bottleneck |
| Dispatch all-to-all | communication | cross-device token shuffle | overlap; fewer bytes; topology-aware |
| Expert grouped GEMM | compute (prefill) / memory (decode) | big GEMMs vs M=1 GEMV | tile/MMA util; weight-only quant at decode |
| Combine all-to-all | communication | route outputs back | overlap with next layer |
# Dispatch microbenchmark: is your MoE comm- or compute-bound?
import torch, torch.distributed as dist, time
def bench_all_to_all(T, d, iters=50):
x = torch.randn(T, d, device='cuda', dtype=torch.bfloat16)
out = torch.empty_like(x)
torch.cuda.synchronize(); t0 = time.time()
for _ in range(iters):
dist.all_to_all_single(out, x) # the dispatch primitive
torch.cuda.synchronize()
bytes_moved = T * d * 2 * iters # BF16 = 2B/elt
gbps = bytes_moved / (time.time() - t0) / 1e9
return gbps # compare to NVLink/network peak
# If gbps ≈ link peak → comm-bound (overlap it). If experts' TFLOP/s ≈ tensor-core
# peak → compute-bound. Whichever saturates first is your bottleneck.
The all-to-all completes only when the slowest rank finishes, so one overloaded expert stalls everyone (straggler). The fix is upstream — load-balancing loss and capacity limits — not a faster collective. Measure per-expert token counts before blaming the network. This is the favorite MoE follow-up.
“MoE turns a dense FFN into routing + grouped GEMM + two all-to-alls. The experts are compute-bound at prefill and memory-bound at decode; the dispatch/combine are comm-bound. The biggest wins are overlapping the all-to-all with expert compute and keeping experts balanced via an aux loss and capacity factor — because the collective waits on the slowest rank.”
Compiler: XLA / StableHLO / PTX / MLIR EMPHASIS
This is the axis to win on. The story is one pipeline: a framework graph is progressively lowered through stable, then internal, then hardware IRs, and the compiler's superpower — and its limit — is fusion.
StableHLO — the portability contract
Purpose: StableHLO is a stable, versioned operation set (an MLIR dialect) that sits between ML frameworks (JAX, PyTorch, TensorFlow) and ML compilers (XLA, IREE). It's the interop layer — a frontend can serialize StableHLO with backward-compatibility guarantees and hand it to any compiler. Think of it as “the bytecode of ML graphs.”
// StableHLO for y = exp(a + b) — note explicit tensor types & shapes
func.func @f(%a: tensor<4x8xf32>, %b: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = stablehlo.add %a, %b : tensor<4x8xf32>
%1 = stablehlo.exponential %0 : tensor<4x8xf32>
return %1 : tensor<4x8xf32>
}
HLO & fusion — XLA's single most important pass
XLA lowers StableHLO into its internal HLO and runs optimization passes. The one that matters most is fusion: merge multiple ops into one kernel so intermediate tensors stay in registers/SRAM instead of round-tripping HBM. Since most non-GEMM ML ops are memory-bound (Module 04), fusion is where XLA earns its speedups.
The add then exp above as two kernels reads and writes the tensor twice through HBM. Fused, it reads once and writes once — exactly the elementwise-fusion win from Module 05, now done automatically by the compiler. XLA even carries operator-boundary hints from the frontend (e.g. “this region is a softmax/BatchNorm”) because recognizing the pattern from the user side fuses better than rediscovering it from primitives.
Lowering boundaries — where the compiler stops (and you write a kernel)
This is the question that separates a compiler engineer from a framework user. Fusion and codegen stop at boundaries like:
- Library calls: big GEMMs/convs go to cuBLAS/cuDNN as custom calls — XLA schedules them but won't fuse through them. An epilogue may fuse onto the GEMM, but two library GEMMs won't merge.
- Incompatible iteration spaces: fusing a reduction's consumer with an unrelated elementwise op can force recompute or block tiling that isn't profitable; the cost model declines.
- Opaque / data-dependent ops: dynamic shapes, sorts, scatter with unknown extents, control flow — hard to fuse and tile statically.
- Algorithms the op set can't express efficiently: FlashAttention's online-softmax tiling isn't something XLA derives from
softmax(QKᵀ)Vprimitives — so it ships as a custom call (a hand-written Triton/CUDA kernel). Same for many fused MoE paths.
Because the win requires changing the algorithm (recompute + online softmax to avoid the S×S materialization), not just merging existing ops. Compilers fuse and schedule the graph you gave them; they don't invent an IO-optimal algorithm. That's the precise seam where a kernel engineer adds value — and a great thing to articulate.
MLIR, PTX, and where Triton sits
- MLIR is the compiler infrastructure: multiple dialects (StableHLO, linalg, affine, gpu, nvvm, llvm) and passes that progressively lower from high-level to hardware. StableHLO is one dialect; lowering walks down to the LLVM dialect.
- PTX is NVIDIA's virtual ISA — forward-compatible assembly.
ptxascompiles PTX → SASS, the actual per-architecture machine code. LLVM/NVVM emit PTX. - Triton is a kernel DSL with its own MLIR-based pipeline (Triton IR → Triton GPU IR → LLVM → PTX). It's how you write the custom-call kernels XLA can't generate — your FlashAttention, your fused MoE GEMM.
“JAX/PyTorch lower to StableHLO — a stable, portable opset — which XLA turns into HLO and optimizes, chiefly by fusion: merging memory-bound ops so intermediates never hit HBM, then lowering through LLVM to PTX and SASS. The compiler fuses and schedules the graph it's given; it won't invent FlashAttention's online-softmax algorithm or fuse through a cuBLAS call. Those boundaries are exactly where I write a Triton/CUDA custom call. MLIR is the dialect-based infrastructure that makes this progressive lowering composable.”
Training & gradients context CONTEXT
Even for a kernel role you must speak training fluently — half the kernels you'll write are backward passes, and the memory tricks (recompute, checkpointing) come straight from the training loop.
\[ \theta \leftarrow \theta - \eta\,\nabla_\theta \mathcal{L} \]
SGD adds momentum (\(v\leftarrow \mu v + \nabla,\ \theta\leftarrow\theta-\eta v\)); Adam keeps per-parameter first/second moment estimates. Optimizer state (Adam = 2× params) is a real memory cost — and a reason FP8/FP4 and sharding matter.
Backward is just more matmuls and reductions
Autodiff applies the chain rule op-by-op (reverse-mode / VJP). The key fact for kernels: the backward of a linear layer is two GEMMs.
Forward \(Y = XW\). Given \(\partial\mathcal{L}/\partial Y = G\):
\[ \frac{\partial \mathcal L}{\partial X} = G\,W^\top, \qquad \frac{\partial \mathcal L}{\partial W} = X^\top G \]
So one forward GEMM becomes three GEMMs total across fwd+bwd. Softmax/LayerNorm/RMSNorm backward are fused reductions, just like their forward.
The memory ideas you already met
- Activation checkpointing / recompute: don't store every activation for backward — recompute it from a checkpoint. Trades FLOPs for memory. FlashAttention's backward does exactly this: it recomputes the attention tiles rather than storing the S×S probabilities.
- Mixed-precision training: keep FP32 master weights, do compute in BF16, accumulate in FP32. FP16 additionally needs loss scaling (Module 07) so small gradients don't underflow.
- Gradient checking = numerical regression test: compare the analytic gradient to a finite-difference estimate \(\frac{\mathcal L(\theta+\epsilon)-\mathcal L(\theta-\epsilon)}{2\epsilon}\) within tolerance — the training-side analog of the kernel numerical checks in Module 08.
“Backward turns each forward GEMM into two more (dX = G·Wᵀ, dW = Xᵀ·G), so most of my kernels exist in pairs. Memory pressure from activations is managed by checkpointing — recompute instead of store — which is the same idea FlashAttention's backward uses. Precision-wise: FP32 master weights, BF16 compute, FP32 accumulate, loss scaling only if FP16.”
Interview drills + cheat sheet REVIEW
Rapid-fire — answer out loud
Q. Is softmax compute- or memory-bound? Memory — intensity ≪ ridge; the win is fusing it (and using the stable/online form).
Q. Why accumulate matmul in FP32 with BF16 inputs? Summing thousands of low-mantissa products loses small terms; FP32 accumulator preserves them. Tensor cores do this natively.
Q. What makes FlashAttention fast? It's IO-aware: tiles Q/K/V into SRAM and uses online softmax so the O(S²) scores never hit HBM — O(N²d²/M) accesses, IO-optimal, linear memory.
Q. NVFP4 vs MXFP4? Block 16 + fractional FP8(E4M3) scale + per-tensor FP32 (NVFP4) vs block 32 + power-of-two E8M0 scale (MXFP4); NVFP4 has less quantization error.
Q. Where does XLA stop fusing? Through library calls (cuBLAS), across incompatible iteration spaces, data-dependent/dynamic ops, and algorithms the op set can't express (attention) — those become custom calls.
Q. Your MoE is slow — first thing you check? Per-expert token counts. All-to-all waits on the slowest rank; imbalance (straggler), not bandwidth, is the usual culprit — fix with aux loss + capacity.
Q. Decode-time matmul (M=1) — strategy? Memory-bound GEMV; weight-only quantization and request batching, not bigger tiles.
Q. Tall-skinny GEMM underutilizes the GPU — why and fix? Too few output tiles → idle SMs; split-K/stream-K to expose parallelism over the K loop.
Q. Why might bigger tiles run slower? They consume registers/SMEM, cutting occupancy or pipeline stages below what's needed to hide latency.
Q. How do you know a kernel is correct? Compare to an FP32 reference with dtype-appropriate tolerance (rtol scales with K), include adversarial inputs (planted outliers), and gate it in CI alongside a TFLOP/s baseline.
One-glance cheat sheet
GEMM FLOPs = 2MNK I = FLOPs/bytes ridge I* = P_peak/B ≈ 300
stable softmax: e^(x−max) FA traffic O(N²d²/M) RMSNorm: x/√(mean(x²)+ε)·γ
linear bwd: G·Wᵀ , Xᵀ·G θ ← θ − η∇L
Memory-bound → fuse. Compute-bound → MMA util / lower precision. Comm-bound → overlap.
Square → big tiles · Tall-skinny → split-K · M=1 decode → weight-only quant.
BF16 train · FP8 (E4M3 fwd / E5M2 grad) · FP4 only with block scaling.
| MMA shapes | Issuer | FP formats (S/E/M) |
|---|---|---|
Ampere mma m16n8k16 | 1 warp | BF16 1/8/7 · FP16 1/5/10 |
Hopper wgmma m64nNk16 | warpgroup (4) | FP8 E4M3 1/4/3 · E5M2 1/5/2 |
Blackwell tcgen05 ≤256·256·16 | single thread + tensor mem | FP4 E2M1 1/2/1 (block-scaled) |
1. Classify the bound (compute / memory / comm) via arithmetic intensity. 2. Name the lever that bound implies (MMA util / fusion / overlap). 3. Cross a layer — connect it to a hardware constraint (MMA shape, SRAM size, link BW) or a compiler boundary (does it fuse?). 4. Say how you'd verify (reference + tolerance, profiler counter, microbenchmark). Hitting all four reads as senior.
Further reading (primary sources)
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — arXiv:2205.14135.
- OpenXLA — XLA architecture & StableHLO docs (openxla.org).
- NVIDIA — “Introducing NVFP4 for Efficient and Accurate Low-Precision Inference” (developer.nvidia.com).
- OCP Microscaling (MX) formats specification — for MXFP4/MXFP8.
- NVIDIA PTX ISA & CUTLASS tutorials — for
mma/wgmma/tcgen05and tensor-memory GEMMs.
Built as a compressed study aid. Math via KaTeX, code via highlight.js. Verify exact hardware peak numbers against the spec sheet for your target part before quoting them in an interview.