Based on the papers and documentations I can find, here is a sketch of training LLMs in NVFP4.
Like FP4 and MXFP4, NVFP4 has the bit strcuture of E2M1. The distinction lies in how NVFP4 represent a set of numbers, which we refer to as a tensor. In particular, NVFP4 partitions into subsets of numbers each called a block. Each block is associated with an 8-bit E4M3 number called a block scale factor such that each one of the four bit numbers belonging to the same block is reconstructed with . Additionally, a FP32 number is asspciated with the tensor itself, called the tensor scale factor.
By contrast each MXFP4 partitions consists of numbers, and its block scale factor is E8M0 (i.e. round to the nearest power of two). It can be shown that the expected square error with E8M0 is larger than that of E4M3, with the tradeoff being E8M0 has less overhead. The said parititon and scaling in NVFP4 is handled by specialized tensor core harware.
Given that E2M1 and E4M3 can represent numbers with maximum absolute value of and respectively, the tensor scale factor for a tensor indexed by a set is the tensor dequantization scale is and is stored in FP32. Let be an indexing set for a block in , the corresponding block scale factor that is In fact, the block dequantization scale factor is stored in FP8 on the tensor core as
Each block gets quantized as and partial dot during GEMM product is computed as where After GEMM the tensor dequantization scales are applied.
There are experiments showing NVFP4 should be used in earlier layers of the forward pass direction of a transformer, while keeping later layers in higher precision.