Training LLMs in Low Precision
Modified on Jan 31
Consider how a real number can be represented on a computer. If then where is the sign of and can be represented with one bit . For the positive number there exists a smallest such that . Now we can divide this interval into two equal halves. By definition of we know that . We can again divide the interval into two equal halves and lies in either the left half or the right half (to break the tie, if lies in the middle, then we say it lies on the right half). As we make the sequence of interval partitions, the positive real number lies in either the left or the right half. Associate left with right with we obtain a binary sequence . It is clear that for the -th partition the situation is where is less than the length of the half interval in the -th partition. So the positive real number , and a desired precision , we can choose an integer such that so that the number which is representable by a binary sequence of length , approximates to within , namely Since we can write Hardware approximates by storing one bit for the sign of , another bits called the mantissa of , the bit length of which determines an upper bounds on the error of approximating by , together with a representation of the exponent .
Now we need a scheme for representing the integer in bits. Suppose that bits are designated to store the exponent, then there are configurations, meaning that the real number the float can approximate spans orders of magnitudes. Naturally we want to choose orders of magnitudes about the multiplicative unit . Customarily we designate of the orders of magnitudes to be at or below the unit, and another orders of magnitudes to be above the unit. The hardware implementation is to store the integer exponent with an non-negative integer such that this way the smallest order of magnitude (with respect this dtype) is stored as whereas the largest order of magnitude is stored as . This scheme is nothing but a bijection from The quantity is sometimes given the nondescript name bias just to be confusing! The mantissa , which we have until now defined as the bit string should really be identified with the real number that it represents and together with and the sign represents the real number (approximately) on hardware as
The above analysis applies to . Since is a limit point of the set there does not exist a smallest integer such that . Suppose that a float dtype has bits to store the mantissa, it is natural to approximate zero with the smallest element in the set the minimum value being and corresponds and . Equivalently zero can be approximated by the maximum of the set consisting the negative elements of the set . Therefore we have two representations of zero by the dtype. Now this argument is true to first-order, and needs slight modification when we introduce the so called subnormal numbers in practical floating point implementations. The idea is that is not the smallest order of magnitude we can represent if in our implementation we cutomzarily agree that when we replace by where observe that whereby is the smallest normal order of magnitude. If the mantissa has bits, then the smallest subnormal number is and the largest subnormal is corresponding with mantissa and respecitively. The case with is reserved to represent zero. Observe that such smaller than magnitudes are only expressible near zero, with subnormal numbers. That is, for a general real numbers , the upper bound of error of approximation is still even though for numbers below the smallest normal the error of approximation has bound .
Let us build some intuition. (1) observing that every mantissa bit of a normal number is significant but for subnormal number this need not be the case. For instance for a normal number with , every bit is signitficant. For the subnormal number with mantissa every bit is significant, but a subnormal with only has 1 significant bit. The leading zeros are placeholders and not significant. (2) A hand wavy way to look at gaps between int and floats. Let , the gaps between consecutive Int- numbers are the same, whereas for Float- numbers with bit mantissa, the ratio of gap sizes between nearby numbers are evenly spaced on a log-scale. This is meant in the sense that for a float the nearby elements are spaced . In particular these gaps are larger by a factor of depending on the exponent for the float .
Let . We say that a floating point data type has bit structure ExMy if every float is represented (in hardware or software) as 1 sign bit together with exponent bits and mantissa bits. It is sometimes the convention that if then the number has no sign bit.
Since any hardware floating point number consists of a finite number of bits, the float has a limited range of expressible order of magnitudes. For numbers outside this range, a data type sometimes designate bit patters for the sign, mantissa, and exponent to represent . However this is optional, for example in the so called microscaled (MX) classes of floats, infinities are excluded in E4M3 MXFP8 but are included in E5M2 MXFP8. Other special values for a float are NaN or not a number, which designate the results of undefined operations such as . How to implement NaN for a given float is a choice. In some designs, NaN is omitted, while in others only one bit pattern (up to equivalence in sign) denotes NaN, and yet in others, multiple such pairs are treated NaN (compare E2M3, E4M3, E5M2 microscaled floats).
The aforementioned microscaled floats works as follows, instead of representing one real number at a time, MX floats represents collections of real numbers simultaneously. Suppose , with . Let be the set of all floats, and let be the set of all floats. The set is called the -dimensional microscaling (MX) float with scalers of type and block of type . An element is associated with . There are dtype specific rules like for all , and implementation defined rule for situations such as when . Within microscaled data types there are floats MXFP4 where scalars are of type E8M0 and of type E2M1 and dimension togehter with MX integers like MXINT8. Encodings for mantissa are somewhat different than our discussion so far, in the sense that the mantissa has factored out ( being the number of mantissa bits), so that so the unsigned integer is stored as opposed to . In this scheme a normal number has and value and a subnormal has and
The dot product of two MX floats of the same dtype corresponds the dot product of the vectors they represent: for every and every we define their dot product to be The product between an element of and an element of and likewise the product between elements of and elements of are specified in the implementation. Specification of the output type depends on the situation.
Let . More generally dot product can be defined between every vector and every by This dot product requires the length of the vectors to be a multiple of the number of elements per MX block. This can be relaxed to any length by padding to the nearest multiple of greater than and truncation the result back to length .
To represent a vector in MX with maximal value the corresponding scale factor is determined by In other words, the scale factor depends on and the largest normal number for the block data type. The scale factor is the ratio of the largest power of two less than or equal to to the largest power of two less than or equal to . Observe that the scaling factor of MX is an integer power of (hence in MXFP8 the block scale are E8M0 integers).
Let be the map that quantizes real numbers to with the scheme in the begining of the blog, and which assigns values exceeding the maximal value to and those values below the minimal normal value to . The corresponding block values
Mixed Precision Training
While reducing the number of bits to represent a float leads to higher compute throughput and less memory, for a fixed number of bits one faces the trade off between range, as govered by , and precision, as governed by . In practice the choice of this tradeoff is informed by the sensitivity of the numerics to either factors during different stages of model training. To be concrete, it is emperically observed the computations in forward pass is more sensitive to the precision of data type being used, while backward pass gradients demands more range for expression. Suppose 8 bits per float is given, then E4M3 is often the choice for forwrad pass and E5M2 is used for backwards gradient computation. In general, the use of different data types for different stages of neural network training is known as mixed precision training.
The idea behind block scaled data types like the MX format above and NVFP4 as we discuss subsequently is grounded in the observation that the range of a local patch in a tensor can be factored out as the scale, which is one number in the block that requires greater number of bits for range, while the other factor can be stored in a low bit data type. This is analogous (to make a very crude analogy!) to the idea that a vector can be stored as a magnitude and a direction, and a block of data can be stored as a scale and some leaner datatype.
Using block scaled data types introduces many nuances, in a general matrix multiplication (GEMM) that involves granular scaled blocks (with the purpose of localizing outlier say), the scaling now needs to take place during the GEMM mainloop as opposed to the epilogue for per-tensor scaling. This actually requires hardware to take care since the introduction of Blackwell GPU architecture. From the algorithmic point of view, neural network backward pass typically requires multiplying the activation gradient from a previous layer by the transpose of the activations inputed to the current layer, and due to the noncommutativity of transpose with block scaled quantization when 1D blocks are used (i.e. blocks along GEMM reduction, or equivalently the dot product, dimension), the chain rule breaks 🙀.
NVFP4
The datatype NVFP4 is a blocked scaled flotating point type that continues the spirit of MXFP8 with lower average bit per float, more granular blocks, and two level scaling. The block size is 16, the scale type is E4M3 8 bit floating point, while the block storage type is E2M1 4 bit float. In addition there is a tensor level scalar stored in FP32 that ensures the largest value for each block is representable in the range of what can be stored as an at the block level. Thus on average NVFP4 uses
NVFP4 Pre-training Receipe
Software emulation of scaling and quantization and reverse.Random Hadamard transform, stochastic rounding for bias reduction but trade off with more variance, 1D/2D block scaling, selective precision. Change of basis to transform distribution to more Gaussian like. Complexity. Emperically found recommendation: apply random Hadamard transform to weight gradient as opposed to forword and activation gradient. Gradients are sensitive to quantization bias. (TBD)
Definition
The following comes from an earlier draft of this blog
Hardware represents floating point numbers in the form where determines the sign of the number and is stored with one bit. Depending on the datatype, a choice of bits are devoted to the mantissa , and bits are devoted to the exponent . The datatype thus determined requires bits to represent a float, and is denoted EeMm. For instance, FP32 is E8M23, FP16 is E5M10 where as BF16 is E8M7.
Quantization is the operation that maps numbers repreented in a given datatype to numbers in another datatype requiring less number of bits. Thus quantization is a compression mechanism that reduces storage and communication footprint, and increase compute throughput. The efficiency gained via compression is to be trade-off with degradation in accuracy in one form or antoher.
In practice one quantizes a set of numbers together. For example, to quantize real numbers into bit integers , we can choose the mapping In particular the scale factor is chosen so that the maximum element of gets mapped to . Since the mapping is many-to-one we can only hope to dequantize approximately, with error. More precisely, given an integer we can map it to and this introduces an error .
In general given a pair of quantization and dequantization maps one can measure the error .
Training in NVFP4
Based on the papers and documentations I can find, here is a sketch of training LLMs in NVFP4.
Like FP4 and MXFP4, NVFP4 has the bit strcuture of E2M1. The distinction lies in how NVFP4 represent a set of numbers, which we refer to as a tensor. In particular, NVFP4 partitions into subsets of numbers each called a block. Each block is associated with an 8-bit E4M3 number called a block scale factor such that each one of the four bit numbers belonging to the same block is reconstructed with . Additionally, a FP32 number is asspciated with the tensor itself, called the tensor scale factor.
By contrast each MXFP4 partitions consists of numbers, and its block scale factor is E8M0 (i.e. round to the nearest power of two). It can be shown that the expected square error with E8M0 is larger than that of E4M3, with the tradeoff being E8M0 has less overhead. The said parititon and scaling in NVFP4 is handled by specialized tensor core harware.
Given that E2M1 and E4M3 can represent numbers with maximum absolute value of and respectively, the tensor scale factor for a tensor indexed by a set is the tensor dequantization scale is and is stored in FP32. Let be an indexing set for a block in , the corresponding block scale factor that is In fact, the block dequantization scale factor is stored in FP8 on the tensor core as
Each block gets quantized as and partial dot during GEMM product is computed as where After GEMM the tensor dequantization scales are applied.
There are experiments showing NVFP4 should be used in earlier layers of the forward pass direction of a transformer, while keeping later layers in higher precision.
Random Hadamard Transform
Hadamard matrices of dimension for an integer satisfies and . We shall consider a randomized Hadamard matrix where is a diagonal matrix of values chosen uniformly at random. In training, instead of operating on tensors one qpplies the above NVFP4 on the random Hadamard transformed tiles of the tensor. In some experiments is applied to inputs of weight gradient GEMM.
Xue J. Zhao © 2026