6D Parallelism for Distributed Training
Mar 2
Matrix multiplication is a fundamental operator in neural networks, and in the previous blog we have fully elucidated the differentiation of the loss through such an operator in full generality, and we have applied this to the attention operator, a compound operator consisting of a series of matrix mulitiplications. Let us assemble a clear view of the process of training a neural network. In particular let us examine the data being stored and their persistence.
Neural network weights is the first pieces of data that needs to be stored for the duration of training, it is used during forward pass to compute activation, and it is used during backwards pass to compute activation gradients. The second piece of data which require storage are the activation of each layers, since these are used to compute the gradient of weights in the layer to which the activations are the inputs during forward pass, as we have already seen. Activation can be released whenever the weight gradients they are associated with are computed. The third piece of data that is stored are activation gradients. They are used to compute gradients during backpropagation through the next layer (or operation), and can be released after computing activation and weight gradients in the next layer (or operation). The fourth piece of data are weight gradients, with a lifetime from produced until the corresponding weights are updated. The fifth type of data are associated with the optimization algorithm which updates the weight using the gradients. In the Adam optimizer, we need to keep track of the exponential moving average of the weight and the square of each weight. These moving averages are ususally called optimizer states, and persist throughout training.
The effective management of the persistence and distribution (across GPUS and CPUs) is a core systems optimization problem.
Distributed Data Parallel
We replicate a set of model weights identically across GPUs. Each GPU holds a distinct collection of training data. Each GPU passes forward with the data to compute the loss, and performs backward pass to compute the gradient. The gradient is all reduced over parallel agents, and scaled by to obtain the average gradient. Each local parallel agent now holds identical copies of the global gradient. This gradient produces identical optimizer states across parallel agents. The optimizer have identical gradients, identical optimizer states, produces identical updated model weights across GPUs. The procedure is repeated.
Fully Sharded Data Parallel
Suppose a neural network consists of layers whereby each layer is a function with suitable domain and range. The neural network evaluated on leading to loss , is defined by Let us introduce the notation that is a function parameterized by weight and takes as input the previous layer's activation .
Given dataset , weight matrices and parallel agents (GPU, TPU etc) we partition the dataset into pariwise disjoint subsets and partition each weight matrix into N slices, for example To the -th parallel agent we send
To describe one step in training, each of the parallel agents selects a training data and puts . For suppose the activation has been computed and the forward pass is ready to advance through layer on parallel agent we call and compute and immediately free the memory on device of the weight shards and store the layer- activation for weight gradient calculation through during backward pass Let us turn to backwards pass. In our notation computes the loss function and this layer has no weights, so as backwards we get just as reminder is the output of layer , which is the input to layer . We will treat as a variable, whereas the actual activation has a superscript to identify the device . For each supposed that device computed its local activation gradient have computed. Then the local weight gradient can be computed as a function of the map , the activation gradient, and the local activation Since weights are kept in shards across the parallel agents, we need to partition the weight gradient too we will use ReduceScatter to provide each parallel agent with the average gradient corresponding to the weight shard it owns, namely it receives scaling by this is the average gradient for its local weight shard . The local Adam optimizer stores exponential moving averages of and the component-wise square of . Thus optimizer states are sharded as well. Finally each parallel agent needs to compute its local activation gradient as a function of the function (that we will differentiate via the chain rule) and the weight of layer together with namely Computing this allows backpropagation to advance from layer to layer . To do this we need to AllGather the weight on each of the parallel agents After computing the activaiton gradient
Tensor Parallel
Attention and MLP are two fundamental operators in a transformer, and their underlying mathematical operation is matrix multiplication. There are two fundamental ways to write matrix product with block matrices, giving rise to two ways in which a GEMM can be performed across GPUs: the first way requires separate block matrix multiplications followed by AllReduce, and the second way requires AllGather. Tensor parallelism is the application of these simple ideas to the matrix multiplications that occur in a transformer.
Let us take the MLP layer as the first example. Recall that in modern architectures the feedforward is computed as where the Swish linear unit (SiLU) activation function is applied componentwise to . Let there be parallel agents, we can write so that each device computes then we can either AllGather and partition along the column direction and perform matrix multiplication of the first type above, then all reduce, or we can first partition along its column direction We then use AllToAll to send from to so that on device we get which satisfies Now each device computes followed by AllReduce to obtain
Sequence Parallel
Recall that in general, an input to a transformer block has shape For operations that are applied independently to each position along the sequence dimension, one split the input along each sequence, and compute the operation on each chunk in parallel, called sequence parallel. Let us be more precise.
We adopt the notation to denote the (hidden/embedding) vector belonging to batch and at sequence location . Suppose the is a function on and a neural network layer computes Let be a positive integer, we can choose number so that each parallel agent computes Let's apply this to a concrete positionwise operation in transformer models: the root mean square normalization with is a mapping Where is the componentwise product. The RMS normalization layer of a LLM is a positionwise operator, meaning that for a sequence of token embeddings the RMSNorm layer with trainable weight transforms this sequence into . This positionwise property enables RMSNorm to be computed on separate GPUs then subsequently AllGathered, and is an effective method for managing long context. In particular, in a pre-norm transformer architecture, sequence parallel then be followed by tensor parallel we described above.
Context Parallel
Context parallelism is a generalization of sequence parallel. CP splits along the sequence dimension , for all position independent operations in the model. Recall a counterexample is the attention operation, which is the main feature of autoregressive models, and it exchanges information between positions in a sequence in some sense. Since attention is a position dependent, when CP is applied, the keys and values over complete sequence is AllGather during forward pass, and ReduceScatter during backwards pass (simplest solution), while queries remain sharded across context parallel agents.
TBD: mathematical description of CP and discussion of activation recomputation during attention.
Pipeline Parallel
Let be a neural network, let be an integer, and let be a set of indices of the functions. Define for all . Pipeline parallelism decomposes and computes on parallel agent for each where for some input to .
TBD: Pipeline bubble. Computation and communication overleap, Dualpipe, GPipe, microbatches, and interleaving forward and backward.
Expert Parallel
Let be integers, and and be continuously differentiable maps on . Let be an integer, and associate each with a subset with cardinality . The abstract MoE problem for tokens is to compute Expert parallelism computes by dispatching to a -subset of parallel agents to compute for each and then combining the results to form the sum.
Indeed we abstracted away many implementation details, such as the existence of shared experts, and not to mention how the layers and scores are computed as well as the need for skip connections. This abstract formulation allows us to cleanly see the essential computation and communication of expert parallelism for each token. In practice, one would have a seqence of tokens, and a batch of sequences. A problem which arises is different number of tokens get routed to the various experts, and one would carefully think through efficient variable length GEMM at each expert for inference, as well as proper expert loading balancing. However these are bonus problems in addition to expert parallelism we analyze here.
Xue J. Zhao © 2026