Location>code7788 >text

MaskLLM: from NVIDIA for learnable `N:M` sparsification of large models | NeurIPS'24

Popularity:379 ℃/2024-11-20 09:16:19

discuss a paper or thesis (old): MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models

  • Paper Address:/abs/2409.17481
  • Thesis Code:/NVlabs/MaskLLM

innovativeness


  • Propose a learnableLLMSemi-structured pruning methodsMaskLLMthat aims to leverage large-scale datasets to learn accurateN:MMasks for generic pruning and domain-specific pruning.
  • In addition, the framework facilitates sparse pattern transfer learning across different tasks for efficient training with sparsity.

Content overview


Large-scale language modeling (LLMs) is characterized by its huge number of parameters, which usually leads to significant redundancy. The paper proposes a learnable pruning methodMaskLLMinLLMsCreate a semi-structured (orN:MinMAmong the consecutive parameters there areNpatterns of non-zero values) sparsity to minimize the computational overhead in the inference process.

MaskLLMpass (a bill or inspection etc)Gumbel Softmaxsampling willN:MPattern sparsification is explicitly modeled as a learnable distribution that can be trained end-to-end on large-scale datasets and offers two significant advantages:

  1. High-quality masks that scale efficiently to large datasets and learn accurate masks.
  2. Transferability, probabilistic modeling of mask distributions enables transfer learning of sparsity across domains or tasks.

differentLLMsprioritize the use of sth.2:4Sparsity assessmentMaskLLMe.g.LLaMA-2Nemotron-4cap (a poem)GPT-3The parameter scale ranges from843Muntil (a time)15BNot equal. The empirical results show a significant improvement compared to the state-of-the-art methods, theMaskLLMSignificantly lower achieved by freezing the weights and learning the masks6.72Confusion level.

MaskLLM


N:M remoteness

N:MPattern sparsification can have a negative impact on theLLMimposes the restriction that each successive set ofMThe maximum number of parameters in theNa non-zero value. This task can be converted into a mask selection problem with a candidate set of size\(|\mathbf{S}|=\binom{M}{N} = \frac{M!}{N!(M-N)!}\) which\(|\mathbf{S}|\) denotes the size of the candidate set.\(\binom{M}{N}\) denote potentialN:MThe number of combinations of masks.

insofar as2:4Sparsity, Binary Mask\(\mathcal{M}\) must contain exactly two zeros, thus forming a discrete candidate set\(\mathbf{S}^{2:4}\) and its size is\(|\mathbf{S}^{2:4}|=\binom{4}{2}=6\) Candidates:

\[\begin{align} \mathbf{S}^{2:4} & = \{\mathcal{M} \in \mathbb{B}^{1\times4} | \sum \mathcal{M} = 2\} = \{\hat{\mathcal{M}}_1, \hat{\mathcal{M}}_2, \hat{\mathcal{M}}_3, \hat{\mathcal{M}}_4, \hat{\mathcal{M}}_5, \hat{\mathcal{M}}_6 \} \\ & = \{[1,1,0,0], [1,0,1,0], [1,0,0,1],[0,1,0,1],[0,1,1,0],[0,0,1,1]\}. \end{align} \]

For aLLM, there exists a large number of parameter blocks, denoted\(\{\mathcal{W}_i\}\) The masks for each parameter block need to be selected accordingly.\(\{\mathcal{M}_i\}\) . For the performance after pruning, forN:MSparsity defines the following loss objective:

\[\begin{equation} \{\mathcal{M}_i^{*}\} = \underset{\{\mathcal{M}_i | \mathcal{M}_i \in \mathbf{S}^{2:4}\} }{argmin} \mathbb{E}_{x\sim p(x)} \left[ \mathcal{L}_{LM}(x; \{\mathcal{W}_i \odot \mathcal{M}_i\}) \right], \label{eqn:objective} \end{equation} \]

included among these\(\mathcal{L}_{LM}\) refers to the pre-trained language modeling loss. The operator\(\odot\) denotes element-by-element multiplication for masking some of the parameters for sparsification.

Learnable semi-supervised sparsity

existLLMcontext, due to the non-differentiable nature of mask selection and the large parameter scales, finding the optimal combination of masks\({\mathcal{M}^*}\) can be extremely challenging. For this reason, the paper transforms mask selection into a sampling process.

It is not feasible to directly determine the exact optimal mask for the parameter block because the trimmedLLMThe behavior also depends on the pruning of the other parameter blocks. However, it is possible to sample masks for each block independently and assess the overall model quality after pruning.

Define a class with class probability\(p_1, p_2, \ldots p_{|\mathcal{S}|}\) The distribution of categories that satisfy the\(\sum_{j} p_j=1\) . In the random sampling phase, if a particular mask exhibits good quality during pruning, it is reasonable to adjust the category distribution by increasing the probability of the sampled mask.

With enough sampling and updating, one ends up with a set of distributions where the higher probability masks are more likely to maintain good quality after pruning.

Formally, the combinatorial problem in the above formulation is modeled from the perspective of random sampling:

\[\begin{equation} \{p^{*}(\mathcal{M}_i)\} = \underset{\{p(\mathcal{M}_i)\}}{argmin}\ \mathbb{E}_{x\sim p(x), \mathcal{M}_i \sim p(\mathcal{M}_i)} \left[ \mathcal{L}_{LM}(x; \{\mathcal{W}_i \odot \mathcal{M}_i\}) \right], \label{eqn:objective_sampling} \end{equation} \]

The above objective can be optimized by gradient descent if the gradient with respect to that distribution is available, but drawing samples from the categorical distribution is still non-differentiable.

  • Differential mask sampling

Gumbel MaxCan effectively model sampling operations and decouple the randomness of sampling into a noise variable\(\epsilon\)The Distribution according to category\(p\) Drawing samples to generate for samplingone-hotindexing\(y\)

\[\begin{equation} y=\text{onehot}(\underset{i}{argmax} [\log(p_i) + g_i]), \; g_i=-\log(-\log \epsilon_i), \; \epsilon_i\sim U(0, 1), \label{eqn:gumbel_max} \end{equation} \]

included among these\(\epsilon_i\) is a random noise that follows a uniform distribution, while the\(g_i = -\log(-\log \epsilon_i)\) be known asGumbelNoise.Gumbel MaxParameterizing sampled randomness as an independent variable\(g_i\)The only problem with differentiable sampling comes out of the\({argmax}\) cap (a poem)one-hotOperation.

To solve this problem byGumbel SoftmaxapproximateSoftmaxindex, resulting in a smooth and differentiable index\(\tilde{\mathbf{y}}=[\tilde{y}_1, \tilde{y}_2, \ldots, \tilde{y}_{|\mathbf{S}|}]\)

\[\begin{equation} \tilde{y}_i = \frac{\exp((\log(p_i) + g_i) / \tau)}{\sum_j \exp( (\log(p_j) + g_j) / \tau ) }. \label{eqn:gumbel_softmax} \end{equation} \]

Temperature parameters\(\tau\) is a hyperparameter that controls the hardness of the sample index. When the\(\tau \rightarrow 0\) When the soft index will be closer to aone-hotvector, which results in the\(\tilde{y}_i\rightarrow y_i\)

Soft Index\(\tilde{\mathbf{y}}\) Considered as a row vector, the set of masks\(\mathbf{S}\) Consider a matrix where each row\(i\) Designation\(i\) candidate mask\(\hat{\mathcal{M}}_i\) , a differentiable mask is easily constructed by simple matrix multiplication:

\[\begin{equation} \tilde{\mathcal{M}} = \tilde{\mathbf{y}} \times \mathbf{S}=\sum_{i=1}^{|\mathbf{S}|} \tilde{y}_i \cdot \hat{\mathcal{M}}_i.\label{eqn:diff_mask} \end{equation} \]

This operation generates a weighted average of the candidate masks based on the soft index, and all operations (including sampling and weighted averaging) are differentiable and relative to the probability of\(p\) The gradient of can be easily computed by being able to use differentiable masks\(\tilde{\mathcal{M}}\) to optimize the formula4The sampling problem in the

  • doLLMsmask

Based on the distribution from the base\(p\) With a differentiable mask sampled in the middle, the gradient flow can easily reach the probability of\(p_i\) that makes it an optimizable variable in the system. However, it is not usually straightforward to learn fromlogitsgenerates probabilities, but instead learns scaling factors with\(\kappa\) (used form a nominal expression)logits \(\pi_i\) , by means of Eq.\(p_i = \frac{\exp(\pi_i \cdot \kappa)}{\sum_j \exp( \pi_j \cdot \kappa )}\) to generate probabilities.

scaling factor\(\kappa\) Will be used to balancelogitscap (a poem)Gumbelthe relative size of the noise, thus controlling the randomness of the sampling. During training, all parameter blocks\(\{\mathcal{W}_i\}\) are associated with the corresponding distribution\(\{p_\pi(\mathcal{M}_i)\}\) associated and learn the optimal distribution in an end-to-end manner.

However, experiments on several large language models have revealed a new problem with learnable masks: since pruning operations produce zero parameters in the network, the gradient may vanish.

To address this problem, sparse weight regularization is introduced, which maintains appropriately large magnitudes in the remaining weights, leading to the following learning objective:

\[\begin{equation} \min_{\{p_{\pi}(\mathcal{M}_i)\}} \mathbb{E}_{x, \tilde{\mathcal{M}}_i \sim p_{\pi}(\mathcal{M}_i)} \left[ \mathcal{L}_{LM}(x; \{\mathcal{W}_i \odot \tilde{\mathcal{M}}_i\}) \right] - \lambda \sum_i \|\mathcal{W}_i \odot \tilde{\mathcal{M}}_i\|^2_2. \label{eqn:final_objective} \end{equation} \]

leave it (to sb)\(\lambda\) The weighted regularization term encourages a larger magnitude to be maintained after pruning.

  • Transfer learning for sparsity

Migration learning is one of the most popular paradigms in deep learning, while sparse migration learning constructs new masks by precomputing the masks.

The paper proposes a mask prior for initializing the distribution (Mask Prior), which can significantly improve the training efficiency and quality. The mask prior can be obtained by one-time pruning methods, such as amplitude pruning,SparseGPTcap (a poem)Wanda

Given a priori mask\(\mathcal{M}_0\) , calculate its similarity to all candidate masks:

\[\begin{equation} \text{sim}(\mathcal{M}_0, \hat{\mathcal{M}}_i) = \mathcal{M}_0 \hat{\mathcal{M}}_i^\top - \frac{1}{|\mathbf{S}|} \sum_i (\mathcal{M}_i \hat{\mathcal{M}}^\top) = \mathcal{M}_i \hat{\mathcal{M}}^\top - (N/2), \end{equation} \]

For candidate masks that are highly similar to the a priori mask, their probability is increased in the initialization phase:

\[\begin{equation} \pi_i^{\prime} = \pi_i + \sigma(\pi)* \text{sim}(\mathcal{M}_0, \hat{\mathcal{M}}_i) * \alpha, \label{eqn:prior_mask} \end{equation} \]

Among them.\(\sigma(o)\) belogitsThe standard deviation of the\(\alpha\) is the hyperparameter that controls the a priori strength. When the\(\alpha=0\) when representing the learning of a differentiable mask without any a priori.

  • Methodological summary

From the randomly initializedlogitsstarts and updates it when available using a priori masks as in Eq.10shown. Then, the optimizationlogitsto solve the formula8in the target. With a maximumlogitsmask\(\mathcal{M}_i\) will be used as the final mask for extrapolation.

Main experiments




If this article is helpful to you, please click a like or in the look at it ~~
For more content, please pay attention to WeChat public number [Xiaofei's Algorithm Engineering Notes].

work-life balance.