From Adam to Mixed Precision
An Infra Engineer's Deep Dive into the ML Stack

Photo: Vincent Tjeng
I've spent four years as an ML Infra engineer working on the platform layer — AI agents, serving infrastructure, search and recommendation systems. The work I enjoy most is the kind where you can't just fix the surface symptom; you have to understand what's happening several layers down. A client asks why their request is slow, and what starts as a latency investigation turns into a puzzle about distributed system behavior.
But I've been operating on top of the ML stack without fully understanding what's inside it. A colleague once told me: "No matter what kind of request it is, it eventually has to run on a physical machine." So I decided to go back to fundamentals. This blog is a record of that.
The Starting Point: Why Does Adam Eat So Much Memory?
My re-learning started with Karpathy's building micrograd. While working through it, I recalled an article about calculating a model's memory footprint during training, which specifically called out Adam's massive GPU memory consumption. I'd always used Adam as the default optimizer — the thing everyone reaches for without thinking twice. But I'd never understood why it's so memory-hungry.
Adam in 60 Seconds
A model's parameters are just numbers — weights that influence the output. The gradient of a parameter is its first derivative with respect to the loss function: it tells you how much changing that parameter affects the loss. Think of the parameter as position, the gradient as the force acting on it.
Gradient descent is simple: compute the gradient, step in the opposite direction. w_new = w_old - α * gradient, where α (the learning rate) controls how hard you pull. But vanilla gradient descent is terrible at navigating complex loss landscapes — it oscillates in steep narrow valleys and crawls across flat plateaus.
Adam (Adaptive Moment Estimation) fixes this with two mechanisms:
First moment (momentum): A running average of past gradients. If gradients have been consistently pointing the same direction, momentum builds up and you accelerate — like a heavy ball rolling downhill that powers through small bumps.
Second moment (adaptive learning rate): A running average of squared gradients. If gradients have been volatile, this value spikes, and Adam automatically shrinks the step size for that parameter.
The step looks roughly like: step ≈ (learning_rate × momentum) / (√volatility + ε). Momentum in the numerator accelerates you; volatility in the denominator slows you down. Each parameter gets its own individually tuned step size.
This "accelerate when smooth, brake when rough" pattern is essentially TCP congestion control. Slow start with exponential window growth maps onto momentum accumulation — when ACKs come back fast (gradients are consistent), you aggressively ramp up. Multiplicative decrease when you hit congestion maps onto the second moment clamping down the step size when gradients become volatile. Even TCP's Fast Recovery has a parallel: when TCP detects mild packet loss (duplicate ACKs rather than a full timeout), it halves the window and continues probing rather than slamming back to zero. Similarly, Adam's second moment uses an exponential moving average rather than reacting to instantaneous gradient spikes — it distinguishes between temporary turbulence and genuine divergence. The math is different, but the design intuition is the same: don't overreact to transient noise, don't ignore real trouble.
One nuance Claude pointed out when I discussed this: optimizer choice can actually affect final model quality, not just training speed. Different optimizers tend to land in different local minima — Adam tends to find sharper minima (which may generalize slightly worse), while SGD with momentum sometimes finds flatter ones (which tend to generalize better). This is apparently why some large model training runs switch from Adam to SGD in later stages.
The Memory Problem
Here's where it matters for infra. With basic SGD, you store two things per parameter: the weight and its gradient. With Adam, you also store the first moment and the second moment — per parameter.
In mixed-precision training, these optimizer states must be kept in FP32 (4 bytes each) for numerical stability. That means Adam adds 12 extra bytes per parameter: an FP32 master weight copy, plus the FP32 first and second moments.
For a 7B parameter model, Adam's optimizer states alone consume roughly 84 GB of GPU memory — before you count the model weights, gradients, or activations. This is why ZeRO exists: it shards optimizer states across GPUs because no single GPU can hold it all.
Bias Correction: A Cold-Start Fix
Adam includes a clever fix for what amounts to a cold-start problem. Both moments are initialized to zero. In the first few steps, the exponential moving average is heavily biased toward zero (the history is almost entirely zeros), so the optimizer thinks the real gradients are tiny. The model barely moves.
Bias correction divides by (1 - β^t), which is small in early steps (amplifying the estimate to its true scale) and converges to 1 as training progresses (at which point the correction silently disappears). Without it, Adam loses its "just works out of the box with default hyperparameters" property — and that robustness is arguably what made it dominate the industry.
This "cold-start correction that gracefully fades out" pattern feels borrowable in other engineering contexts — anywhere you're bootstrapping a running estimate from zero history.
A Side Note: Adam vs. Google Vizier
Before this deep dive, I'd been fuzzy on where Adam sits relative to systems like Google Vizier. Both involve "optimization," so it's an easy conflation.
Adam is a gradient-based optimizer that runs inside the training loop, updating weights step by step. Vizier is a black-box optimization service that sits outside the loop, managing entire training runs to find the best hyperparameters. Vizier might find the best learning rate and beta values for your Adam optimizer. They're complementary systems at different abstraction levels.
Mixed Precision: Not About Saving Space
I'd seen "mixed-precision training" referenced countless times but never deeply understood the why. My instinct was that it was a memory optimization — use half the bytes, fit more on the GPU. That's partially true, but it misses the real story.
Why FP32 Master Weights Can't Be Negotiated
In mixed-precision training, forward and backward passes run in half precision (FP16 or BF16, 2 bytes per value). But master weights and Adam's optimizer states must stay in FP32.
The reason is numerical swamping. In later stages of training, Adam's updates become extremely small — on the order of 0.00001. If your current weight is 1.0:
In FP32:
1.0 - 0.00001 = 0.99999. The update is faithfully recorded.In FP16: The mantissa doesn't have enough bits. The tiny update gets silently rounded away:
1.0 - 0.00001 = 1.0.
The gradient was computed correctly, Adam did its job correctly, but the weight doesn't move. Training silently stalls. This is why FP32 master weights are non-negotiable — you're trading 2x memory for the ability to actually converge.
The Real Reason for Half Precision: Hardware Physics
So if we're stuck with FP32 for optimizer states anyway, why bother with half precision for forward and backward passes?
This was my "aha" moment: mixed precision is fundamentally about compute throughput and memory bandwidth, not storage capacity.
Modern GPUs have Tensor Cores physically optimized for FP16/BF16 matrix multiplication. On an NVIDIA A100: ~19.5 TFLOPS for FP32 versus ~312 TFLOPS for FP16 on Tensor Cores. That's roughly a 16x speedup. Is that sliver of extra FP32 precision per operation worth a 16x slowdown? No.
Then there's the memory bandwidth wall. In large model training, the bottleneck is often how fast you can feed data from HBM to the compute cores (SRAM). Every FP16 value is 2 bytes instead of 4, cutting data transfer volume in half. In a bandwidth-bound workload, that's the difference between keeping Tensor Cores busy and having them sit idle waiting for data.
And there's activation memory — the intermediate results from each layer saved for backpropagation. Unlike model parameters and optimizer states (fixed once the model is defined), activations scale dynamically with batch size and sequence length. In the era of 128K+ context windows, activation memory can explode. Cutting it from FP32 to FP16 directly frees space for larger batches or longer contexts.
Activation memory in training can't be optimized the same way as in inference. In inference, KV cache avoids recomputing past tokens' keys and values during autoregressive generation. But training processes the entire sequence at once, and backpropagation needs the full activations from every layer — not just the final results, but the entire intermediate computation trail. The training-side answer is gradient checkpointing: selectively discarding activations during the forward pass and recomputing them during backward, trading compute for memory.
A Quick Note on BF16
BF16 (Brain Floating Point) reveals a hardware-software co-design philosophy. FP16's problem: its exponent is too small (5 bits), giving it a narrow dynamic range and making gradient overflow/underflow a real risk. Google designed BF16 for their TPUs by keeping the same 8-bit exponent as FP32 (preserving full dynamic range) but truncating the mantissa. Less precision, far more robust in practice. It's now the default on NVIDIA's newer GPUs as well — a case where a TPU hardware design choice propagated back and reshaped the entire industry's training practices.
What's Next
Next up: transformer internals from an infra perspective, continuing with Karpathy's neural network series.
References:
The spelled-out intro to neural networks and backpropagation: building micrograd — Andrej Karpathy
Adam: A Method for Stochastic Optimization (Kingma & Ba, 2014)
Google Vizier: A Service for Black-Box Optimization (Golovin et al., 2017)
A Study of BFLOAT16 for Deep Learning Training (Kalamkar et al., 2019)
BFloat16: The secret to high performance on Cloud TPUs — Google Cloud Blog
ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (Rajbhandari et al., 2019)

