discuss a paper or thesis (old): MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models
- Paper Address:/abs/2409.17481
- Thesis Code:/NVlabs/MaskLLM
innovativeness
- Propose a learnable
LLM
Semi-structured pruning methodsMaskLLM
that aims to leverage large-scale datasets to learn accurateN:M
Masks for generic pruning and domain-specific pruning. - In addition, the framework facilitates sparse pattern transfer learning across different tasks for efficient training with sparsity.
Content overview
Large-scale language modeling (LLMs
) is characterized by its huge number of parameters, which usually leads to significant redundancy. The paper proposes a learnable pruning methodMaskLLM
inLLMs
Create a semi-structured (orN:M
inM
Among the consecutive parameters there areN
patterns of non-zero values) sparsity to minimize the computational overhead in the inference process.
MaskLLM
pass (a bill or inspection etc)Gumbel Softmax
sampling willN:M
Pattern sparsification is explicitly modeled as a learnable distribution that can be trained end-to-end on large-scale datasets and offers two significant advantages:
- High-quality masks that scale efficiently to large datasets and learn accurate masks.
- Transferability, probabilistic modeling of mask distributions enables transfer learning of sparsity across domains or tasks.
differentLLMs
prioritize the use of sth.2
:4
Sparsity assessmentMaskLLM
e.g.LLaMA-2
、Nemotron-4
cap (a poem)GPT-3
The parameter scale ranges from843M
until (a time)15B
Not equal. The empirical results show a significant improvement compared to the state-of-the-art methods, theMaskLLM
Significantly lower achieved by freezing the weights and learning the masks6.72
Confusion level.
MaskLLM
N:M
remoteness
N:M
Pattern sparsification can have a negative impact on theLLM
imposes the restriction that each successive set ofM
The maximum number of parameters in theN
a non-zero value. This task can be converted into a mask selection problem with a candidate set of size\(|\mathbf{S}|=\binom{M}{N} = \frac{M!}{N!(M-N)!}\) which\(|\mathbf{S}|\) denotes the size of the candidate set.\(\binom{M}{N}\) denote potentialN:M
The number of combinations of masks.
insofar as2
:4
Sparsity, Binary Mask\(\mathcal{M}\) must contain exactly two zeros, thus forming a discrete candidate set\(\mathbf{S}^{2:4}\) and its size is\(|\mathbf{S}^{2:4}|=\binom{4}{2}=6\) Candidates:
For aLLM
, there exists a large number of parameter blocks, denoted\(\{\mathcal{W}_i\}\) The masks for each parameter block need to be selected accordingly.\(\{\mathcal{M}_i\}\) . For the performance after pruning, forN:M
Sparsity defines the following loss objective:
included among these\(\mathcal{L}_{LM}\) refers to the pre-trained language modeling loss. The operator\(\odot\) denotes element-by-element multiplication for masking some of the parameters for sparsification.
Learnable semi-supervised sparsity
existLLM
context, due to the non-differentiable nature of mask selection and the large parameter scales, finding the optimal combination of masks\({\mathcal{M}^*}\) can be extremely challenging. For this reason, the paper transforms mask selection into a sampling process.
It is not feasible to directly determine the exact optimal mask for the parameter block because the trimmedLLM
The behavior also depends on the pruning of the other parameter blocks. However, it is possible to sample masks for each block independently and assess the overall model quality after pruning.
Define a class with class probability\(p_1, p_2, \ldots p_{|\mathcal{S}|}\) The distribution of categories that satisfy the\(\sum_{j} p_j=1\) . In the random sampling phase, if a particular mask exhibits good quality during pruning, it is reasonable to adjust the category distribution by increasing the probability of the sampled mask.
With enough sampling and updating, one ends up with a set of distributions where the higher probability masks are more likely to maintain good quality after pruning.
Formally, the combinatorial problem in the above formulation is modeled from the perspective of random sampling:
The above objective can be optimized by gradient descent if the gradient with respect to that distribution is available, but drawing samples from the categorical distribution is still non-differentiable.
-
Differential mask sampling
Gumbel Max
Can effectively model sampling operations and decouple the randomness of sampling into a noise variable\(\epsilon\)The Distribution according to category\(p\) Drawing samples to generate for samplingone-hot
indexing\(y\) :
included among these\(\epsilon_i\) is a random noise that follows a uniform distribution, while the\(g_i = -\log(-\log \epsilon_i)\) be known asGumbel
Noise.Gumbel Max
Parameterizing sampled randomness as an independent variable\(g_i\)The only problem with differentiable sampling comes out of the\({argmax}\) cap (a poem)one-hot
Operation.
To solve this problem byGumbel Softmax
approximateSoftmax
index, resulting in a smooth and differentiable index\(\tilde{\mathbf{y}}=[\tilde{y}_1, \tilde{y}_2, \ldots, \tilde{y}_{|\mathbf{S}|}]\) :
Temperature parameters\(\tau\) is a hyperparameter that controls the hardness of the sample index. When the\(\tau \rightarrow 0\) When the soft index will be closer to aone-hot
vector, which results in the\(\tilde{y}_i\rightarrow y_i\) 。
Soft Index\(\tilde{\mathbf{y}}\) Considered as a row vector, the set of masks\(\mathbf{S}\) Consider a matrix where each row\(i\) Designation\(i\) candidate mask\(\hat{\mathcal{M}}_i\) , a differentiable mask is easily constructed by simple matrix multiplication:
This operation generates a weighted average of the candidate masks based on the soft index, and all operations (including sampling and weighted averaging) are differentiable and relative to the probability of\(p\) The gradient of can be easily computed by being able to use differentiable masks\(\tilde{\mathcal{M}}\) to optimize the formula4
The sampling problem in the
-
do
LLMs
mask
Based on the distribution from the base\(p\) With a differentiable mask sampled in the middle, the gradient flow can easily reach the probability of\(p_i\) that makes it an optimizable variable in the system. However, it is not usually straightforward to learn fromlogits
generates probabilities, but instead learns scaling factors with\(\kappa\) (used form a nominal expression)logits
\(\pi_i\) , by means of Eq.\(p_i = \frac{\exp(\pi_i \cdot \kappa)}{\sum_j \exp( \pi_j \cdot \kappa )}\) to generate probabilities.
scaling factor\(\kappa\) Will be used to balancelogits
cap (a poem)Gumbel
the relative size of the noise, thus controlling the randomness of the sampling. During training, all parameter blocks\(\{\mathcal{W}_i\}\) are associated with the corresponding distribution\(\{p_\pi(\mathcal{M}_i)\}\) associated and learn the optimal distribution in an end-to-end manner.
However, experiments on several large language models have revealed a new problem with learnable masks: since pruning operations produce zero parameters in the network, the gradient may vanish.
To address this problem, sparse weight regularization is introduced, which maintains appropriately large magnitudes in the remaining weights, leading to the following learning objective:
leave it (to sb)\(\lambda\) The weighted regularization term encourages a larger magnitude to be maintained after pruning.
-
Transfer learning for sparsity
Migration learning is one of the most popular paradigms in deep learning, while sparse migration learning constructs new masks by precomputing the masks.
The paper proposes a mask prior for initializing the distribution (Mask Prior
), which can significantly improve the training efficiency and quality. The mask prior can be obtained by one-time pruning methods, such as amplitude pruning,SparseGPT
cap (a poem)Wanda
。
Given a priori mask\(\mathcal{M}_0\) , calculate its similarity to all candidate masks:
For candidate masks that are highly similar to the a priori mask, their probability is increased in the initialization phase:
Among them.\(\sigma(o)\) belogits
The standard deviation of the\(\alpha\) is the hyperparameter that controls the a priori strength. When the\(\alpha=0\) when representing the learning of a differentiable mask without any a priori.
-
Methodological summary
From the randomly initializedlogits
starts and updates it when available using a priori masks as in Eq.10
shown. Then, the optimizationlogits
to solve the formula8
in the target. With a maximumlogits
mask\(\mathcal{M}_i\) will be used as the final mask for extrapolation.
Main experiments
If this article is helpful to you, please click a like or in the look at it ~~
For more content, please pay attention to WeChat public number [Xiaofei's Algorithm Engineering Notes].