{"id":1110,"date":"2025-01-11T07:02:42","date_gmt":"2025-01-11T07:02:42","guid":{"rendered":"https:\/\/mailitics.com\/index.php\/2025\/01\/11\/linearizing-llama-ef7266d03050\/"},"modified":"2025-01-11T07:02:42","modified_gmt":"2025-01-11T07:02:42","slug":"linearizing-llama-ef7266d03050","status":"publish","type":"post","link":"https:\/\/mailitics.com\/index.php\/2025\/01\/11\/linearizing-llama-ef7266d03050\/","title":{"rendered":"Linearizing Llama"},"content":{"rendered":"<p>    Linearizing Llama<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n    <!-- no image --><br \/>\n \t<BR><br \/>\n<BR><\/BR><\/p>\n<div>\n<h4>Speeding up Llama: A hybrid approach to attention mechanisms<\/h4>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/1024\/1%2AXvOrjQGJ21ggNyA8bJQkPw.jpeg?ssl=1\"><figcaption>Source: Image by Author (Generated using Gemini 1.5\u00a0Flash)<\/figcaption><\/figure>\n<p>In this article, we will see how to replace softmax self-attention in Llama-3.2-1B with hybrid attention combining softmax sliding window and linear attention. This implementation will help us better understand the growing interest in linear attention research, while also examining its limitations and potential future directions.<\/p>\n<p>This walkthrough builds upon the following works:<\/p>\n<ul>\n<li><a href=\"https:\/\/arxiv.org\/abs\/2410.10254\">LoLCATs: On Low-Rank Linearizing of Large Language Models<\/a><\/li>\n<li><a href=\"https:\/\/arxiv.org\/abs\/2406.07887\">An Empirical Study of Mamba-based Language Models<\/a><\/li>\n<li><a href=\"https:\/\/towardsdatascience.com\/linearizing-attention-204d3b86cc1e\">Linearizing Attention<\/a><\/li>\n<\/ul>\n<p>This article will be mostly a recreation of the LoLCATs paper using Llama 3.2 1B, where we will replace 50% of self-attention layers in a pretrained Llama model. The article consists of four main\u00a0parts:<\/p>\n<ul>\n<li><strong>Hybrid Attention Block<\/strong><\/li>\n<li><strong>Attention Transfer<\/strong><\/li>\n<li><strong>LoRA finetuning<\/strong><\/li>\n<li><strong>Evaluation<\/strong><\/li>\n<\/ul>\n<p>The main goal of this article is that can we somehow replace softmax attention in already trained models so that we can speed up inference while not losing too much on accuracy. If we can achieve this then we can bring the cost of using LLMs down drastically!<\/p>\n<h3>LlamaSdpAttention<\/h3>\n<p>Let\u2019s see what the Llama-3.2-1B model looks\u00a0like:<\/p>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/1024\/1%2AcLBonCZ1BdaGMBlS4o3r7Q.png?ssl=1\"><figcaption>Source: Image by\u00a0Author<\/figcaption><\/figure>\n<p>As we can see we have 16 repeating decoder blocks, our focus will be on the <em>self_attn<\/em> part so the goal of this section is to understand how the LlamaSdpAttention block works! Let\u2019s see what the definition of LlamaSdpAttention is:<\/p>\n<pre>class LlamaSdpaAttention(LlamaAttention):<br>    \"\"\"<br>    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from<br>    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to<br>    SDPA API.<br>    \"\"\"<\/pre>\n<p>You can check what this function looks like using the following code:<\/p>\n<pre>import inspect<br><br>attention_layer = model.model.layers[0].self_attn<br>print(inspect.getsource(attention_layer.__class__))<\/pre>\n<p>Let\u2019s go over the main parts of this code and understand what each part is doing and see where we need to make a\u00a0change,<\/p>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/1024\/1%2AZHbyUZlAVIIm6xOW66XsPQ.png?ssl=1\"><figcaption>Source: Image by\u00a0Author<\/figcaption><\/figure>\n<p>Let\u2019s take a dummy input to be of the shape [2,4,2048] \u2192 [batch_size, seq_len, embedding dimension]. Llama uses multi-headed attn with 32\u00a0heads.<\/p>\n<h4><strong>Block 1:<\/strong><\/h4>\n<p>After proj \u2192 query_states is a tensor of [2,4,2048], key_states is a tensor of [2,4,512] and value_states is a tensor of [2,4,512].<\/p>\n<p>After view and transpose it is: query_states \u2192 [2,32,4,64] key_states \u2192 [2,8,4,64] value_states \u2192 [2,8,4,64]<\/p>\n<p>Here 64 is the embedding dimension, key and value have heads as 8 because llama uses key-value groups where basically out of the 32 total heads, groups of 4 heads share the same key_states and value_states among the 32 total\u00a0heads.<\/p>\n<h4><strong>Block 2:<\/strong><\/h4>\n<p>In this block we just apply positional encoding in particular llama uses Rotary Position Embeddings (RoPE). I won\u2019t go into detail why this is needed but you can read the following article to get a better\u00a0idea:<\/p>\n<p><a href=\"https:\/\/towardsdatascience.com\/master-positional-encoding-part-i-63c05d90a0c3\">Master Positional Encoding: Part I<\/a><\/p>\n<h4><strong>Block 3:<\/strong><\/h4>\n<p>Here we just apply the repeat_kv function which just repeats the kv value in the groups of 4, also we use past_key_value so that we can use some precomputed kv values so that we don\u2019t have to compute them again for computational efficiency.<\/p>\n<h4>Block 4:<\/h4>\n<p>Block 4 handles two main preparation steps for attention: setting up the causal mask to ensure tokens only attend to previous positions, and optimizing memory layout with contiguous tensors for efficient GPU operations.<\/p>\n<h4>Block 5:<\/h4>\n<p>This is where we apply softmax attention\u200a\u2014\u200athe component we\u2019ll be replacing in our implementation.<\/p>\n<h4>Block 6:<\/h4>\n<p>The attention output will be a tensor of shape [2, 32, 4, 64]. We convert it back to [2, 4, 2048] and apply the final output projection.<\/p>\n<p>And that\u2019s the journey of an input through Llama self-attention!<\/p>\n<h3>Hybrid Attention Block<\/h3>\n<p>So now let\u2019s look at our HybridAttention block:<\/p>\n<pre>class HybridAttention(LlamaSdpaAttention):<br>    def __init__(self, config, layer_idx=None):<br>        super().__init__(config, layer_idx=layer_idx)<br>        self.window_size = 64<br>        #self.layer_idx = layer_idx<br><br>        # Initialize learnable factors<br>        # Create one factor pair per attention head<br>        num_heads = config.num_attention_heads<br>        self.window_factors = torch.nn.Parameter(torch.ones(1, num_heads, 1, 1) * 0.5)<br>        self.linear_factors = torch.nn.Parameter(torch.ones(1, num_heads, 1, 1) * 0.5)<br><br>        self.factor_activation = torch.nn.Sigmoid()<br><br>    def sliding_window_attention(self, query_states, key_states, value_states, window_size, window_factor):<br>        \"\"\"Compute sliding window attention\"\"\"<br>        batch_size, num_heads, seq_len, head_dim = query_states.shape<br><br>        key_windows = F.pad(key_states, (0, 0, window_size - 1, 0), value=0)<br>        key_windows = key_windows.unfold(2, window_size, 1)<br><br>        value_windows = F.pad(value_states, (0, 0, window_size - 1, 0), value=0)<br>        value_windows = value_windows.unfold(2, window_size, 1)<br><br>        attn_weights = torch.einsum('bhld,bhldw-&gt;bhlw', query_states, key_windows) * (head_dim ** -0.5)<br>        attn_weights = torch.where(attn_weights == 0,<br>                                 torch.tensor(-float('inf'), device=attn_weights.device),<br>                                 attn_weights)<br><br>        # Apply learnable window factor (with sigmoid to ensure positivity)<br>        attn_weights = self.factor_activation(window_factor) * F.softmax(attn_weights, dim=-1)<br><br>        attn_output = torch.einsum('bhlw,bhldw-&gt;bhld', attn_weights, value_windows)<br>        sum_weights = attn_weights.sum(dim=-1, keepdim=True)<br><br>        return attn_output, sum_weights<br><br>    def linear_attention(self, query_states, key_states, value_states, window_size, linear_factor):<br>        \"\"\"Compute linear attention with cumsum\"\"\"<br>        def feature_map(x):<br>            return F.elu(x) + 1<br><br>        query_prime = feature_map(query_states)<br>        key_prime = feature_map(key_states)<br><br>        key_prime = F.pad(key_prime, (0, 0, window_size, 0), value=0)[:, :, :-window_size, :]<br>        value_padded = F.pad(value_states, (0, 0, window_size, 0), value=0)[:, :, :-window_size, :]<br><br>        # Compute KV<br>        kv = torch.einsum('bhlf,bhld-&gt;bhlfd', key_prime, value_padded)<br>        # Apply learnable linear factor (with sigmoid to ensure positivity)<br>        qkv = self.factor_activation(linear_factor) * torch.einsum('bhlf,bhlfd-&gt;bhld',<br>                                                                  query_prime,<br>                                                                  kv.cumsum(dim=2))<br><br>        sum_k = key_prime.cumsum(dim=2)<br>        sum_qk = self.factor_activation(linear_factor) * torch.einsum('bhld,bhld-&gt;bhl',<br>                                                                     query_prime,<br>                                                                     sum_k)[..., None]<br>        sum_qk = torch.where(sum_qk == 0, torch.tensor(1e-12, device=sum_qk.device), sum_qk)<br><br>        return qkv, sum_qk<br><br>    def hybrid_attention(self, query_states, key_states, value_states):<br>        \"\"\"Combine sliding window and linear attention with learnable factors\"\"\"<br>        qkv_window, sum_window = self.sliding_window_attention(<br>            query_states, key_states, value_states,<br>            self.window_size, self.window_factors<br>        )<br><br>        qkv_linear, sum_linear = self.linear_attention(<br>            query_states, key_states, value_states,<br>            self.window_size, self.linear_factors<br>        )<br><br>        output = (qkv_window + qkv_linear) \/ (sum_window + sum_linear)<br>        return output<br><br>    def forward(<br>        self,<br>        hidden_states: torch.Tensor,<br>        attention_mask: Optional[torch.Tensor] = None,<br>        position_ids: Optional[torch.LongTensor] = None,<br>        past_key_value: Optional[Cache] = None,<br>        output_attentions: bool = False,<br>        use_cache: bool = False,<br>        cache_position: Optional[torch.LongTensor] = None,<br>        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,<br>        **kwargs,<br>    ):<br>        bsz, q_len, _ = hidden_states.size()<br><br>        query_states = self.q_proj(hidden_states)<br>        key_states = self.k_proj(hidden_states)<br>        value_states = self.v_proj(hidden_states)<br><br>        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)<br>        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)<br>        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)<br><br>        if position_embeddings is None:<br>            cos, sin = self.rotary_emb(value_states, position_ids)<br>        else:<br>            cos, sin = position_embeddings<br>        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)<br><br>        if past_key_value is not None:<br>            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}<br>            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)<br><br>        key_states = repeat_kv(key_states, self.num_key_value_groups)<br>        value_states = repeat_kv(value_states, self.num_key_value_groups)<br><br>        attn_output = self.hybrid_attention(<br>            query_states,<br>            key_states,<br>            value_states<br>        )<br><br>        attn_output = attn_output.transpose(1, 2).contiguous()<br>        attn_output = attn_output.view(bsz, q_len, -1)<br>        attn_output = self.o_proj(attn_output)<br><br>        return attn_output, None, past_key_value<\/pre>\n<p>We only made one change in forward(), we replaced block 5 with the following:<\/p>\n<pre>attn_output = self.hybrid_attention(<br>            query_states,<br>            key_states,<br>            value_states<br>        )<\/pre>\n<p>We basically partitioned the attention mechanism into <strong>sliding window<\/strong> and <strong>linear attention<\/strong> blocks.<\/p>\n<h4>Sliding Window Attention:<\/h4>\n<pre>def sliding_window_attention(self, query_states, key_states, value_states, window_size, window_factor):<br>        \"\"\"Compute sliding window attention\"\"\"<br>        batch_size, num_heads, seq_len, head_dim = query_states.shape<br><br>        key_windows = F.pad(key_states, (0, 0, window_size - 1, 0), value=0)<br>        key_windows = key_windows.unfold(2, window_size, 1)<br><br>        value_windows = F.pad(value_states, (0, 0, window_size - 1, 0), value=0)<br>        value_windows = value_windows.unfold(2, window_size, 1)<br><br>        attn_weights = torch.einsum('bhld,bhldw-&gt;bhlw', query_states, key_windows) * (head_dim ** -0.5)<br>        attn_weights = torch.where(attn_weights == 0,<br>                                 torch.tensor(-float('inf'), device=attn_weights.device),<br>                                 attn_weights)<br><br>        # Apply learnable window factor (with sigmoid to ensure positivity)<br>        attn_weights = self.factor_activation(window_factor) * F.softmax(attn_weights, dim=-1)<br><br>        attn_output = torch.einsum('bhlw,bhldw-&gt;bhld', attn_weights, value_windows)<br>        sum_weights = attn_weights.sum(dim=-1, keepdim=True)<br><br>        return attn_output, sum_weights<\/pre>\n<p>For a deeper understanding of window attention concepts, I recommend referring to this\u00a0paper:<\/p>\n<p><a href=\"https:\/\/arxiv.org\/abs\/2309.17453\">Efficient Streaming Language Models with Attention Sinks<\/a><\/p>\n<p>The idea I have implemented here is that instead of calculating the attention of all key-value pairs together(where each token attends to every other token), we break it into windows of \u2018w\u2019 size and then calculate the attention for each window. Using this in the above code, the time complexity comes down from O(n\u00b2) to O(n*w), since each token only needs to attend to w tokens instead of all n tokens. It can be made even better by using concepts such as sinks and only doing window for last w tokens which I might implement in future\u00a0updates.<\/p>\n<h4>Linear Attention:<\/h4>\n<pre>def linear_attention(self, query_states, key_states, value_states, window_size, linear_factor):<br>        \"\"\"Compute linear attention with cumsum\"\"\"<br>        def feature_map(x):<br>            return F.elu(x) + 1<br><br>        query_prime = feature_map(query_states)<br>        key_prime = feature_map(key_states)<br><br>        key_prime = F.pad(key_prime, (0, 0, window_size, 0), value=0)[:, :, :-window_size, :]<br>        value_padded = F.pad(value_states, (0, 0, window_size, 0), value=0)[:, :, :-window_size, :]<br><br>        # Compute KV<br>        kv = torch.einsum('bhlf,bhld-&gt;bhlfd', key_prime, value_padded)<br>        # Apply learnable linear factor (with sigmoid to ensure positivity)<br>        qkv = self.factor_activation(linear_factor) * torch.einsum('bhlf,bhlfd-&gt;bhld',<br>                                                                  query_prime,<br>                                                                  kv.cumsum(dim=2))<br><br>        sum_k = key_prime.cumsum(dim=2)<br>        sum_qk = self.factor_activation(linear_factor) * torch.einsum('bhld,bhld-&gt;bhl',<br>                                                                     query_prime,<br>                                                                     sum_k)[..., None]<br>        sum_qk = torch.where(sum_qk == 0, torch.tensor(1e-12, device=sum_qk.device), sum_qk)<br><br>        return qkv, sum_qk<\/pre>\n<p>For linear attention, I use a very simple feature map of elu(x) + 1 but the main part to note there is the initial padding being done. The idea here is that we can use linear attention only for the first [sequence length\u200a\u2014\u200awindow size] as we already have sliding window to keep track of recent\u00a0context.<\/p>\n<p>The combination of these two types of attention becomes our new hybrid attention and we use <em>window_factor<\/em> and <em>linear_factor<\/em> as learnable parameters that control how much each type of attention contributes to the final\u00a0output.<\/p>\n<p>Now that we have our hybrid block, taking inspiration from the \u201c<a href=\"https:\/\/arxiv.org\/abs\/2406.07887\"><strong>An Empirical Study of Mamba-based Language Models<\/strong><\/a>\u201d paper, we will replace only half the softmax attention layers that too in an alternate order. Llama-3.2-1B has 16 softmax attention layers and we shall replace 8 of those in the order: [0,2,4,6,8,10,12,14].<\/p>\n<h3>Attention Transfer<\/h3>\n<p>The implementation follows the methodology described in \u201c<a href=\"https:\/\/arxiv.org\/abs\/2410.10254\"><strong>LoLCATs: On Low-Rank Linearizing of Large Language Models<\/strong><\/a>\u201d. The attention transfer step involves initializing 8 hybrid blocks with the weights from the original blocks and for training I used 1M tokens from the 10B version of <a href=\"https:\/\/huggingface.co\/datasets\/HuggingFaceFW\/fineweb-edu\">fineweb-edu<\/a>[1].<\/p>\n<p>The basic goal here is that, we will freeze all the parameters in llama-3.2\u20131B and then do a forward pass with one train input. Using this we can get the input and output of each of our self attention blocks. We can then pass this same input from the corresponding hybrid block and then take the MSE loss between the two and train the hybrid blocks. What this helps us do is to explicitly tell the hybrid block to mimic the output of softmax attention which will help preserve accuracy. We do this separately for all the blocks and once trained we can replace the the self attention in llama-3.2\u20131B with our hybrid blocks now. Taking a sample output from this new model looks something like,<\/p>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/1024\/1%2Ai_HI5Kj0A3v4_941I7Y-kA.png?ssl=1\"><figcaption>Source: Image by\u00a0Author<\/figcaption><\/figure>\n<p>The current model outputs lack coherence and meaning\u200a\u2014\u200aan issue that our next implementation phase will specifically target and\u00a0resolve.<\/p>\n<p>The code for this step\u200a\u2014\u200a<a href=\"https:\/\/github.com\/shitanshubhushan\/Linearizing-Llama-3.2-1B\/blob\/main\/Llama_attn_transfer.ipynb\">Llama_attn_transfer.ipynb<\/a><\/p>\n<h3>LoRA Finetune<\/h3>\n<p>I won\u2019t go into the details of LoRA, you could go through the following article if you want to understand LoRA\u00a0better:<\/p>\n<p><a href=\"https:\/\/towardsdatascience.com\/lora-intuitively-and-exhaustively-explained-e944a6bff46b\">LoRA\u200a\u2014\u200aIntuitively and Exhaustively Explained<\/a><\/p>\n<p>But the main goal with this step is that so far we trained each hybrid block separately to mimic softmax but we still haven\u2019t trained\/finetuned the entire model post adding these blocks to actually work together for text generation. So in this step we use the <a href=\"https:\/\/huggingface.co\/datasets\/databricks\/databricks-dolly-15k\">Dolly-15K Dataset<\/a>[2] which is an instruction tuning dataset to finetune our model for text generation using LoRA and we only finetune the parameters in the hybrid attention blocks while every other parameter is\u00a0frozen.<\/p>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/1024\/1%2AEYL8XCes-zv8kwf6sTqgFw.png?ssl=1\"><figcaption>Source: Image by\u00a0Author<\/figcaption><\/figure>\n<p>We can clearly see the model is able to generate much better text post this finetuning. Now after attention transfer and finetuning, we have a model we can actually benchmark!<\/p>\n<p>The code for this step\u200a\u2014\u200a<a href=\"https:\/\/github.com\/shitanshubhushan\/Linearizing-Llama-3.2-1B\/blob\/main\/llama_lora_finetune.ipynb\">llama_lora_finetune.ipynb<\/a><\/p>\n<h3>Evaluation<\/h3>\n<p>We went through all these steps so now it\u2019s time compare our hybrid model with the original Llama-3.2-1B. Our main expectations are that our model should be faster during inference while its accuracy should remain reasonably close to that of Llama-3.2-1B.<\/p>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/1024\/1%2AcgXvPhPVTpUg-bexswnRcQ.png?ssl=1\"><figcaption>Source: Image by\u00a0Author<\/figcaption><\/figure>\n<p>Evaluating both models on throughput for sequence-lengths ranging from 2\u2070 to 2\u00b9\u2075, we can see that initially both models are pretty close in performance. However, as the sequence length increases, the hybrid model becomes notably faster than the base model\u200a\u2014\u200amatching our expectations. It\u2019s important to note that these tokens\/sec measurements vary significantly depending on the GPU\u00a0used.<\/p>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/1024\/1%2Acf59v9ZzmaAQLbs5k6g5qg.png?ssl=1\"><figcaption>Source: Image by\u00a0Author<\/figcaption><\/figure>\n<p>Looking at seconds taken per token, we see a similar pattern: initially, both models have nearly the same speed, but as the sequence length increases, we observe the computational advantages that linear + sliding window attention brings.<\/p>\n<p>\u2611\ufe0f We meet our first expectation that our hybrid is faster than llama-3.2-1B.<\/p>\n<p>Now let\u2019s look at accuracy, For this, I benchmarked the models on <a href=\"https:\/\/huggingface.co\/datasets\/cais\/mmlu\">MMLU<\/a>[3] where each model had to answer multiple-choice questions with 4 options. The model\u2019s prediction is determined by examining the logits it assigns to tokens [\u2018A\u2019, \u2018B\u2019, \u2018C\u2019, \u2018D\u2019], with the highest logit indicating the predicted answer.<\/p>\n<pre>\u2554\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2566\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2566\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2566\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2557<br>\u2551          Model          \u2551 Num Shot \u2551    GPU    \u2551 macro_avg\/acc_char \u2551<br>\u2560\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u256c\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u256c\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u256c\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2563<br>\u2551 Hybrid                  \u2551        5 \u2551 RTX A6000 \u2551              27.36 \u2551<br>\u2551 Llama 3.2 1B (No Cache) \u2551        5 \u2551 RTX A6000 \u2551              25.38 \u2551<br>\u2551 Llama 3.2 1B (No Cache) \u2551        5 \u2551 L40S      \u2551              32.13 \u2551<br>\u2551 Hybrid                  \u2551        0 \u2551 RTX A6000 \u2551              27.26 \u2551<br>\u2551 Llama 3.2 1B (No Cache) \u2551        0 \u2551 RTX A6000 \u2551              25.50 \u2551<br>\u255a\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2569\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2569\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2569\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u255d<\/pre>\n<p>The test results reveal an intriguing insight into model evaluation. While the Hybrid model slightly outperforms Llama-3.2-1B, this difference (approximately 2%) should be considered insignificant, especially given that the Hybrid model underwent additional training, particularly with instruction tuning datasets.<\/p>\n<p>The most fascinating observation is the substantial performance variance when running identical code on different GPUs. When Llama-3.2-1B was run on an L40S GPU versus an RTX A6000, the accuracy jumped from 25.38% to 32.13%\u200a\u2014\u200aa significant difference considering all other variables remained constant. This difference comes down to how different GPUs handle floating-point operations, which shows just how much hardware choices can unexpectedly affect your model\u2019s performance.<\/p>\n<p>Another striking finding is the lack of difference between 5-shot and 0-shot performance in these results, particularly on the RTX A6000. This is unexpected, as 5-shot prompting typically improves performance, especially for base models like Llama-3.2-1B. In fact, when running the Llama-3.2-1B on the L40S GPU, I have observed a notable gap between 5-shot and 0-shot scores\u200a\u2014\u200aagain highlighting how GPU differences can affect benchmark scores.<\/p>\n<p>It would be a fun future exercise to benchmark the same model with all the same variables but with different GPUs.<\/p>\n<ul>\n<li>MMLU 0-shot evaluation code\u200a\u2014<a href=\"https:\/\/github.com\/shitanshubhushan\/Linearizing-Llama-3.2-1B\/blob\/main\/MMLU_eval-0shot.ipynb\">\u200aMMLU_eval-0shot.ipynb<\/a>\n<\/li>\n<li>MMLU 5-shot evaluation code\u200a\u2014\u200a<a href=\"https:\/\/github.com\/shitanshubhushan\/Linearizing-Llama-3.2-1B\/blob\/main\/MMLU_eval-5shot.ipynb\">MMLU_eval-5shot.ipynb<\/a>\n<\/li>\n<li>Inference speed evaluation code\u200a\u2014\u200a<a href=\"https:\/\/github.com\/shitanshubhushan\/Linearizing-Llama-3.2-1B\/blob\/main\/Linear_llama_eval_inference_speed.ipynb\">Linear_llama_eval_inference_speed.ipynb<\/a>\n<\/li>\n<\/ul>\n<h3>Conclusion<\/h3>\n<p>I hope this article has demonstrated both the potential of softmax attention alternatives and the inherent strengths of traditional softmax attention. Using relatively modest computational resources and a small dataset, we were able to achieve faster inference speeds while maintaining comparable accuracy levels with our hybrid approach.<\/p>\n<p>Another point to understand is that softmax based attention transformers have gone through a lot of hardware optimizations which make them competitive with linear alternatives when it comes to computational complexity, if the same effort is put into architectures like mamba maybe they can be more competitive then.<\/p>\n<p>A promising approach is using a hybrid of softmax attention and linear attention alternatives to try to get the best of both worlds. Nvidia did this in \u201c<a href=\"https:\/\/arxiv.org\/abs\/2406.07887\"><strong>An Empirical Study of Mamba-based Language Models<\/strong><\/a>\u201d and showed how a hybrid approach is an effective alternative.<\/p>\n<p>Hopefully you all learnt something from this\u00a0article!<\/p>\n<p>All the code for this can be found at\u200a\u2014\u200a<a href=\"https:\/\/github.com\/shitanshubhushan\/Linearizing-Llama-3.2-1B\/tree\/main\">Linearizing-Llama-3.2\u20131B<\/a><\/p>\n<h3>Acknowledgment<\/h3>\n<p>This blog post was inspired by coursework from my graduate studies during Fall 2024 at University of Michigan. While the courses provided the foundational knowledge and motivation to explore these topics, any errors or misinterpretations in this article are entirely my own. This represents my personal understanding and exploration of the material.<\/p>\n<h3>License References<\/h3>\n<p>[1]\u200a\u2014\u200afineweb-edu: The dataset is released under the Open Data Commons Attribution License (ODC-By) v1.0\u00a0<a href=\"https:\/\/opendatacommons.org\/licenses\/by\/1-0\/\">license<\/a>.<\/p>\n<p>[2]\u200a\u2014\u200aDolly-15K: The dataset is subject to CC BY-SA 3.0\u00a0license.<\/p>\n<p>[3]\u200a\u2014\u200aMMLU: MIT\u00a0license<\/p>\n<p><img loading=\"lazy\" decoding=\"async\" src=\"https:\/\/medium.com\/_\/stat?event=post.clientViewed&amp;referrerSource=full_rss&amp;postId=ef7266d03050\" width=\"1\" height=\"1\" alt=\"\"><\/p>\n<hr>\n<p><a href=\"https:\/\/towardsdatascience.com\/linearizing-llama-ef7266d03050\">Linearizing Llama<\/a> was originally published in <a href=\"https:\/\/towardsdatascience.com\/\">Towards Data Science<\/a> on Medium, where people are continuing the conversation by highlighting and responding to this story.<\/p>\n<\/div>\n<p> \t<BR><br \/>\n <BR><\/BR><br \/>\n    Shitanshu Bhushan<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n<a href=\"https:\/\/medium.com\/m\/global-identity-2?redirectUrl=https%3A%2F%2Ftowardsdatascience.com%2Flinearizing-llama-ef7266d03050\">Go to original source<\/a><br \/>\n \t<BR><br \/>\n <BR><\/BR><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Linearizing Llama Speeding up Llama: A hybrid approach to attention mechanisms Source: Image by Author (Generated using Gemini 1.5\u00a0Flash) In this article, we will see how to replace softmax self-attention in Llama-3.2-1B with hybrid attention combining softmax sliding window and linear attention. This implementation will help us better understand the growing interest in linear attention [&hellip;]<\/p>\n","protected":false},"author":2,"featured_media":0,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[62,959,71,164,87,260],"tags":[960,474,878],"class_list":["post-1110","post","type-post","status-publish","format-standard","hentry","category-aimldsaimlds","category-attention","category-large-language-models","category-llama-3","category-llm","category-nlp","tag-attention","tag-llama","tag-using"],"_links":{"self":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/1110"}],"collection":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/comments?post=1110"}],"version-history":[{"count":0,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/1110\/revisions"}],"wp:attachment":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/media?parent=1110"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/categories?post=1110"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/tags?post=1110"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}