Preface:The current Transformer architecture based on large language modeling artificial intelligence technology, due to the large investment, high cost, demanding talent needs, resulting in many enterprises to shy away. How many companies can afford to spend tens of millions or even hundreds of millions of dollars? Truly competitive technology should be lower in cost and higher in efficiency, therefore, major universities and commercial companies are no longer only concerned about the parameter size of the model, but are actively exploring innovative ways to significantly reduce the cost of research and development and the use of large language models, so that most enterprises can also be easily adopted. This latest research from Stanford University is a big step toward that goal - reducing the cost of training an 8 billion parameter model to $20. At the same time, Chinese companies have already launched vertical solutions for private AI models, servers, and front-end applications for enterprises based on this research, paving the way for the private deployment of AI in enterprises.
A group of researchers at Stanford University has introduced LoLCATs, a new approach to linearizing standard Transformer LLMs that dramatically reduces computational requirements while retaining much of the state-of-the-art (SOTA) performance. And all this for a few hours of GPU computation at a total cost of less than $20, ultimately making the model performance up to 35,500 times more efficient in terms of training inputs. It sounds incredible, but the three tricks they're demonstrating today make it all possible. This remarkable achievement may soon become an essential skill for AI engineering teams in their quest for best-in-class performance.
So, how do they do it?
The Problem with Standard LLM: Large Language Models (LLMs) have brought enormous excitement and money to the AI industry, and seem to be on a straight path to global dominance as a superior technology that can move civilization into a new era (at least, that's what they want us to believe). Surprisingly, this huge bet has forced big tech companies to invest billions to focus on one type of architecture, Transformers, yet they are extremely inefficient and therefore massively expensive, and improving the efficiency of Transformers will inevitably reduce costs.
Transformer is all we need. Simply put, AI is nothing more than a data compression algorithm that inputs data and learns patterns from it, and in turn uses this acquired knowledge to make useful predictions. Of the existing implementations, none come closer to the desired effect than Transformer for two reasons:
-
They are fully parallelizable and well suited for training models with large amounts of data.
-
The larger the model, the better the results, which has triggered a frenzy of investment and research on how to scale model models and training budgets.
However, just because Transformer-based models have unrivaled expressive power doesn't mean they're perfect, because their strengths bring their own problems.
Squared Complexity Problems. The biggest problem is that they can't compress state or memory. In other words, if you want the model to remember certain information, it has to be stored in memory in its original state. That's not how humans remember. Humans don't remember everything or the events they experience; humans only retain the parts that are considered important.
Imagine you're reading a book. If you're in chapter 11, the main character might mention something that happened earlier in the book, perhaps in chapter 1. If you paid attention, you might remember, otherwise you might need to go back and reread that part. However, this is not a problem for the memory load, because if nothing special happened in chapter 2, you can feel free to forget about it.This is not how Transformer handles information. When reading chapter 11, if you need to recall the contents of chapter 1, it will remember it immediately because it still has access to the entire contents of chapter 11. In fact, if it weren't for the KV Cache that we built into LLM reasoning, Transformer would actually be rereading the previous content as it read each new word.
Does this seem inefficient? I hope so, because it does.
But what does this mean? Simply put, this uncompressed memory grows as the sequence gets longer. If you're on chapter 15, the previous chapters will accumulate, and the memory requirements will be much greater than when you read chapter 8 (which is only seven chapters long).
To make matters worse, because of the way the underlying attention mechanism works (which we won't discuss here to save space), we need to store two kinds of information (keys and values) for each word in each section. This means that the computational and memory complexity of the Transformer is O(n²), meaning that every time the length of the sequence is doubled, the computational and memory requirements quadruple (whereas tripling would mean a ninefold increase).
With the FlashAttention technique (partly developed by today's researchers), we help to reduce the memory requirements to sub-square complexity, by avoiding materializing the entire attention matrix into memory, i.e., the above problem is partially solved. However, FlashAttention does not deal with computational complexity, which is still squared complexity (specifically O(n²*d), where n is the sequence length and d is the model dimension, i.e., the amount of digits embedded in each vector). The reason for this is because we do not compress memory, each word must focus on every previous word, and the main operation in Transformer is pairwise multiplication, so the computation is still of squared complexity in terms of sequence length.
But now, a group of researchers has found a way to extend computational complexity into the sub-square realm.
How to Linearize Attention As we mentioned before, Transformer relies on a mathematical operation called the Attention Mechanism, whose complexity is squared in terms of sequence length. Simply put, longer sequences dramatically increase computational requirements. However, other attention mechanisms (e.g., linear attention) have linear complexity, although they perform poorly on paper (i.e., linear attention models perform poorly).
But what happens if we mimic their more computationally dense counterparts by training a linear attention layer?
Therefore, the goal of LoLCaTs is to create linearized LLMs that are cost-efficient while retaining the performance of their squared complexity peers.
To do this, they divide the problem into three steps.
Step 1: Replacement Layers The first step in linearizing the model is to insert a set of linear attention layers and train them to mimic the output of the standard attention layer. As shown below, by using the mean square error (MSE) between the outputs of the two layers, we can train the new layers to behave like the original layers.
However, at this point, while the new layer is individually mimicking the peer layer, the created model is performing abnormally. To solve this problem, we need to retrain the model to perform the standard next word prediction task.
Step 2: LoRA fine-tuning Certainly, retraining a model to improve efficiency seems like a difficult-to-justify use of resources. Fortunately, full model fine-tuning is not necessary because of LoRA adapters. I've mentioned them in multiple issues of this newsletter, but the concept is that for any given task, the model is inherently low-rank. In short, only a small set of weights is important for any given prediction. Thus, we can add a small set of weights, called adapters, which are then added to each layer while keeping the original model weights 'untouched'. In a way, these adapters adjust the behavior of the model based on the data they are trained on.
Importantly, these adapters are much smaller than the actual models, so they can be trained quickly and inexpensively.
After LoRA training, we obtain a model that performs similarly to the original model on next word prediction, despite having a linear attention layer.
The model is not completely linearized. In fact, they combine standard attention (squared) with linear attention, so that if a sequence has D=N+M words, the last M words are computed using standard attention, while the first N (which are more numerous) use linear attention.
This finds a good balance between the expressiveness of standard attention and the computational efficiency of linear attention. Since languages usually have a proximity bias, which means that the nearest word is usually more important for predicting the next word than the past word, we use softmax attention (standard attention) for the closer words and linear attention for the rest.
The last M words in the sequence are computed conventionally, while the others are in linear form.
And on top of that, the researchers added a final step.
Step 3: Layered Optimization During the attention migration process, we update all layers in the model together. However, the researchers observed that this meant that the last few layers of the model had a larger mean square error (MSE), and this was especially true for deeper models like Llama 3.1 405B (which has more layers than the smaller version). As a result, they trained the layers in batches to make sure this didn't happen.
So what's the end result?
State of the art...and cheap.
LoLCaTs dramatically improves the efficiency of linearizing large LLMs by optimizing 0.2% of the model parameters over 40 million tokens, closing 80% of the performance gap with a fully softmax (standard) Transformer, despite using hundreds of times fewer orders of magnitude of training tokens, and fine-tuning up to 35,500 times more efficiently than standard fine-tuning for training.
More impressively, at all three Llama sizes, the LoLCaT model performs similarly to the original model, but enjoys significant computational efficiency from the linear attention layer, rather than full squared attention.
These results cost only a few training hours on the Llama 3.1 8B single GPU (costing less than $20).
Of course, we must acknowledge that this approach still requires a pre-trained squared Transformer to perform this layer of distillation, but inference-based optimizations like LoLCaTs will be an important part of any enterprise looking to adopt generative AI, saving significant computational costs while gaining superior performance.
Summary: LoLCaTs technology allows large language models (LLMs) to be trained at a significantly lower cost. Using this technique, an 8 billion parameter model can be trained for less than $20 while maintaining its original high performance. Compared to traditional Transformer methods, LoLCaTs dramatically improves efficiency, making model training fast and inexpensive. This breakthrough makes powerful AI capabilities available to more organizations at a lower cost, paving the way for the popularization of AI.