Suppose a neural network consists of layers whereby each layer is a function with suitable domain and range. The neural network evaluated on leading to loss , is defined by Let us introduce the notation that is a function parameterized by weight and takes as input the previous layer's activation .
Given dataset , weight matrices and parallel agents (GPU, TPU etc) we partition the dataset into pariwise disjoint subsets and partition each weight matrix into N slices, for example To the -th parallel agent we send
To describe one step in training, each of the parallel agents selects a training data and puts . For suppose the activation has been computed and the forward pass is ready to advance through layer on parallel agent we call and compute and immediately free the memory on device of the weight shards and store the layer- activation for weight gradient calculation through during backward pass Let us turn to backwards pass. In our notation computes the loss function and this layer has no weights, so as backwards we get just as reminder is the output of layer , which is the input to layer . We will treat as a variable, whereas the actual activation has a superscript to identify the device . For each supposed that device computed its local activation gradient have computed. Then the local weight gradient can be computed as a function of the map , the activation gradient, and the local activation Since weights are kept in shards across the parallel agents, we need to partition the weight gradient too we will use ReduceScatter to provide each parallel agent with the average gradient corresponding to the weight shard it owns, namely it receives scaling by this is the average gradient for its local weight shard . The local Adam optimizer stores exponential moving averages of and the component-wise square of . Thus optimizer states are sharded as well. Finally each parallel agent needs to compute its local activation gradient as a function of the function (that we will differentiate via the chain rule) and the weight of layer together with namely Computing this allows backpropagation to advance from layer to layer . To do this we need to AllGather the weight on each of the parallel agents After computing the activaiton gradient