As Large Language Models (LLMs) grow in size and complexity, finding ways to reduce their computational and energy consumption has become a key challenge. One popular solution is quantization, in which the precision of parameters is reduced from the standard 16-bit floating point (FP16) or 32-bit floating point (FP32) to lower-bit formats such as 8-bit or 4-bit. While this approach significantly reduces memory usage and speeds up computation, it often comes at the cost of accuracy. Excessive reduction in accuracy can cause the model to lose critical information, which can lead to performance degradation.
BitNet is a special kind of transformers architecture that uses just three values: the(-1, 0, 1)
denotes each parameter, providing an extreme quantization of only 1.58 $ (log_2(3)) $ bits per parameter. However, this requires training a model from scratch. While the results are impressive, not everyone has the budget to pre-train a large language model. To overcome this limitation, we explored some techniques that allow fine-tuning an existing model to 1.58 bits! Continue reading to learn more!
catalogs
- synopsis
- Learn more about what BitNet is
- 1.58 Pre-training results for bits
- 1.58 Fine-tuning of bits
- Kernel and test standards used
- reach a verdict
- a thank-you note
- More resources
synopsis
BitNet is a modeling architecture proposed by Microsoft Research that uses extreme quantization to represent each parameter with only three values -1, 0 and 1. This results in a model that uses only 1.58 bits per parameter, significantly reducing computation and memory requirements.
The architecture uses INT8 additive computation when performing matrix multiplication, which is quite different from the FP16 multiply-add operation of the traditional LLM architecture exemplified by Llama.
A new computational paradigm for BitNet b1.58 (Source: BitNet paper)/abs/2402.17764)
This approach theoretically reduces energy consumption, with BitNet b1.58 saving 71.4 times the computational energy in matrix multiplication compared to the Llama benchmark.
Energy consumption of BitNet b1.58 vs. Llama (Source: BitNet paper)/abs/2402.17764)
We successfully used the BitNet architecture for theLlama3 8B model The model was fine-tuned and achieved good performance in downstream tasks. We developed the 8B model byHF1BitLLM Organizational Release. Two of the models were fine-tuned with different training settings on 10B tokens, while the third model was fine-tuned on 100B tokens. Notably, our models outperform the Llama 1 7B model in the MMLU benchmark.
How to use Transformers in
To integrate the BitNet architecture into Transformers, we introduce a new quantization method called "bitnet" (PR). The approach involves replacing the standard Linear layer with a BitLinear layer specifically designed for use in the BitNet architecture, which implements the corresponding dynamic operations of activation quantization, weight unwrapping, and matrix multiplication.
Loading and testing models in Transformers is very simple, and the API has not changed a bit.
model = AutoModelForCausalLM.from_pretrained(
"HF1BitLLM/Llama3-8B-1.58-100B-tokens",
device_map="cuda",
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
input_text = "Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:"
input_ids = (input_text, return_tensors="pt").cuda()
output = (input_ids, max_new_tokens=10)
generated_text = (output[0], skip_special_tokens=True)
print(generated_text)
With this code, everything is done perfectly directly behind the scenes, so there's no need to worry about additional complexity, all you need to do is install the latest version of transformers.
To quickly test the model, check out thisnotebook。
Learn more about what BitNet is
BitNet Instead of the traditional Linear layer in the Multihead Attention and Feedforward Network, special layers called BitLinear are used, which use three-valued precision (or even two-valued precision in the initial version). In this project, we use a BitLinear layer that uses three-valued precision for the weights (with values of -1, 0, and 1) and quantizes the activations to 8-bit precision. We use different BitLinear implementations for training and inference, as described in the next sections.
The main obstacle in three-valued precision training is that the weight values are discretized (via theround()
function) and therefore not differentiable, BitLinear solves this problem with a neat trick: theSTE (Straight Through Estimator)STE allows the gradient to be approximated as 1 by a non-differentiable rounding operation (by replacing theround()
regarded as equivalent to a constant function) to achieve this. Another way to look at it is that STE lets the gradient pass through the rounding step as if the rounding had never happened, thus using standard gradient-based optimization techniques to update the weights.
BitNet model architecture using BitLienar (Source: BitNet paper)/pdf/2310.11453)
train
We train at full precision, but quantize the weights to three values during training, using symmetric quantization of the per-tensor. First, we compute the average of the absolute values of the weight matrix and use it as the scale, then we divide the weights by the scale, round the values to limit them to the interval -1 and 1, and finally inverse quantize the weights back to full precision.
The activation is then quantized to the specified bit-width (8 bits in our case), using the per-token's maximum absolute value quantization (for a full description of quantization methods, check out thispost). This involves scaling the activation to the range [-128, 127] to fit the 8-bit width. The quantization formula is as follows.
To make these formulas clearer, here are some examples of weights and activation quantization using a 3x3 matrix:.
Example 1: Quantization of the weight matrix
Assume that the weight matrix $ W $ is.
Step 1: Calculate the scale of the weights
Usage formula.
k
We compute the average of the $ W $ activation values.
The resulting scale is now.
Step 2: Quantify the weighting matrix
Usage formula.
We start by scaling the weights $ scale_w \approx 1.2 $ times: $ scale_w \approx 1.2 $ times: $ scale_w \approx 1.2 $ times.
Then we round and truncate it to the interval $ [-1, 1] $ :)
Step 3: Inverse Quantization Weights
Finally, we inversely quantize this weight: the
Using scale_w to restore the weights to their original range, we get.
Example 2: Quantization of the activation matrix
Suppose the activation matrix $ X $ is.
Step 1: Calculate the scale of activation
For each row (or channel), calculate its maximum absolute value
- Row 1: Maximum absolute value = 1.0
- Row 2: Maximum absolute value = 1.2
- Row 3: Maximum absolute value = 0.8
Calculate the scale of each row: the
Step 2: Quantify the activation matrix
Use the following formula.
Scale the corresponding activation values: the
Round and truncate values in the range $ [-128, 127] $.
Step 3: Inverse Quantization Activation
Finally, we inverse quantize the activation value: the
Recovering values using scale.
We use Layer Normalization (LN) before quantizing the activation in order to preserve the variance of the output.
Here ε is a very small value to prevent overflow
As mentioned earlier.round()
Functions are not differentiable. We use thedetach()
As a trick to implement differentiable STE (Straight-Through Estimator) in backpropagation.
# Adapted from /microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import as nn
import as F
def activation_quant(x):
scale = 127.0 / ().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127) / scale
return y
def weight_quant(w):
scale = 1.0 / ().mean().clamp_(min=1e-5)
u = (w * scale).round().clamp_(-1, 1) / scale
return u
class BitLinear():
"""
Only for training
"""
def forward(self, x):
w =
x_norm = LN(x)
# A trick for implementing Straight−Through−Estimator (STE) using detach()
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
# Perform quantized linear transformation
y = (x_quant, w_quant)
return y
inference
In the inference process, we simply weight the weights to three values without re-inverting the quantization. We take the same approach for activations, using 8-bit precision, and then perform matrix multiplication using efficient operators, followed by division by the scale of the weights and activations. This can significantly improve the speed of inference, especially on optimized hardware. As you can see, the process of inverse quantization during training is different from inference because the matrix multiplication is kept in fp16/bf16/fp32 for proper training.
# Adapted from /microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import as nn
import as F
def activation_quant_inference(x):
x = LN(x)
scale = 127.0 / ().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127)
return y, scale
class BitLinear():
"""
Only for training
"""
def forward(self, x):
w = # weights here are already quantized to (-1, 0, 1)
w_scale = self.w_scale
x_quant, x_scale = activation_quant_inference(x)
y = efficient_kernel(x_quant, w) / w_scale / x_scale
return y
1.58 Pre-training results for bits
Before attempting fine-tuning, we first tried to reproduce the results of the BitNet paper on pre-training. We used a small datasettinystoriesand aLlama3 8B model. We found that adding a normalization function as done in the paper improves performance. For example, after 2000 steps of training, our perplexity on the validation set is 6.3 without normalization and 5.9 with normalization. in both cases, training is stable.
Pre-training images with layer normalization (blue) and without (orange)
While this approach looks very interesting in pre-training, only a few organizations can afford to pre-train on a large scale. However, since a large number of powerful pre-trained models exist, it would be very useful if they could be converted to 1.58 bits after pre-training. Other groups have reported that the results of fine-tuning are not as robust as those achieved with pre-training, so we have launched a study to see if we can make 1.58 bitwise fine-tuning work.
1.58 Fine-tuning of bits
When we started fine-tuning from the pre-trained Llama3 8B weights, the model performance improved slightly, but not as much as we expected.
Note: All of the experiments were conducted inNanotron Performed on, if you're interested in trying 1.58-bit pre-training or fine-tuning, you can check out thisPR Links。
Fine-tuning curve vs. pre-training curve
To understand why, we tried to examine the weight distributions of the randomly initialized model and the pre-trained model to identify possible problems.
Randomized weight distribution (combined standard deviation of 2)
Pre-training the weight distribution of Llama3
The scales of the two distributions are.
The initial random weight distribution is a mixture of two normal distributions: the
- A standard deviation of $$ 0.025 $$
- The other standard deviation is $$ \frac{0.025}{\sqrt{2 \cdot \text{num_hidden_layers}}} = 0.00325 $$
This is because innanotron
Different standard deviations are used for column linear weights and row linear weights. In the quantized version, all matrices have only two weight scales (50.25 and 402), which are the average of the reciprocal of the absolute values of the weights of each matrix.scale = 1.0 / ().mean().clamp_(min=1e-5)
- For $$\text{scale} = 50.25 $$, $$ ().mean() = 0.0199 $$, leading to $$\text{std} = 0.025 $$, matching our first standard deviation. The formula used to derive the standard deviation is based on the expectation of the half-normal distribution of $$ |w| $$: the
- For $$\text{scale} = 402 $$, $$ ().mean() = 0.0025 $$, leading to $$\text{std} = 0.00325 $$
On the other hand, the distribution of the pretrained weights looks like a normal distribution with a standard deviation of $ 0.013 $ .
It is clear that the pre-trained model starts with more information (scale), whereas the randomly initialized model starts with practically no information and gradually adds information over time. We conclude that starting with random weights gives the model minimal initial information, thus enabling a gradual learning process, whereas the introduction of a BitLinear layer during fine-tuning causes the model to lose all previous information.
To improve the fine-tuning results, we have tried different techniques. For example, we tried using per-row and per-column quantization instead of per-tensor quantization to retain more information from the Llama 3 weights. We also tried changing the way the scales are calculated: instead of just using the mean absolute value of the weights as a scale, we used the mean absolute value of the outliers (values that are more than k times the mean absolute value, where k is a constant that we tried to vary in our experiments) as a scale, but we didn't notice a significant improvement.
def scale_outliers(tensor, threshold_factor=1):
mean_absolute_value = ((tensor))
threshold = threshold_factor * mean_absolute_value
outliers = tensor[(tensor) > threshold]
mean_outlier_value = ((outliers))
return mean_outlier_value
def weight_quant_scaling(w):
scale = 1.0 / scale_outliers(w).clamp_(min=1e-5)
quantized_weights = (w * scale).round().clamp_(-1, 1) / scale
return quantized_weights
We observe that the random weights and the Llama 3 weights start at a value of about 13 at the beginning of the loss, suggesting that the Llama 3 model loses all previous information when quantization is introduced. To further investigate how much information the model loses in the process, we tried per-group quantization.
As a plausibility check, we first set the group size to 1, which essentially means no quantization. In this case, the loss starts at 1.45, the same as in normal fine-tuning. However, when we increase the group size to 2, the loss jumps up to about 11. This shows that even with the group size minimized to 2, the model still loses almost all information.
To solve this problem, we consider introducing quantization gradually instead of suddenly applying it to the weights and activations of each tensor. To accomplish this, we introduce a lambda value to control this process: the
lambda_ = ?
x_quant = x + lambda_ *(activation_quant(x) - x).detach()
w_quant = w + lambda_ *(weight_quant(w) - w).detach()
(coll.) fail (a student)lambda
is set to 0, no quantization actually occurs, and whenlambda=1
When , full quantization will be applied .
We initially tested a number of discrete lambda values, such as 0.25, 0.5, 0.75, and 1. However, this approach did not result in a significant improvement in the results, mainly because lambda=0.25 was already high enough to make the loss start high.
Therefore, we decided to try a dynamically adapted according to the training steplambda
Value.
Using this dynamiclambda
value leads to better loss convergence, but during inference, when thelambda
When set to 1, the perplexity (perplexity or ppl) results are still far from satisfactory. We realize that this is likely due to the fact that the model is inlambda=1
case had not been trained long enough. To address this, we adjusted ourlambda
values to improve the training process.
lambda_ = min(2 * training_step / total_training_steps, 1)
In this configuration, after 2000 steps, we have.
Fine-tuned image when lambda = min(2*training_steps/total_training_steps, 1)
Our fine-tuning method shows better convergence overall. You can observe a slight increase in the loss curve at about 1000 steps, which corresponds to when we start to approach thelambda=1
or when fully quantized. Immediately after this point, however, the loss begins to converge again, leading to an improvement in the perplexity of about 4.
Despite the progress made, when we tested the quantitative model on the WikiText dataset (rather than the tinystories dataset we used for fine-tuning), the level of confusion was very high. This suggests that fine-tuning the model in low-bit mode on a specific dataset causes it to lose much of its generalized knowledge. This problem may be due to the fact that the minimal representation we are aiming for in the triple-valued weights may vary significantly between datasets. To address this issue, we extended our training procedure to include a largerFineWeb-edu Dataset. We maintain alambda
The values are.
lambda_ = min(training_step/1000, 1)
We chose thislambda
value, as it seems to be a good starting point for warming up the model. We then trained 5000 steps on the FineWeb-edu dataset using a learning rate of 1e-4. A batch size (BS) of 2B was used and a total of 10B tokens were trained.
Finding the right learning rate and the right decay rate is challenging; this seems to be a key factor in model performance.
Fine-tuned images from warmup quantification on Fineweb-edu
After fine-tuning on FineWeb-Edu, reaching a perplexity of 12.2 on the WikiText dataset is quite impressive, considering that we only used 10 billion tokens. Other evaluation metrics also show strong performance, considering the limited amount of data (see Results).
It's also a good idea to try to smooth out the sharp increase in lambda as it approaches 1. To do this, consider using lambda schedulers that grow exponentially in the beginning and then level off as they approach 1. This approach helps the model adapt more smoothly to changes in lambda values and avoids sudden fluctuations.
def scheduler(step, total_steps, k):
normalized_step = step / total_steps
return 1 - (1 - normalized_step)**k
For different values of k and a total number of warm-up steps of 1, we have the following graph.
We conducted 4 experiments using the best performing learning rate, 1e-4 , with k-values of 4, 6, 8, and 10.
Fine-tuned images when using different exponential schedulers
Smoothing was good and did not spike like the linear scheduler. However, the perplexity is not ideal, staying around 15, and performance on downstream tasks does not improve.
We also noticed spikes at the beginning from which the model struggles to recover. When lambda = 0, there is essentially no quantization, so the loss starts out low, around 2 or so. However, after the first step, there is a spike, similar to what happens with the linear scheduler (as shown in the blue graph above). Therefore, we tried another scheduler, the Sigmoid scheduler, which starts slowly, rises quickly to 1, and then stabilizes as it approaches 1.
def sigmoid_scheduler(step, total_steps, k):
# Sigmoid-like curve: slow start, fast middle, slow end
normalized_step = step / total_steps
return 1 / (1 + (-k *(normalized_step - 0.5)))
For different values of k the following curves are available.
Sigmoid scheduler for different values of k
We conducted experiments this time at k 15, 20, 25, 40 and 100.
Images fine-tuned using the Sigmoid scheduler
The sharp increase in lambda leads to instability around step 500 and does not resolve the first divergence problem. However, for $$ k = 100 $$, we observe some improvement in the downstream tasks (see the results table), although the perplexity remains around 13.5. Nonetheless, it does not show a significant performance improvement compared to the linear scheduler.
In addition, we tried experiments where we trained the model from scratch using random weights and various learning rates. This allowed us to compare the effectiveness of our fine-tuning approach with traditional pre-training methods.
Training images at different learning rates
None of the models trained from random weights performed better than our fine-tuned models. The best perplexity we achieved in these models was 26, which is slightly inferior to the results of our fine-tuning approach.
Scales to 100B tokens!
We extended the experiment to 100B tokens to see if we could achieve the performance level of the Llama 3 8B model. We performed longer training runs, starting with the checkpoints that performed best in the shorter runs, using a linear scheduler, and continued fine-tuning for 45,000 steps. We tried different learning rates, and while the model performed close to the Llama 3 model on some metrics, it still lagged a bit on average.
Here is an example of some of the metrics we evaluated at different checkpoints during training.
Evaluation results of multiple metrics for different learning rates in training
The average scores are as follows.
Mean evaluation results for different learning rates in training
Experiments on smaller models
In our initial experiments with smaller models such as SmolLM, we observed that the warmup quantization technique did not lead to as much improvement as it did for larger models. This suggests that the effectiveness of warmup quantization may be more closely related to the size and complexity of the model.
For example, here'sSmolLM 135M Loss curves for the model, comparing the case where warmup quantization is used from the beginning and the case where full quantization is used. Interestingly, the curves are very close and the obtained perplexity is not significantly different.
Smoll LLM fine-tuning experiments with and without warmup quantification
Comparison and Conclusion
BitNet performs well against benchmark methods, especially at lower bit counts. According to the paper, BitNet achieves scores comparable to 8-bit models, but with significantly lower inference costs. In the case of the 4-bit model, a method that quantizes only the weights outperforms a method that quantizes both weights and activations, since activations are more difficult to quantify. However, BitNet with 1.58-bit weights outperforms both weight-only and weight-and-activation quantization methods.
The following table shows the results for various metrics after the 10B token fine-tuning process for Llama3 8B. These results are compared with those of other model architectures to provide a comprehensive overview of performance (all evaluations were performed using theLighteval existNanotron performed on the format model).
Comparison of metrics with Llama model: linear for linear lambda scheduler, sigmoid for sigmoid scheduler (k = 100 in our case)
After being fine-tuned on 10B tokens using only triple-valued weights, the model shows impressive performance, especially when compared to other models that have undergone more extensive training. For example, it outperforms the Bitnet 7B model trained on a significantly larger dataset size of 100B tokens. In addition, it also outperforms the FBI LLM (Fully Binarized LLM) model, which was distilled on a much larger 1.26T tokens. This highlights the efficiency and effectiveness of the model, despite the relatively small size of its fine-tuning process.
For the 100B token experiment, the best performing checkpoint we have is as follows.
Comparison of metrics with Llama model after 100B tokens fine-tuning
To replicate these results, you can check out thisPR Convert the model to Nanotron format, decompress the weights (check functionunpack_weights) and use lighteval.
Note that although these models are fine-tuned from an Instruct-tuned model, they still need to be fine-tuned using the Instruct dataset. These can be considered as base models.
Operators and test standards used
To benefit from BitNet's low-precision weights, we packaged them into aint8
tensor (which reduces the number of parameters from 80 B to 28 B!). During inference, these weights must be unwrapped before performing matrix multiplication. We implemented custom kernels in Cuda and Triton to handle on-the-fly unwrapping during matrix multiplication. For the matrix multiplication itself, we used a cached chunked matrix multiplication technique. To fully understand this approach, let's first review some Cuda programming basics.
Basic GPU Concepts: Threads, Blocks, and Shared Memory
Before diving into cached chunked matrix multiplication, it's important to understand some basic GPU concepts: the
- Threads and Blocks: The GPU executes thousands of threads simultaneously. These threads are organized into blocks, and each block runs independently. The grid consists of these blocks, which represent the entire program space. For example, in matrix multiplication, each thread may be responsible for computing one cell of the output matrix.
- share memory: Each block has access to a limited amount of shared memory, which is much faster than global memory (main memory on the GPU). However, shared memory is limited in size and is shared among all threads within a block. Efficient use of shared memory is the key to improving GPU program performance.
Challenges in matrix multiplication
A simple implementation of matrix multiplication on the GPU may involve each thread computing individual elements of the result matrix by reading the required elements directly from global memory. However, this approach may be inefficient for the following reasons.
- Memory bandwidth: Access to global memory is relatively slow compared to the speed at which GPU cores perform computations. If each thread reads matrix elements directly from global memory, the access time may become a bottleneck.
- Redundant data access: In matrix multiplication, many elements of the input matrix are used multiple times. If each thread independently fetches the required data from global memory, the same data may be loaded into the GPU multiple times, resulting in inefficiency. For example, if each thread is used to compute a single element in the output matrix, the thread responsible for computing positions (i, j) will need to load row i of matrix A and column j of matrix B from global memory. However, other threads, such as the one responsible for computing position (i+1, j), will not be able to reuse this data and will have to load the same jth column from global memory again.
The concept of chunking
Chunking is a technique used to address these challenges and is primarily used in FlashAttention technology to improve the efficiency of the kernel. The basic idea is to break the matrix into smaller sub-matrices, called tiles, which fit into the GPU's shared memory. Instead of computing the entire output matrix at once, the computation is broken down into smaller chunks and processed chunk by chunk.
In the context of matrix multiplication, this means dividing matrices A and B into blocks, loading those blocks into shared memory, and then performing the multiplication on those smaller blocks. This approach allows threads to reuse data stored in fast shared memory, reducing the need for repeated accesses to global memory.
This is done as follows.
- Loading blocks into shared memory: Each thread block collaboratively loads a chunk of matrix A and a chunk of the corresponding matrix B from global memory into shared memory. This operation is performed only once for each chunk, and then that chunk is reused multiple times by the threads in the block.
- Calculating partial products: Once the block is loaded into shared memory, each thread computes the partial product. Since all threads in the block work on the same block in shared memory, they can effectively reuse data without additional access to global memory.
- Cumulative results: After computing the partial product of a block, the thread loads the next block from matrices A and B into shared memory and repeats the process. The results are accumulated in registers (or local memory) and once all blocks have been processed, the final values of the output matrix elements are written back to global memory.
Illustration of chunked matrix multiplication (source)/tutorial/pages/)
Practical considerations
Several considerations are taken into account when implementing cached chunked matrix multiplication.
- block size: The size of the block should be chosen to balance the tradeoff between the amount of data that can be put into shared memory and the number of global memory accesses.
- Memory Consolidation: Global memory accesses should be memory merged, which means that neighboring threads access neighboring memory locations.
- occupancy level: The number of threads per block and the number of blocks in the mesh should be chosen to ensure a high occupancy, i.e., as many active thread bundles (warp) as possible on the GPU (a thread bundle is a set of 32 threads) to hide memory latency.
Triton operator
Here is a triton operator that we use as a benchmark.
@(
configs=get_cuda_autotune_config(),
key=['M', 'N', 'K'],
)
@
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: , BLOCK_SIZE_N: , BLOCK_SIZE_K: ,
GROUP_SIZE_M: ,
):
pid = tl.program_id(axis=0)
num_pid_m = (M, BLOCK_SIZE_M)
num_pid_n = (N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + (0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + (0, BLOCK_SIZE_N)) % N
offs_k = (0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = ((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
for i in range(4):
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
for j in range(0, (K // 4, BLOCK_SIZE_K) ):
k = i * (K // 4, BLOCK_SIZE_K) + j
# BLOCK_SIZE_K must be a divisor of K / 4
a = (a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
b_uint8 = (b_ptrs, mask=offs_k[:, None] < K // 4 - j * BLOCK_SIZE_K, other=0)
mask = 3<<(2*i)
b = ((b_uint8 & mask) >> (2*i))
# We accumulate the tiles along the K dimension.
tensor_full = ((1,), 1, dtype=tl.int8)
accumulator += (a, ((tl.int8) - tensor_full), out_dtype=tl.int32)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator
offs_cm = pid_m * BLOCK_SIZE_M + (0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + (0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
(c_ptrs, c, mask=c_mask)
def matmul(a, b):
assert [1] == [0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K =
_, N =
c = ((M, N), device=, dtype=torch.float16)
grid = lambda META:((M, META['BLOCK_SIZE_M'])* (N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c,
M, N, K,
(0), (1),
(0), (1),
(0), (1),
)
return c
code resolution
- Determine the location of the chunks
The operator first determines the block (tile) of the output matrix that each thread block is responsible for: the
-
pid
is a unique identifier for each thread block, using thetl.program_id(axis=0)
Get. - The grid is divided into a set of thread blocks (
GROUP_SIZE_M
). Each group processes a portion of the output matrix. -
pid_m
cap (a poem)pid_n
are the coordinates of the chunks in M and N dimensions, respectively. - Calculate the offset (
offs_am
、offs_bn
、offs_k
) to determine which elements of matrices A and B will be processed by the threads in each block.
- Loading and calculating chunks
The operator uses a loop toBLOCK_SIZE_K
The block size of the iterative K dimensions. For each block: the
-
Loading chunks: Load the chunks of matrices A and B from global memory.
-
Unpacking Matrix B: The algorithm assumes that the matrix B is a matrix using
int8
values are packed, which means that each element actually represents four smaller values packed into one byte. The decompression process occurs inside the loop: the- Load from global memory
b_uint8
As a packagedint8
。 - Decompress each packed value to obtain the actual weight values used in the calculation.
- Load from global memory
-
dot product (mathematics): The kernel computes the dot product of the chunks loaded from matrices A and B and accumulates the results in the
accumulator
Center.accumulator
Stores the partial results of the chunking of the output matrix C.
- Storing results
After processing all the chunks along the K-dimension, the chunks stored in theaccumulator
The final result in thefloat16
and writes back to the corresponding chunk of matrix C in global memory. The write process uses masks to determine memory boundaries to ensure that only valid elements are written.
To get a more detailed explanation of the code, check out thisPR。
benchmarking
We benchmarked our arithmetic against that of using the@
The methods of decompressing the weights and then performing matrix multiplication at BF16 accuracy were compared and found to have nearly identical performance for both methods. To ensure accurate benchmarking, we performed the matrix multiplication operation over 2000 iterations and calculated the average time over the last 1000 iterations to eliminate any inefficiencies associated with initial loading or compilation. Below is a chart showing the results of the benchmark. We also tested various matrix sizes, where the x-axis shows the number of multiplications on a logarithmic scale and the y-axis shows the average time in milliseconds.
Comparison of Triton Algorithms
We also experimented with BitBlas, a software library designed to perform matrix operations using mixed precision. It helps optimize these operations by allowing calculations to be performed in lower-precision formats (e.g., INT8, INT4, or even INT2) rather than the traditional FP32 or FP16 formats.
Benchmark results are encouraging, as shown in the graph, where BitBlas outperforms our custom kernel at low precision and Torch'smatmul
function.
Bitblas test
However, during model loading, BitBlas needs to compile kernels that fit the shape of the weight matrix and store them in the local code base, which may increase the initial loading time.
reach a verdict
In summary, as large language models continue to scale, it is critical to reduce their computational requirements through quantization. This blog post explores an approach to 1.58-bit quantization that uses triple-valued weights. While pre-training models at 1.58 bits is resource-intensive, we have shown that with a few tricks, it is possible to fine-tune existing models to this level of accuracy, achieving efficient performance without sacrificing accuracy. By optimizing inference speed with specialized kernels, BitNet opens up new possibilities for making large language models more useful and scalable.
a thank-you note
We would like to sincerely thank Leandro von Werra, Thomas Wolf and Marc Sun for their valuable help and insights throughout this project. We would also like to thank Omar Sanseviero and Pedro Cuenca for their contributions in refining this blog post and helping us communicate our findings clearly and effectively to the AI community. In addition, we would like to thank the GeneralAI team for their pioneering work on the BitNet project. Their research has been fundamental to our efforts, and we are particularly grateful for the clear and accurate data they provided in the paper.
More resources
- H. Wang et al., BitNet: Scaling 1-bit Transformers for Large Language Models . arxiv paper
- S. Ma et al., The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits . arxiv paper
- S. Ma et al., The Era of 1-bit LLMs: Training Tips, Code and FAQ . link
- RJ. Honicky, Are All Large Language Models Really in 1.58 Bits? . blogpost
- L. Mao, CUDA Matrix Multiplication Optimization . blogpost
- Tutorial: OpenCL SGEMM tuning for Kepler . link
- CUDAMODE . github, youtube
- Wen-mei W. Hwu, David B. Kirk, Izzat El Hajj, Programming Massively Parallel Processors : A Hands-on Approach