GPU Kernel Level Optimization for Efficient MoE Training
Jan 20
Variable Length Grouped General Matrix Multiply
Recall that the definition of the general matrix multiply (GEMM) is the operation on matrices and scalers . We shall be concerned with the special case where has dimensions and has dimensions , and we say will denote the problem size by .
This special case is relevant in the mixture of experts layers where we have a set of such multiplications for the different tokens routed to the various experts. Within an MoE layer, these multiplications have possibly with different shapes due to the tokens routed to each expert. For example if is the token dimension, we shall refer to the problem as variable length M grouped GEMM.
The up-projection of a given expert requires tokens routed to it to be gathered, while the down-projection has all tokens already contiguously packaged.
Forward and Backward MoE Kernels
Let us turn to the kernel design in the SonicMoE paper. A total of kernels are associated with the MoE layer. The forward computation has up-projection, down-projection, and expert aggregation kernels. Recall that in general, backwards gradients are computed with respective to the activations as well as for the weights. In the case of MoE, the backwards activation gradients are computed with respect to down-projection, up-projection and the MoE input. Two gradients with respect to up and down projection weight matrices are also computed. Moreover a top-k routing kernel is given in SonicMoE.
Kernel (Epilogue) Fusion
The first strategy is a kernel fusion of the input gather (during token routing) with input loading from global memory (HBM) to shared memory (SMEM). During this phase the token indices are first gathered, then the activations at those indicies are obtained via the cp.async PTX instruction. The authors of SonicMoE note that on Blackwell architecture, the 2-CTA GEMM requires the leader CTA to wait for gather to be complete in both CTAs, which leads to the following pipeline structure of the two CTAs: 1 warp to fetch token indices, 4 warps to gather, then 1 warp to relay the signal and perform MMA.
The second fusion occurs during the epilogue associated with the MoE layer. Specifically, the SwiGLU is fused with forward up-projection, while the backward of SwiGLU is fused with the down-projection gradient computation.
Thirdly fusion involves the and associated with the scores and down-project , respectively.
Computation and Async IO Overlap
A useful strategy for achieving high tensor core throughput with intense epilogue is called ping-pong scheduling where warp groups that collectively issues WGMMA overlap IO and GEMM at a given moment and exchange roles at a future time. Such a strategy was used in Flash Attention 3, and is used in SonicMoE's down projection activation gradient computation, and the down project forward epilogue.
The paper also leverages async TMA for data movement between SMEM and GMEM in a subset of the kernels, for example to store forward down project, among other places.
On Blackwell, the ping-poing strategy is used together with the TMEM and UMMA for better pipelining.
Top-K Sorting Kernel
The SonicMoE paper also provides a top-K kernel using bitonic sort, and has optional softmax fusion.
Claire Zhao © 2026