{"id":2858,"date":"2025-04-04T07:03:00","date_gmt":"2025-04-04T07:03:00","guid":{"rendered":"https:\/\/mailitics.com\/index.php\/2025\/04\/04\/kernel-case-study-flash-attention\/"},"modified":"2025-04-04T07:03:00","modified_gmt":"2025-04-04T07:03:00","slug":"kernel-case-study-flash-attention","status":"publish","type":"post","link":"https:\/\/mailitics.com\/index.php\/2025\/04\/04\/kernel-case-study-flash-attention\/","title":{"rendered":"Kernel Case Study: Flash Attention"},"content":{"rendered":"<p>    Kernel Case Study: Flash Attention<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n    <!-- no image --><br \/>\n \t<BR><br \/>\n<BR><\/BR><\/p>\n<div>\n<p class=\"wp-block-paragraph\"><mdspan datatext=\"el1743624162175\" class=\"mdspan-comment\">The attention<\/mdspan> mechanism is at the core of modern day transformers. But scaling the context window of these transformers was a major challenge, and it still is even though we are in the era of a million tokens + context window (Qwen 2.5\u00a0<a href=\"https:\/\/huggingface.co\/Qwen\/Qwen2.5-7B-Instruct-1M\">[1]<\/a>). There are both considerable compute and memory bound complexities in these models when we scale the context window (A naive <a href=\"https:\/\/towardsdatascience.com\/tag\/attention-mechanism\/\" title=\"Attention Mechanism\">Attention Mechanism<\/a> scales quadratically in both compute and memory requirements). Revisiting Flash Attention lets us understand the complexities of optimizing the underlying operations on GPUs and more importantly gives us a better grip on thinking what\u2019s next.<\/p>\n<p class=\"wp-block-paragraph\">Let\u2019s quickly revisit a naive attention algorithm to see what\u2019s going on.<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-8.png?ssl=1\" alt=\"\" class=\"wp-image-601051\"><figcaption class=\"wp-element-caption\">Attention Algorithm. Image by Author<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">As you can see if we are not being careful then we will end up materializing a full NxM attention matrix into the GPU HBM. Meaning the memory requirement will go up quadratically to increasing context length.<\/p>\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p class=\"wp-block-paragraph\">If you wanna learn more about the GPU memory hierarchy and its differences, <a href=\"https:\/\/aarunjith.substack.com\/p\/simplifying-cuda-kernels-with-triton\" data-type=\"link\" data-id=\"https:\/\/aarunjith.substack.com\/p\/simplifying-cuda-kernels-with-triton\">my previous post on Triton<\/a> is a good starting point. This would also be handy as we go along in this post when we get to implementing the <a href=\"https:\/\/towardsdatascience.com\/tag\/flash-attention\/\" title=\"Flash Attention\">Flash Attention<\/a> kernel in triton. The <a href=\"https:\/\/arxiv.org\/pdf\/2205.14135\" target=\"_blank\" rel=\"noreferrer noopener\">flash attention paper<\/a> also has some really good introduction to this.<\/p>\n<\/blockquote>\n<p class=\"wp-block-paragraph\">Additionally, when we look at the steps involved in executing this algorithm and its pattern of accessing the slow HBM, (which as explained later in the post could be a major bottleneck as well) we notice a few things:<\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">We have Q, K and V in the HBM initially<\/li>\n<li class=\"wp-block-list-item\">We need to access Q and K initially from the HBM to compute the dot product<\/li>\n<li class=\"wp-block-list-item\">We write the output scores back to the HBM<\/li>\n<li class=\"wp-block-list-item\">We access it again to execute the softmax, and optionally for Causal attention, like in the case of LLMs, we will have to mask this output before the softmax. The resulting full attention matrix is written again into the HBM<\/li>\n<li class=\"wp-block-list-item\">We access the HBM again to execute the final dot product, to get both the attention weights and the Value matrix to write the output back to the slow GPU memory<\/li>\n<\/ol>\n<p class=\"wp-block-paragraph\">I think you get the point. We could smartly read and write from the HBM to avoid redundant operations, to make some potential gains. This is exactly the primary motivation for the original Flash Attention algorithm.<\/p>\n<p class=\"wp-block-paragraph\">Flash Attention initially came out in 2022 <a href=\"https:\/\/arxiv.org\/pdf\/2205.14135\" target=\"_blank\" rel=\"noreferrer noopener\">[2]<\/a>, and then a year later came out with some much needed improvements in 2023 as Flash Attention v2 <a href=\"https:\/\/arxiv.org\/pdf\/2307.08691\" target=\"_blank\" rel=\"noreferrer noopener\">[3]<\/a> and again in 2024 with additional improvements for Nvidia Hopper and Blackwell GPUs [<a href=\"https:\/\/www.nvidia.com\/en-us\/data-center\/technologies\/hopper-architecture\/\" target=\"_blank\" rel=\"noreferrer noopener\">4<\/a>] as Flash Attention v3 [<a href=\"https:\/\/arxiv.org\/abs\/2407.08608\" target=\"_blank\" rel=\"noreferrer noopener\">5<\/a>]. The original attention paper identified that the attention operation is still limited by memory bandwidth rather than compute. (In the past, there have been attempts to reduce the computation complexity of Attention from <a href=\"https:\/\/arxiv.org\/pdf\/2001.04451\" target=\"_blank\" rel=\"noreferrer noopener\">O(N**2) to O(NlogN)<\/a> and lower through approximate algorithms)<\/p>\n<p class=\"wp-block-paragraph\">Flash attention proposed a fused kernel which does all of the above attention operations in one go, block-wise, to get the final attention output without ever having to realize the full N**2 attention matrix in memory, making the algorithm significantly faster. The term `fused` simply means we combine multiple operations in the GPU SRAM before invoking the much slower journey across the slower GPU memory, making the algorithm performant. All the while providing the exact attention output without any approximations.<\/p>\n<p class=\"wp-block-paragraph\">This lecture, from Stanford CS139, demonstrates brilliantly how we can think of the impact of a well thought out memory access pattern can have on an algorithm. I highly recommend you check this one out if you haven\u2019t already.<\/p>\n<figure class=\"wp-block-embed is-type-video is-provider-youtube wp-block-embed-youtube wp-embed-aspect-16-9 wp-has-aspect-ratio\">\n<div class=\"wp-block-embed__wrapper\">\n<iframe loading=\"lazy\" title=\"Stanford CS149 I Parallel Computing I 2023 I Lecture 1 - Why Parallelism? Why Efficiency?\" width=\"500\" height=\"281\" src=\"https:\/\/www.youtube.com\/embed\/V1tINV2-9p4?feature=oembed\" frameborder=\"0\" allow=\"accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share\" referrerpolicy=\"strict-origin-when-cross-origin\" allowfullscreen><\/iframe>\n<\/div>\n<\/figure>\n<p class=\"wp-block-paragraph\">Before we start diving into flash attention <mdspan datatext=\"el1743623764484\" class=\"mdspan-comment\">(it\u2019s getting tedious to type this over and over so let\u2019s agree<\/mdspan> to call it FA, shall we?) in triton there is something else that I wanted to get out of the way.<\/p>\n<h2 class=\"wp-block-heading\"><strong>Numerical Stability in exponents<\/strong><\/h2>\n<p class=\"wp-block-paragraph\">Let\u2019s take the example of FP32 numbers. <strong>float32<\/strong> (standard 32-bit float) uses 1 sign bit, 8 exponent bits, and 23 mantissa bits [<a href=\"https:\/\/en.wikipedia.org\/wiki\/Single-precision_floating-point_format\" target=\"_blank\" rel=\"noreferrer noopener\">6<\/a>]. The largest finite base for the exponent in float32 is 2<sup>127<\/sup>\u22481.7\u00d710<sup>38<\/sup>. Which implies when we look at exponents, e<sup>88<\/sup> \u2248 1.65\u00d710<sup>38<\/sup>, anything close to 88 (although in reality would be much lower to keep it safe) and we are in trouble as we could easily overflow. Here\u2019s a <a href=\"https:\/\/chatgpt.com\/share\/679d0ed9-8f48-8011-926e-e274b15ae8ae\" target=\"_blank\" rel=\"noreferrer noopener\">very interesting chat with OpenAI o1<\/a> shared by folks at <a href=\"https:\/\/allenai.org\/\">AllenAI<\/a> in their <a href=\"https:\/\/github.com\/allenai\/open-instruct\" target=\"_blank\" rel=\"noreferrer noopener\">OpenInstruct<\/a> repo. This although is talking about stabilizing KL Divergence calculations in the setting of RLHF\/RL, the ideas translate exactly to exponents as well. So to deal with the softmax situation in attention what we do is the following:<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-6.png?ssl=1\" alt=\"\" class=\"wp-image-601049\"><figcaption class=\"wp-element-caption\">Softmax with rescaling. Image by Author<\/figcaption><\/figure>\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p class=\"wp-block-paragraph\">TRICK : Let\u2019s also observe the following, if you do this:<\/p>\n<\/blockquote>\n<figure class=\"wp-block-image aligncenter size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-7.png?ssl=1\" alt=\"\" class=\"wp-image-601050\"><figcaption class=\"wp-element-caption\">Rescaling Trick. Image by Author<\/figcaption><\/figure>\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p class=\"wp-block-paragraph\">then you can rescale\/readjust values without affecting the final softmax value. This is really useful when you have an initial estimate for the maximum value, but that might change when we encounter a new set of values. I know I know, stay with me and let me explain.<\/p>\n<\/blockquote>\n<p class=\"wp-block-paragraph\"><strong>Setting the scene<\/strong><\/p>\n<p class=\"wp-block-paragraph\">Let\u2019s take a small detour into matrix multiplication.<\/p>\n<figure class=\"wp-block-image size-large\"><img data-recalc-dims=\"1\" height=\"358\" width=\"1024\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/2-1024x358.png?resize=1024%2C358&#038;ssl=1\" alt=\"\" class=\"wp-image-601052\"><figcaption class=\"wp-element-caption\">Blocked Matrix Multiplication. Image by Author<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">This shows a toy example of a blocked matrix multiplication except we have blocks only on the rows of A (green) and columns of B (Orange? Beige?). As you can see above the output O1, O2, O3 and O4 are complete (those positions need no more calculations). We just need to fill in the remaining columns in the initial rows by using the remaining columns of B. Like below:<\/p>\n<figure class=\"wp-block-image size-large\"><img data-recalc-dims=\"1\" height=\"369\" width=\"1024\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/3-1024x369.png?resize=1024%2C369&#038;ssl=1\" alt=\"\" class=\"wp-image-601053\"><figcaption class=\"wp-element-caption\">Next set of block fill the remaining spaces up. Image by Author<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">So we can fill these places in the output with a block of columns from B and a block of rows from A at a time.<\/p>\n<h2 class=\"wp-block-heading\"><strong>Connecting the dots<\/strong><\/h2>\n<p class=\"wp-block-paragraph\">When I introduced FA, I said that we never have to compute the full attention matrix and store the whole thing. So here\u2019s what we do:<\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">Compute a block of the attention matrix using a block of rows from Q and a block of columns from K. Once you get the partial attention matrix compute a few statistics and keep it in the memory.<\/li>\n<\/ol>\n<figure class=\"wp-block-image size-large\"><img data-recalc-dims=\"1\" height=\"396\" width=\"1024\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/4-1024x396.png?resize=1024%2C396&#038;ssl=1\" alt=\"\" class=\"wp-image-601054\"><figcaption class=\"wp-element-caption\">Computing block attention scores S_b, and computing the row-wise maximums. Image by Author<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">I have greyed O5 to O12 because we don\u2019t know those values yet, as they need to come from the subsequent blocks. We then transform Sb like below:<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-9.png?ssl=1\" alt=\"\" class=\"wp-image-601055\"><figcaption class=\"wp-element-caption\">Keeping a track of the current row-sum and row-maxes. Image by Author<\/figcaption><\/figure>\n<figure class=\"wp-block-image size-large\"><img data-recalc-dims=\"1\" height=\"202\" width=\"1024\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/5-1024x202.png?resize=1024%2C202&#038;ssl=1\" alt=\"\" class=\"wp-image-601056\"><figcaption class=\"wp-element-caption\">Exponents with the scaling trick. Image by Author<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">Now you have <mdspan datatext=\"el1743623960664\" class=\"mdspan-comment\">the<\/mdspan> setup for a partial softmax<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-10.png?ssl=1\" alt=\"\" class=\"wp-image-601057\"><figcaption class=\"wp-element-caption\">Partial Softmax, as the denominator is still a partial sum. Image by Author<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\"><strong>But:<\/strong><\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">What if the true maximum is in the Oi\u2019s that are yet to come?<\/li>\n<li class=\"wp-block-list-item\">The sum is still local, so we need to update this every time we see new Pi\u2019s. We know how to keep track of a sum, but what about rebasing it to the true maximum?<\/li>\n<\/ol>\n<p class=\"wp-block-paragraph\">Recall the trick above. All that we have to do is to keep a track of the maximum values we encounter for each row, and iteratively update as you see new maximums from the remaining blocks of columns from K for the same set of rows from Q.<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-11.png?ssl=1\" alt=\"\" class=\"wp-image-601058\"><figcaption class=\"wp-element-caption\">Two consecutive blocks and its row max manipulations. Image by Author<\/figcaption><\/figure>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-12.png?ssl=1\" alt=\"\" class=\"wp-image-601059\"><figcaption class=\"wp-element-caption\">Updating the estimate of our current sum with rescaling<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">We still do not want to write our partial softmax matrix into HBM. We keep it for the next step.<\/p>\n<h2 class=\"wp-block-heading\"><strong>The final dot product<\/strong><\/h2>\n<p class=\"wp-block-paragraph\">The last step in our attention computation is our dot product with V. To start we would have initialized a matrix full of 0\u2019s in our HBM as our output of shape NxD. Where N is the number of Queries as above. We use the same block size for V as we had for K except we can apply it row wise like below (The subscripts just denote that this is only a block and not the full matrix)<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-13.png?ssl=1\" alt=\"\" class=\"wp-image-601060\"><figcaption class=\"wp-element-caption\">A single block of attention scores creating a partial output. Image by Author<\/figcaption><\/figure>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-14.png?ssl=1\" alt=\"\" class=\"wp-image-601061\"><figcaption class=\"wp-element-caption\">Whereas the full output would require the sum of all these dot products. Some of which will be filled in by the blocks to come. Image by Author<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">Notice how we need the attention scores from all the blocks to get the final product. But if we calculate the local score and `accumulate` it like how we did to get the actual Ls we can form the full output at the end of processing all the blocks of columns (K<sub>b<\/sub>) for a given row block (Q<sub>b<\/sub>).<\/p>\n<h2 class=\"wp-block-heading\"><strong>Putting it all together<\/strong><\/h2>\n<p class=\"wp-block-paragraph\">Let\u2019s put all these ideas together to form the final algorithm<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-15.png?ssl=1\" alt=\"\" class=\"wp-image-601062\"><figcaption class=\"wp-element-caption\">Flash Attention V1 Algorithm. Source: <em>Tri Dao et.al [<a href=\"https:\/\/arxiv.org\/pdf\/2211.17192\">2<\/a>]<\/em><\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">To understand the notation, _<sub>ij<\/sub> implies that it is the local values for a given block of columns and rows and _i implies it\u2019s for the global output rows and Query blocks. The only part we haven\u2019t explained so far is the final update to O<sub>i<\/sub>. That\u2019s where we use all the ideas from above to get the right scaling.<\/p>\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p class=\"wp-block-paragraph\">The whole code is available\u00a0<a href=\"https:\/\/gist.github.com\/aarunjith\/adba4c4f67f9392d6e5789f7e92858b0\">as a gist here<\/a>.<\/p>\n<\/blockquote>\n<p class=\"wp-block-paragraph\">Let\u2019s see what these initializations look like in torch:<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\">def flash_attn_v1(Q, K, V, Br, Bc):\n  \"\"\"Flash Attention V1\"\"\"\n  B, N, D = Q.shape\n  M = K.shape[1]\n  Nr = int(np.ceil(N\/Br))\n  Nc = int(np.ceil(N\/Bc))\n  \n  Q = Q.to('cuda')\n  K = K.to('cuda')\n  V = V.to('cuda')\n  \n  batch_stride = Q.stride(0)\n  \n  O = torch.zeros_like(Q).to('cuda')\n  lis = torch.zeros((B, Nr, int(Br)), dtype=torch.float32).to('cuda')\n  mis = torch.ones((B, Nr, int(Br)), dtype=torch.float32).to('cuda')*-torch.inf\n  \n  grid = (B, )\n  flash_attn_v1_kernel[grid](\n      Q, K, V,\n      N, M, D,\n      Br, Bc,\n      Nr, Nc,\n      batch_stride,\n      Q.stride(1),\n      K.stride(1),\n      V.stride(1),\n      lis, mis,\n      O,\n      O.stride(1),\n  )\n  return O<\/code><\/pre>\n<blockquote class=\"wp-block-quote is-layout-flow wp-block-quote-is-layout-flow\">\n<p class=\"wp-block-paragraph\">If you are unsure about the launch grid, checkout\u00a0<a href=\"https:\/\/substack.com\/home\/post\/p-159116483\">my introduction to Triton<\/a><\/p>\n<\/blockquote>\n<p class=\"wp-block-paragraph\">Take a closer look at how we initialized our Ls and Ms. We are keeping one for each row block of Output\/Query, each of size B<sub>r<\/sub>. There are N<sub>r<\/sub> such blocks in total.<\/p>\n<p class=\"wp-block-paragraph\">In the example above I was simply using B<sub>r<\/sub> = 2 and B<sub>c<\/sub> = 2. But in the above code the initialization is based on the device capacity. I have included the calculation for a T4 GPU. For any other GPU, we need to get the SRAM capacity and adjust these numbers accordingly. Now for the actual kernel implementation:<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\"># Flash Attention V1\nimport triton\nimport triton.language as tl\nimport torch\nimport numpy as np\nimport pdb\n\n@triton.jit\ndef flash_attn_v1_kernel(\n    Q, K, V,\n    N: tl.constexpr, M: tl.constexpr, D: tl.constexpr,\n    Br: tl.constexpr,\n    Bc: tl.constexpr,\n    Nr: tl.constexpr,\n    Nc: tl.constexpr,\n    batch_stride: tl.constexpr,\n    q_rstride: tl.constexpr,\n    k_rstride: tl.constexpr, \n    v_rstride: tl.constexpr,\n    lis, mis,\n    O,\n    o_rstride: tl.constexpr):\n    \n    \"\"\"Flash Attention V1 kernel\"\"\"\n    \n    pid = tl.program_id(0)\n    \n\n    for j in range(Nc):\n        k_offset = ((tl.arange(0, Bc) + j*Bc) * k_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * M * D\n        # Using k_rstride and v_rstride as we are looking at the entire row at once, for each k v block \n        v_offset = ((tl.arange(0, Bc) + j*Bc) * v_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * M * D\n        k_mask = k_offset &lt; (pid + 1) * M*D\n        v_mask = v_offset &lt; (pid + 1) * M*D\n        k_load = tl.load(K + k_offset, mask=k_mask, other=0)\n        v_load = tl.load(V + v_offset, mask=v_mask, other=0)\n        for i in range(Nr):\n            q_offset = ((tl.arange(0, Br) + i*Br) * q_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * N * D\n            q_mask = q_offset &lt; (pid + 1) * N*D\n            q_load = tl.load(Q + q_offset, mask=q_mask, other=0)\n            # Compute attention\n            s_ij = tl.dot(q_load, tl.trans(k_load))\n            m_ij = tl.max(s_ij, axis=1, keep_dims=True)\n            p_ij = tl.exp(s_ij - m_ij)\n            l_ij = tl.sum(p_ij, axis=1, keep_dims=True)\n            \n            ml_offset = tl.arange(0, Br) + Br * i + pid * Nr * Br\n            m = tl.load(mis + ml_offset)[:, None]\n            l = tl.load(lis + ml_offset)[:, None]\n\n            m_new = tl.where(m &lt; m_ij, m_ij, m)\n\n            l_new = tl.exp(m - m_new) * l + tl.exp(m_ij - m_new) * l_ij\n\n            o_ij = tl.dot(p_ij, v_load)\n\n            output_offset = ((tl.arange(0, Br) + i*Br) * o_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * N * D\n            output_mask = output_offset &lt; (pid + 1) * N*D\n            o_current = tl.load(O + output_offset, mask=output_mask)\n\n            o_new = (1\/l_new) * (l * tl.exp(m - m_new) * o_current + tl.exp(m_ij - m_new) * o_ij)\n\n            tl.store(O + output_offset, o_new, mask=output_mask)\n            tl.store(mis + ml_offset, tl.reshape(m_new, (Br,)))\n            tl.store(lis + ml_offset, tl.reshape(l_new, (Br,)))<\/code><\/pre>\n<p class=\"wp-block-paragraph\">Let\u2019s understand whats happening here:<\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">Create 1 kernel for each NxD matrix in the batch. In reality we would have one more dimension to parallelize across, the head dimension. But for understanding the implementation I think this would suffice.<\/li>\n<li class=\"wp-block-list-item\">In each kernel we do the following:\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">For each block of columns in K and V we load up the relevant part of the matrix (B<sub>c<\/sub> x D) into the GPU SRAM (Current total SRAM usage = 2B<sub>c<\/sub>D). This stays in the SRAM till we are done with all the row blocks<\/li>\n<li class=\"wp-block-list-item\">For each row block of Q, we load the block onto SRAM as well (Current total SRAM Usage = 2B<sub>c<\/sub>D + BrD)<\/li>\n<li class=\"wp-block-list-item\">On chip we compute the dot product (s<sub>ij<\/sub>), compute the local row-maxes (m<sub>ij<\/sub>), the exp (p<sub>ij<\/sub>), and the expsum (l<sub>ij<\/sub>)<\/li>\n<li class=\"wp-block-list-item\">We load up the running stats for the i<sup>th<\/sup> row block. Two vectors of size B<sub>r<\/sub> x 1, which denotes the current global row-maxes (m<sub>i<\/sub>) and the expsum (l<sub>i<\/sub>). (Current SRAM usage: 2B<sub>c<\/sub>D + B<sub>r<\/sub>D + 2B<sub>r<\/sub>)<\/li>\n<li class=\"wp-block-list-item\">We get the new estimates for the global m<sub>i<\/sub> and l<sub>i<\/sub>.<\/li>\n<li class=\"wp-block-list-item\">We load the part of the output for this block of Q and update it using the new running stats and the exponent trick, we then write this back into the HBM. (Current SRAM usage: 2B<sub>c<\/sub>D + 2B<sub>r<\/sub>D + 2B<sub>r<\/sub>)<\/li>\n<li class=\"wp-block-list-item\">We write the updated running stats also into the HBM.<\/li>\n<\/ol>\n<\/li>\n<li class=\"wp-block-list-item\">For a matrix of any size, aka any context length, at a time we will never materialize the full attention matrix, only a part of it always.<\/li>\n<li class=\"wp-block-list-item\">We managed to fuse together all the ops into a single kernel, reducing HBM access considerably.<\/li>\n<\/ol>\n<p class=\"wp-block-paragraph\">Final SRAM usage stands although at 4BD + 2B, where B was initially calculated as M\/4d where M is the SRAM capacity. Not sure if am missing something here. Please comment if you know why this is the case!<\/p>\n<h2 class=\"wp-block-heading\"><strong>Block Sparse Attention and V2 and V3<\/strong><\/h2>\n<p class=\"wp-block-paragraph\">I will keep this short as these versions keep the core idea but figured out better and better ways to do the same.<\/p>\n<p class=\"wp-block-paragraph\">For Block Sparse Attention,<\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">Consider we had masks for each block like in the case of causal attention. If for a given block we have the masks all set to zero then we can simply skip the entire block without computing anything really. Saving FLOPs. This is where the major gains were seen. To put this into perspective, in the case of BERT pre-training the algorithm gets a 15% boost over the best performing training setup at the time, whereas for GPT-2 we get a 3x over huggingface training implementation and ~ 2x over a Megatron setup.<\/li>\n<\/ol>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/contributor.insightmediagroup.io\/wp-content\/uploads\/2025\/04\/image-16.png?ssl=1\" alt=\"\" class=\"wp-image-601063\"><figcaption class=\"wp-element-caption\">Performance gain for autoregressive models, where we have a sparse mask. Source: <em>Tri Dao et.al [<a href=\"https:\/\/arxiv.org\/pdf\/2211.17192\">2<\/a>]<\/em><\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">2. You can literally get the same performance in GPT2 in a fraction of the time, literally shaving off days from the training run, which is awesome!<\/p>\n<p class=\"wp-block-paragraph\">In V2:<\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">Notice how currently we can only do parallelization at the batch and head dimension. But if you simply just flip the order to look at all the column blocks for a given row block then we get the following advantages:\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">Each row block becomes embarrassingly parallel. How you know this is by looking at the illustrations above. You need all the column blocks for a given row block to fully form the attention output. If you were to run all the column blocks in parallel, you will end up with a race condition that will try to update the same rows of the output at the same time. But not if you do it the other way around. Although there are atomic add operators in triton which could help, they may potentially set us back.<\/li>\n<li class=\"wp-block-list-item\">We can avoid hitting the HBM to get the global Ms and Ls. We can initialize one on the chip for each kernel.<\/li>\n<li class=\"wp-block-list-item\">Also we do not have to scale all the output update terms with the new estimate of L. We can just compute stuff without dividing by L and at the end of all the column blocks, simply divide the output with the latest estimate of L, saving some FLOPS again!<\/li>\n<\/ol>\n<\/li>\n<li class=\"wp-block-list-item\">Much of the improvement also comes in the form of the backward kernel. I am omitting all the backward kernels from this. But they are a fun exercise to try and implement, although they are significantly more complex.<\/li>\n<\/ol>\n<p class=\"wp-block-paragraph\">Here are some benchmarks:<\/p>\n<figure class=\"wp-block-image\"><a href=\"https:\/\/i0.wp.com\/substackcdn.com\/image\/fetch\/f_auto%2Cq_auto%3Agood%2Cfl_progressive%3Asteep\/https%253A%252F%252Fsubstack-post-media.s3.amazonaws.com%252Fpublic%252Fimages%252Ff4fe5982-f319-4aaf-891f-da96490a1b8a_1788x1366.png?ssl=1\" target=\"_blank\" rel=\"noreferrer noopener\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/substackcdn.com\/image\/fetch\/w_1456%2Cc_limit%2Cf_auto%2Cq_auto%3Agood%2Cfl_progressive%3Asteep\/https%253A%252F%252Fsubstack-post-media.s3.amazonaws.com%252Fpublic%252Fimages%252Ff4fe5982-f319-4aaf-891f-da96490a1b8a_1788x1366.png?ssl=1\" alt=\"\"><\/a><figcaption class=\"wp-element-caption\">Performance benchmark of FA v2 against existing attention algorithms. Source: <em><em>Tri Dao et.al [<a href=\"https:\/\/arxiv.org\/pdf\/2307.08691\">3<\/a>]<\/em><\/em><\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">The actual implementations of these kernels need to take into account various nuances that we encounter in the real world. I have tried to keep it simple. But do\u00a0<a href=\"https:\/\/github.com\/Dao-AILab\/flash-attention\/tree\/main\">check them out here<\/a>.<\/p>\n<p class=\"wp-block-paragraph\">More recently in V3:<\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">Newer GPUs, especially the Hopper and Blackwell GPUs, have low precision modes (FP8 in Hopper and GP4 in Blackwell), which can double and quadruple the throughput for the same power and chip area and more specialized GEMM (General Matrix Multiply) kernels, which the previous version of the algorithm fails to capitalize on. This is because there are many operations which are non-GEMM, like softmax, which reduces the utilization of these specialized GPU kernels.<\/li>\n<li class=\"wp-block-list-item\">The FA v1 and v2 are essentially synchronous. Recall in the v2 description I mentioned that we are limited when column blocks try to write to the same output pointers, or when we have to go step by step using the output from the previous steps. Well these modern GPUs can make use special instructions to break this synchrony.<\/li>\n<\/ol>\n<figure class=\"wp-block-pullquote\">\n<blockquote>\n<p>We overlap the comparatively low-throughput non-GEMM operations involved in softmax, such as floating point multiply-add and exponential, with the asynchronous WGMMA instructions for GEMM. As part of this, we rework the FlashAttention-2 algorithm to circumvent certain sequential dependencies between softmax and the GEMMs. For example, in the 2-stage version of our algorithm, while softmax executes on one block of the scores matrix, WGMMA executes in the asynchronous proxy to compute the next block.<\/p>\n<p><cite>Flash Attention v3, Shah et.al<\/cite>\n<\/p><\/blockquote>\n<\/figure>\n<ol start=\"3\" class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">They also adapted the algorithm to target these specialized low precision Tensor cores on these new devices, significantly increasing the FLOPs.<\/li>\n<\/ol>\n<p class=\"wp-block-paragraph\">Some more benchmarks:<\/p>\n<figure class=\"wp-block-image\"><a href=\"https:\/\/i0.wp.com\/substackcdn.com\/image\/fetch\/f_auto%2Cq_auto%3Agood%2Cfl_progressive%3Asteep\/https%253A%252F%252Fsubstack-post-media.s3.amazonaws.com%252Fpublic%252Fimages%252F12fc597c-0f7a-46c9-b4f3-19a125264caa_1844x760.png?ssl=1\" target=\"_blank\" rel=\"noreferrer noopener\"><img data-recalc-dims=\"1\" decoding=\"async\" src=\"https:\/\/i0.wp.com\/substackcdn.com\/image\/fetch\/w_1456%2Cc_limit%2Cf_auto%2Cq_auto%3Agood%2Cfl_progressive%3Asteep\/https%253A%252F%252Fsubstack-post-media.s3.amazonaws.com%252Fpublic%252Fimages%252F12fc597c-0f7a-46c9-b4f3-19a125264caa_1844x760.png?ssl=1\" alt=\"\"><\/a><figcaption class=\"wp-element-caption\">FA v3 Performance gain over v2. Source: <em>Shah et. al [<a href=\"https:\/\/arxiv.org\/abs\/2407.08608\">5<\/a>]<\/em><\/figcaption><\/figure>\n<h2 class=\"wp-block-heading\"><strong>Conclusion<\/strong><\/h2>\n<p class=\"wp-block-paragraph\">There is much to admire in their work here. The floor for this technical skill level often seemed high owing to the low level details. But hopefully tools like Triton could change the game and get more people into this! The future is bright.<\/p>\n<h2 class=\"wp-block-heading\">References<\/h2>\n<p class=\"wp-block-paragraph\">[1] <a href=\"https:\/\/huggingface.co\/Qwen\/Qwen2.5-7B-Instruct-1M\">Qwen 2.5-7B-Instruct-1M Huggingface Model Page <\/a><\/p>\n<p class=\"wp-block-paragraph\">[2] <em>Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher <\/em>Re, <em><a href=\"https:\/\/arxiv.org\/pdf\/2205.14135\">FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness<\/a><\/em><\/p>\n<p class=\"wp-block-paragraph\">[3] Tri Dao, <a href=\"https:\/\/arxiv.org\/pdf\/2307.08691\"><em>FlashAttention-2:<\/em> <em>Faster Attention with Better Parallelism and Work Partitioning<\/em><\/a><\/p>\n<p class=\"wp-block-paragraph\">[4] <a href=\"https:\/\/www.nvidia.com\/en-us\/data-center\/technologies\/hopper-architecture\/\">NVIDIA Hopper Architecture Page<\/a> <\/p>\n<p class=\"wp-block-paragraph\">[5] <a href=\"https:\/\/arxiv.org\/search\/cs?searchtype=author&amp;query=Shah,+J\">Jay Shah<\/a>,\u00a0<a href=\"https:\/\/arxiv.org\/search\/cs?searchtype=author&amp;query=Bikshandi,+G\">Ganesh Bikshandi<\/a>,\u00a0<a href=\"https:\/\/arxiv.org\/search\/cs?searchtype=author&amp;query=Zhang,+Y\">Ying Zhang<\/a>,\u00a0<a href=\"https:\/\/arxiv.org\/search\/cs?searchtype=author&amp;query=Thakkar,+V\">Vijay Thakkar<\/a>,\u00a0<a href=\"https:\/\/arxiv.org\/search\/cs?searchtype=author&amp;query=Ramani,+P\">Pradeep Ramani<\/a>,\u00a0<a href=\"https:\/\/arxiv.org\/search\/cs?searchtype=author&amp;query=Dao,+T\">Tri Dao<\/a>, <a href=\"https:\/\/arxiv.org\/abs\/2407.08608\">FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision<\/a><\/p>\n<p class=\"wp-block-paragraph\">[6] <a href=\"https:\/\/en.wikipedia.org\/wiki\/Single-precision_floating-point_format\">Single-precision floating-point format, Wikipedia<\/a><\/p>\n<p class=\"wp-block-paragraph\">\n<p>The post <a href=\"https:\/\/towardsdatascience.com\/kernel-case-study-flash-attention\/\">Kernel Case Study: Flash Attention<\/a> appeared first on <a href=\"https:\/\/towardsdatascience.com\/\">Towards Data Science<\/a>.<\/p>\n<\/div>\n<p> \t<BR><br \/>\n <BR><\/BR><br \/>\n    Arunjith A<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n<a href=\"https:\/\/towardsdatascience.com\/kernel-case-study-flash-attention\/\">Go to original source<\/a><br \/>\n \t<BR><br \/>\n <BR><\/BR><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Kernel Case Study: Flash Attention The attention mechanism is at the core of modern day transformers. But scaling the context window of these transformers was a major challenge, and it still is even though we are in the era of a million tokens + context window (Qwen 2.5\u00a0[1]). There are both considerable compute and memory [&hellip;]<\/p>\n","protected":false},"author":2,"featured_media":0,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[62,2222,67,2268,71,70,229],"tags":[960,2270,2269],"class_list":["post-2858","post","type-post","status-publish","format-standard","hentry","category-aimldsaimlds","category-attention-mechanism","category-deep-dives","category-flash-attention","category-large-language-models","category-machine-learning","category-math","tag-attention","tag-flash","tag-hbm"],"_links":{"self":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/2858"}],"collection":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/comments?post=2858"}],"version-history":[{"count":0,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/2858\/revisions"}],"wp:attachment":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/media?parent=2858"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/categories?post=2858"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/tags?post=2858"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}