{"id":212,"date":"2024-11-27T07:05:11","date_gmt":"2024-11-27T07:05:11","guid":{"rendered":"https:\/\/mailitics.com\/index.php\/2024\/11\/27\/optimizing-transformer-models-for-variable-length-input-sequences-19fb88fddf71\/"},"modified":"2024-11-27T07:05:11","modified_gmt":"2024-11-27T07:05:11","slug":"optimizing-transformer-models-for-variable-length-input-sequences-19fb88fddf71","status":"publish","type":"post","link":"https:\/\/mailitics.com\/index.php\/2024\/11\/27\/optimizing-transformer-models-for-variable-length-input-sequences-19fb88fddf71\/","title":{"rendered":"Optimizing Transformer Models for Variable-Length Input Sequences"},"content":{"rendered":"<p>    Optimizing Transformer Models for Variable-Length Input Sequences<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n    <!-- no image --><br \/>\n \t<BR><br \/>\n<BR><\/BR><\/p>\n<div>\n<h4>How PyTorch NestedTensors, FlashAttention2, and xFormers can Boost Performance and Reduce AI\u00a0Costs<\/h4>\n<figure><img decoding=\"async\" alt=\"\" src=\"https:\/\/cdn-images-1.medium.com\/max\/1024\/0*KTgbpA3zQGTR4ugq\"><figcaption>Photo by <a href=\"https:\/\/unsplash.com\/@tanjazoellner?utm_source=medium&amp;utm_medium=referral\">Tanja Z\u00f6llner<\/a> on\u00a0<a href=\"https:\/\/unsplash.com\/?utm_source=medium&amp;utm_medium=referral\">Unsplash<\/a><\/figcaption><\/figure>\n<p>As generative AI (genAI) models grow in both popularity and scale, so do the computational demands and costs associated with their training and deployment. Optimizing these models is crucial for enhancing their runtime performance and reducing their operational expenses. At the heart of modern genAI systems is the Transformer architecture and its attention mechanism, which is notably compute-intensive.<\/p>\n<p>In a <a href=\"https:\/\/towardsdatascience.com\/increasing-transformer-model-efficiency-through-attention-layer-optimization-fefa6f87b1d6\">previous post<\/a>, we demonstrated how using optimized attention kernels can significantly accelerate the performance of Transformer models. In this post, we continue our exploration by addressing the challenge of variable-length input sequences\u200a\u2014\u200aan inherent property of real-world data, including documents, code, time-series, and\u00a0more.<\/p>\n<h4>The Challenge of Batching Variable-Length Input<\/h4>\n<p>In a typical deep learning workload, individual samples are grouped into batches before being copied to the GPU and fed to the AI model. Batching improves computational efficiency and often aids model convergence during training. Usually, batching involves <a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.stack.html\">stacking<\/a> all of the sample tensors along a new dimension\u200a\u2014\u200athe <em>batch<\/em> dimension. However, <a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.stack.html\">torch.stack<\/a> requires that all tensors to have the same shape, which is not the case with variable-length sequences.<\/p>\n<h4>Padding and its Inefficiencies<\/h4>\n<p>The traditional way to address this challenge is to pad the input sequences to a fixed length and then perform <a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.stack.html\">stacking<\/a>. This solution requires appropriate masking within the model so that the output is not affected by the irrelevant tensor elements. In the case of attention layers, a padding mask indicates which tokens are padding and should not be attended to (e.g., see <a href=\"https:\/\/github.com\/pytorch\/pytorch\/blob\/v2.5.1\/torch\/nn\/modules\/activation.py#L1139\">PyTorch MultiheadAttention<\/a>). However, padding can waste considerable GPU resources, increasing costs and slowing development. This is especially true for large-scale AI\u00a0models.<\/p>\n<h4>Don\u2019t Pad, Concatenate<\/h4>\n<p>One way to avoid padding is to <a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.cat.html#torch.cat\">concatenate<\/a> sequences along an existing dimension instead of <a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.stack.html\">stacking<\/a> them along a new dimension. Contrary to <a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.stack.html\">torch.stack<\/a>, <a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.cat.html#torch.cat\">torch.cat<\/a> allows inputs of different shapes. The output of concatenation is single sequence whose length equals the sum of the lengths of the individual sequences. For this solution to work, our single sequence would need to be supplemented by an attention mask that would ensure that each token only attends to other tokens in the same original sequence, in a process sometimes referred to as <a href=\"https:\/\/pytorch.org\/blog\/flexattention\/#document-maskingjagged-sequences\">document masking<\/a>. Denoting the sum of the lengths of all of the individual by <em>N <\/em>and adopting <a href=\"https:\/\/en.wikipedia.org\/wiki\/Big_O_notation\">\u201dbig O\u201d notation<\/a>, the size of this mask would need to be <em>O(N\u00b2)<\/em>, as would the compute complexity of a standard attention layer, making this solution highly inefficient.<\/p>\n<h4>Attention Layer Optimization<\/h4>\n<p>The solution to this problem comes in the form of specialized attention layers. Contrary to the standard attention layer that performs the full set of <em>O(N\u00b2) attention scores<\/em> only to mask out the irrelevant ones, these optimized attention kernels are designed to calculate only the <em>scores<\/em> that matter. In this post we will explore several solutions, each with their own distinct characteristics. These\u00a0include:<\/p>\n<ul>\n<li>\n<a href=\"https:\/\/pytorch.org\/tutorials\/intermediate\/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support\">PyTorch&#8217;s SDPA (Scaled Dot Product Attention) with NestedTensors<\/a>,<\/li>\n<li>\n<a href=\"https:\/\/github.com\/Dao-AILab\/flash-attention\">FlashAttention2<\/a>, and<\/li>\n<li>\n<a href=\"https:\/\/facebookresearch.github.io\/xformers\/components\/ops.html\">xFormers&#8217; memory-efficient attention<\/a>.<\/li>\n<\/ul>\n<h4>Integration into Existing HuggingFace Models<\/h4>\n<p>For teams working with pre-trained models, transitioning to these optimizations might seem challenging. We will demonstrate how <a href=\"https:\/\/huggingface.co\/\">HuggingFace\u2019s<\/a> APIs simplify this process, enabling developers to integrate these techniques with minimal code changes and\u00a0effort.<\/p>\n<h4><strong>Disclaimers<\/strong><\/h4>\n<ul>\n<li>Please do not interpret our use of any platforms, libraries, or optimization techniques as an endorsement for their use. The best options for you will depend greatly on the specifics of your own use-case.<\/li>\n<li>Some of the APIs discussed here are in prototype or beta stages and may change in the\u00a0future.<\/li>\n<li>The code examples provided are for demonstrative purposes only. We make no claims regarding their accuracy, optimality, or robustness.<\/li>\n<\/ul>\n<p>Special thanks to <a href=\"https:\/\/www.linkedin.com\/in\/yitzhak-levi-49a217201\/\">Yitzhak Levi<\/a> and <a href=\"https:\/\/www.linkedin.com\/in\/peleg-nahaliel-b304a61a5\/?originalSubdomain=il\">Peleg Nahaliel<\/a> for their contributions to this\u00a0post.<\/p>\n<h3>Toy LLM\u00a0Model<\/h3>\n<p>To facilitate our discussion we will define a simple generative model (partially inspired by the <a href=\"https:\/\/en.wikipedia.org\/wiki\/GPT\">GPT<\/a> model defined <a href=\"https:\/\/github.com\/karpathy\/nanoGPT\/tree\/master\">here<\/a>). For a more comprehensive guide on building language models, please see one of the many excellent tutorials available online (e.g.,\u00a0<a href=\"https:\/\/www.youtube.com\/watch?v=kCc8FmEb1nY\">here<\/a>).<\/p>\n<h4>Transformer Block<\/h4>\n<p>We begin by constructing a basic Transformer block, specifically designed to facilitate experimentation with different attention mechanisms and optimizations. While our block performs the same computation as standard Transformer blocks, we make slight modifications to the usual choice of operators in order to support the possibility of PyTorch <a href=\"https:\/\/pytorch.org\/docs\/stable\/nested.html#supported-operations\">NestedTensor<\/a> inputs (as described <a href=\"https:\/\/pytorch.org\/tutorials\/intermediate\/scaled_dot_product_attention_tutorial.html#causal-self-attention\">here<\/a>).<\/p>\n<pre># general imports<br>import time, functools<br><br># torch imports<br>import torch<br>from torch.utils.data import Dataset, DataLoader<br>import torch.nn as nn<br><br># Define Transformer settings<br>BATCH_SIZE = 32<br>NUM_HEADS = 16<br>HEAD_DIM = 64<br>DIM = NUM_HEADS * HEAD_DIM<br>DEPTH = 24<br>NUM_TOKENS = 1024<br>MAX_SEQ_LEN = 1024<br>PAD_ID = 0<br>DEVICE = 'cuda'<br><br>class MyAttentionBlock(nn.Module):<br>    def __init__(<br>            self,<br>            attn_fn,<br>            dim,<br>            num_heads,<br>            format=None,<br>            **kwargs<br>    ):<br>        super().__init__()<br>        self.attn_fn = attn_fn<br>        self.num_heads = num_heads<br>        self.dim = dim<br>        self.head_dim = dim \/\/ num_heads<br>        self.norm1 = nn.LayerNorm(dim, bias=False)<br>        self.norm2 = nn.LayerNorm(dim, bias=False)<br>        self.qkv = nn.Linear(dim, dim * 3)<br>        self.proj = nn.Linear(dim, dim)<br><br>        # mlp layers<br>        self.fc1 = nn.Linear(dim, dim * 4)<br>        self.act = nn.GELU()<br>        self.fc2 = nn.Linear(dim * 4, dim)<br><br>        self.permute = functools.partial(torch.transpose, dim0=1, dim1=2)<br>        if format == 'bshd':<br>            self.permute = nn.Identity()<br><br>    def mlp(self, x):<br>        x = self.fc1(x)<br>        x = self.act(x)<br>        x = self.fc2(x)<br>        return x<br><br>    def reshape_and_permute(self,x, batch_size):<br>        x = x.view(batch_size, -1, self.num_heads, self.head_dim)<br>        return self.permute(x)<br><br>    def forward(self, x_in, attn_mask=None):<br>        batch_size = x_in.size(0)<br>        x = self.norm1(x_in)<br>        qkv = self.qkv(x)<br><br>        # rather than first reformatting and then splitting the input<br>        # state, we first split and then reformat q, k, v in order to<br>        # support PyTorch Nested Tensors<br>        q, k, v = qkv.chunk(3, -1)<br>        q = self.reshape_and_permute(q, batch_size)<br>        k = self.reshape_and_permute(k, batch_size)<br>        v = self.reshape_and_permute(v, batch_size)<br>        <br>        # call the attn_fn with the input attn_mask<br>        x = self.attn_fn(q, k, v, attn_mask=attn_mask)<br><br>        # reformat output<br>        x = self.permute(x).reshape(batch_size, -1, self.dim)<br>        x = self.proj(x)<br>        x = x + x_in<br>        x = x + self.mlp(self.norm2(x))<br>        return x<\/pre>\n<h4>Transformer Decoder\u00a0Model<\/h4>\n<p>Building on our programmable Transformer block, we construct a typical Transformer decoder\u00a0model.<\/p>\n<pre>class MyDecoder(nn.Module):<br>    def __init__(<br>            self,<br>            block_fn,<br>            num_tokens,<br>            dim,<br>            num_heads,<br>            num_layers,<br>            max_seq_len,<br>            pad_idx=None<br>    ):<br>        super().__init__()<br>        self.num_heads = num_heads<br>        self.pad_idx = pad_idx<br>        self.embedding = nn.Embedding(num_tokens, dim, padding_idx=pad_idx)<br>        self.positional_embedding = nn.Embedding(max_seq_len, dim)<br>        self.blocks = nn.ModuleList([<br>            block_fn(<br>                dim=dim,<br>                num_heads=num_heads<br>            )<br>            for _ in range(num_layers)])<br>        self.output = nn.Linear(dim, num_tokens)<br><br>    def embed_tokens(self, input_ids, position_ids=None):<br>        x = self.embedding(input_ids)<br>        if position_ids is None:<br>            position_ids = torch.arange(input_ids.shape[1],<br>                                        device=x.device)<br>        x = x + self.positional_embedding(position_ids)<br>        return x<br><br>    def forward(self, input_ids, position_ids=None, attn_mask=None):<br>        # Embed tokens and add positional encoding<br>        x = self.embed_tokens(input_ids, position_ids)<br>        if self.pad_idx is not None:<br>            assert attn_mask is None<br>            # create a padding mask - we assume boolean masking<br>            attn_mask = (input_ids != self.pad_idx)<br>            attn_mask = attn_mask.view(BATCH_SIZE, 1, 1, -1) <br>                .expand(-1, self.num_heads, -1, -1)<br><br>        for b in self.blocks:<br>            x = b(x, attn_mask)<br><br>        logits = self.output(x)<br>        return logits<\/pre>\n<h4>Variable Length Sequence\u00a0Input<\/h4>\n<p>Next, we create a dataset containing sequences of variable lengths, where each sequence is made up of randomly generated tokens. For simplicity, we (arbitrarily) select a fixed distribution for the sequence lengths. In real-world scenarios, the distribution of sequence lengths typically reflects the nature of the data, such as the length of documents or audio segments. Note, that the distribution of lengths directly affects the computational inefficiencies caused by\u00a0padding.<\/p>\n<pre># Use random data<br>class FakeDataset(Dataset):<br>    def __len__(self):<br>        return 1000000<br><br>    def __getitem__(self, index):<br>        length = torch.randint(1, MAX_SEQ_LEN, (1,))<br>        sequence = torch.randint(1, NUM_TOKENS, (length + 1,))<br>        input = sequence[:-1]<br>        target = sequence[1:]<br>        return input, target<br><br>def pad_sequence(sequence, length, pad_val):<br>    return torch.nn.functional.pad(<br>        sequence,<br>        (0, length - sequence.shape[0]),<br>        value=pad_val<br>    )<br><br>def collate_with_padding(batch):<br>    padded_inputs = []<br>    padded_targets = []<br>    for b in batch:<br>        padded_inputs.append(pad_sequence(b[0], MAX_SEQ_LEN, PAD_ID))<br>        padded_targets.append(pad_sequence(b[1], MAX_SEQ_LEN, PAD_ID))<br>    padded_inputs = torch.stack(padded_inputs, dim=0)<br>    padded_targets = torch.stack(padded_targets, dim=0)<br>    return {<br>        'inputs': padded_inputs,<br>        'targets': padded_targets<br>    }<br><br>def data_to_device(data, device):<br>    if isinstance(data, dict):<br>        return {<br>            key: data_to_device(val,device)<br>            for key, val in data.items()<br>        }<br>    elif isinstance(data, (list, tuple)):<br>        return type(data)(<br>            data_to_device(val, device) for val in data<br>        )<br>    elif isinstance(data, torch.Tensor):<br>        return data.to(device=device, non_blocking=True)<br>    else:<br>        return data.to(device=device)<\/pre>\n<h4>Training\/Evaluation Loop<\/h4>\n<p>Lastly, we implement a <em>main<\/em> function that performs training\/evaluation on input sequences of varying\u00a0length.<\/p>\n<pre>def main(<br>    block_fn, <br>    data_collate_fn=collate_with_padding,<br>    pad_idx=None,<br>    train=True,<br>    compile=False<br>):<br>    torch.random.manual_seed(0)<br>    device = torch.device(DEVICE)<br>    torch.set_float32_matmul_precision(\"high\")<br><br>    # Create dataset and dataloader<br>    data_set = FakeDataset()<br>    data_loader = DataLoader(<br>        data_set,<br>        batch_size=BATCH_SIZE,<br>        collate_fn=data_collate_fn,<br>        num_workers=12,<br>        pin_memory=True,<br>        drop_last=True<br>    )<br><br>    model = MyDecoder(<br>        block_fn=block_fn,<br>        num_tokens=NUM_TOKENS,<br>        dim=DIM,<br>        num_heads=NUM_HEADS,<br>        num_layers=DEPTH,<br>        max_seq_len=MAX_SEQ_LEN,<br>        pad_idx=pad_idx<br>    ).to(device)<br><br>    if compile:<br>        model = torch.compile(model)<br><br>    # Define loss and optimizer<br>    criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)<br>    optimizer = torch.optim.SGD(model.parameters())<br><br>    def train_step(model, inputs, targets, <br>                   position_ids=None, attn_mask=None):<br>        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):<br>            outputs = model(inputs, position_ids, attn_mask)<br>            outputs = outputs.view(-1, NUM_TOKENS)<br>            targets = targets.flatten()<br>            loss = criterion(outputs, targets)<br>        optimizer.zero_grad(set_to_none=True)<br>        loss.backward()<br>        optimizer.step()<br><br>    @torch.no_grad()<br>    def eval_step(model, inputs, targets, <br>                  position_ids=None, attn_mask=None):<br>        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):<br>            outputs = model(inputs, position_ids, attn_mask)<br>            if outputs.is_nested:<br>                outputs = outputs.data._values<br>                targets = targets.data._values<br>            else:<br>                outputs = outputs.view(-1, NUM_TOKENS)<br>                targets = targets.flatten()<br>            loss = criterion(outputs, targets)<br>        return loss<br><br>    if train:<br>        model.train()<br>        step_fn = train_step<br>    else:<br>        model.eval()<br>        step_fn = eval_step<br><br>    t0 = time.perf_counter()<br>    summ = 0<br>    count = 0<br><br>    for step, data in enumerate(data_loader):<br>        # Copy data to GPU<br>        data = data_to_device(data, device=device)<br>        step_fn(model, data['inputs'], data['targets'],<br>                       position_ids=data.get('indices'),<br>                       attn_mask=data.get('attn_mask'))<br><br>        # Capture step time<br>        batch_time = time.perf_counter() - t0<br>        if step &gt; 20:  # Skip first steps<br>            summ += batch_time<br>            count += 1<br>        t0 = time.perf_counter()<br>        if step &gt;= 100:<br>            break<br>    print(f'average step time: {summ \/ count}')<\/pre>\n<h4>PyTorch SDPA with\u00a0Padding<\/h4>\n<p>For our baseline experiments, we configure our Transformer block to utilize PyTorch\u2019s <a href=\"https:\/\/pytorch.org\/tutorials\/intermediate\/scaled_dot_product_attention_tutorial.html\">SDPA<\/a> mechanism. In our experiments, we run both training and evaluation, both with and without <a href=\"https:\/\/pytorch.org\/tutorials\/intermediate\/torch_compile_tutorial.html\">torch.compile<\/a>. These were run on an <a href=\"https:\/\/www.nvidia.com\/en-eu\/data-center\/h100\/\">NVIDIA H100<\/a> with <a href=\"https:\/\/developer.nvidia.com\/cuda-toolkit\">CUDA 12.4<\/a> and <a href=\"https:\/\/pytorch.org\/\">PyTorch<\/a>\u00a02.5.1<\/p>\n<pre>from torch.nn.functional import scaled_dot_product_attention as sdpa<br>block_fn = functools.partial(MyAttentionBlock, attn_fn=sdpa)<br>causal_block_fn = functools.partial(<br>    MyAttentionBlock,<br>    attn_fn=functools.partial(sdpa, is_causal=True)<br>)<br><br>for mode in ['eval', 'train']:<br>    for compile in [False, True]:<br>        block_func = causal_block_fn<br>            if mode == 'train' else block_fn<br>        print(f'{mode} with {collate}, '<br>              f'{\"compiled\" if compile else \"uncompiled\"}')<br>        main(block_fn=block_func,<br>             pad_idx=PAD_ID,<br>             train=mode=='train',<br>             compile=compile)<\/pre>\n<p>Performance Results:<\/p>\n<ul>\n<li>\n<strong>Evaluation<\/strong>: 132 milliseconds (ms) without torch.compile, 130 ms with torch.compile<\/li>\n<li>\n<strong>Training<\/strong>: 342 ms without torch.compile, 299 ms with torch.compile<\/li>\n<\/ul>\n<h3>Optimizing for Variable Length\u00a0Input<\/h3>\n<p>In this section, we will explore several optimization techniques for handling variable-length input sequences in Transformer models.<\/p>\n<h4>Padding Optimization<\/h4>\n<p>Our first optimization relates not to the attention kernel but to our padding mechanism. Rather than padding the sequences in each batch to a constant length, we pad to the length of the longest sequence in the batch. The following block of code consists of our revised collation function and updated experiments.<\/p>\n<pre>def collate_pad_to_longest(batch):<br>    padded_inputs = []<br>    padded_targets = []<br>    max_length = max([b[0].shape[0] for b in batch])<br>    for b in batch:<br>        padded_inputs.append(pad_sequence(b[0], max_length, PAD_ID))<br>        padded_targets.append(pad_sequence(b[1], max_length, PAD_ID))<br>    padded_inputs = torch.stack(padded_inputs, dim=0)<br>    padded_targets = torch.stack(padded_targets, dim=0)<br>    return {<br>        'inputs': padded_inputs,<br>        'targets': padded_targets<br>    }<br><br>for mode in ['eval', 'train']:<br>    for compile in [False, True]:<br>        block_func = causal_block_fn<br>            if mode == 'train' else block_fn<br>        print(f'{mode} with {collate}, '<br>              f'{\"compiled\" if compile else \"uncompiled\"}')<br>        main(block_fn=block_func,<br>             data_collate_fn=collate_pad_to_longest,<br>             pad_idx=PAD_ID,<br>             train=mode=='train',<br>             compile=compile)<\/pre>\n<p>Padding to the longest sequence in each batch results in a slight performance acceleration:<\/p>\n<ul>\n<li>\n<strong>Evaluation<\/strong>: 129 ms without torch.compile, 116 ms with torch.compile<\/li>\n<li>\n<strong>Training<\/strong>: 337 ms without torch.compile, 294 ms with torch.compile<\/li>\n<\/ul>\n<h4>SDPA with PyTorch NestedTensors<\/h4>\n<p>Next, we take advantage of the built-in support for <a href=\"https:\/\/pytorch.org\/tutorials\/intermediate\/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support\">PyTorch NestedTensors<\/a> in SDPA in evaluation mode. Currently a prototype feature, <a href=\"https:\/\/pytorch.org\/tutorials\/prototype\/nestedtensor.html\">PyTorch NestedTensors<\/a> allows for grouping together tensors of varying length. These are sometimes referred to as <em>jagged<\/em> or <em>ragged<\/em> tensors. In the code block below, we define a collation function for grouping our sequences into NestedTensors. We also define an <em>indices <\/em>entry so that we can properly calculate the <a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.nn.Embedding.html\">positional embeddings<\/a>.<\/p>\n<p>PyTorch NestedTensors are supported by a <a href=\"https:\/\/pytorch.org\/tutorials\/prototype\/nestedtensor.html#nested-tensor-operations\">limited number of PyTorch ops<\/a>. Working around these limitations can require some creativity. For example, addition between NestedTensors is only supported when they share precisely the same \u201cjagged\u201d shape. In the code below we use a workaround to ensure that the <em>indices <\/em>entry shares the same shape as the model\u00a0inputs.<\/p>\n<pre>def nested_tensor_collate(batch):<br>    inputs = torch.nested.as_nested_tensor([b[0] for b in batch],<br>                                           layout=torch.jagged)<br>    targets = torch.nested.as_nested_tensor([b[1] for b in batch],<br>                                            layout=torch.jagged)<br>    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])<br><br>    # workaround for creating a NestedTensor with identical \"jagged\" shape<br>    xx = torch.empty_like(inputs)<br>    xx.data._values[:] = indices<br><br>    return {<br>        'inputs': inputs,<br>        'targets': targets,<br>        'indices': xx<br>    }<br><br>for compile in [False, True]:<br>    print(f'eval with nested tensors, '<br>          f'{\"compiled\" if compile else \"uncompiled\"}')<br>    main(<br>        block_fn=block_fn,<br>        data_collate_fn=nested_tensor_collate,<br>        train=False,<br>        compile=compile<br>    )<\/pre>\n<p>Although, with torch.compile, the NestedTensor optimization results in a step time of 131 ms, similar to our baseline result, in compiled mode the step time drops to 42 ms for an impressive ~3x improvement.<\/p>\n<h4>FlashAttention2<\/h4>\n<p>In our <a href=\"https:\/\/towardsdatascience.com\/increasing-transformer-model-efficiency-through-attention-layer-optimization-fefa6f87b1d6\">previous post<\/a> we demonstrated the use of <a href=\"https:\/\/github.com\/Dao-AILab\/flash-attention\">FlashAttention<\/a> and its impact on the performance of a transformer model. In this post we demonstrate the use of <a href=\"https:\/\/github.com\/Dao-AILab\/flash-attention\/blob\/v2.7.0\/hopper\/flash_attn_interface.py#L429\">flash_attn_varlen_func<\/a> from <a href=\"https:\/\/pypi.org\/project\/flash-attn\/\">flash-attn (2.7.0)<\/a>, an API designed for use with variable-sized inputs. To use this function, we concatenate all of the sequences in the batch into a single sequence. We also create a <em>cu_seqlens <\/em>tensor that points to the indices within the concatenated tensor where each of the individual sequences start. The code block below includes our collation function followed by evaluation and training experiments. Note, that <a href=\"https:\/\/github.com\/Dao-AILab\/flash-attention\/blob\/v2.7.0\/hopper\/flash_attn_interface.py#L429\">flash_attn_varlen_func<\/a> does not support torch.compile (at the time of this writing).<\/p>\n<pre>def collate_concat(batch):<br>    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)<br>    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)<br>    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])<br>    seqlens = torch.tensor([b[0].shape[0] for b in batch])<br>    seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)<br>    cu_seqlens = torch.nn.functional.pad(seqlens, (1, 0))<br><br>    return {<br>        'inputs': inputs,<br>        'targets': targets,<br>        'indices': indices,<br>        'attn_mask': cu_seqlens<br>    }<br><br>from flash_attn import flash_attn_varlen_func<br>fa_varlen = lambda q, k, v, attn_mask: flash_attn_varlen_func(<br>    q.squeeze(0),<br>    k.squeeze(0),<br>    v.squeeze(0),<br>    cu_seqlens_q=attn_mask,<br>    cu_seqlens_k=attn_mask,<br>    max_seqlen_q=MAX_SEQ_LEN,<br>    max_seqlen_k=MAX_SEQ_LEN<br>).unsqueeze(0)<br><br>fa_varlen_causal = lambda q, k, v, attn_mask: flash_attn_varlen_func(<br>    q.squeeze(0),<br>    k.squeeze(0),<br>    v.squeeze(0),<br>    cu_seqlens_q=attn_mask,<br>    cu_seqlens_k=attn_mask,<br>    max_seqlen_q=MAX_SEQ_LEN,<br>    max_seqlen_k=MAX_SEQ_LEN,<br>    causal=True<br>).unsqueeze(0)<br><br>block_fn = functools.partial(MyAttentionBlock,<br>                             attn_fn=fa_varlen,<br>                             format='bshd')<br><br>causal_block_fn = functools.partial(MyAttentionBlock,<br>                                    attn_fn=fa_varlen_causal,<br>                                    format='bshd')<br><br>print('flash-attn eval')<br>main(<br>    block_fn=block_fn,<br>    data_collate_fn=collate_concat,<br>    train=False<br>)<br><br>print('flash-attn train')<br>main(<br>    block_fn=causal_block_fn,<br>    data_collate_fn=collate_concat,<br>    train=True,<br>)<\/pre>\n<p>The impact of this optimization is dramatic, 51 ms for evaluation and 160 ms for training, amounting to 2.6x and 2.1x performance boosts compared to our baseline experiment.<\/p>\n<h4>XFormers Memory Efficient Attention<\/h4>\n<p>In our previous post we demonstrated the use of the <a href=\"https:\/\/facebookresearch.github.io\/xformers\/components\/ops.html#xformers.ops.memory_efficient_attention\">memory_efficient_attention<\/a> operator from <a href=\"https:\/\/pypi.org\/project\/xformers\/\">xFormers (0.0.28)<\/a>. Here we demonstrate the use of <a href=\"https:\/\/facebookresearch.github.io\/xformers\/_modules\/xformers\/ops\/fmha\/attn_bias.html#BlockDiagonalMask\">BlockDiagonalMask<\/a>, specifically designed for input sequences of arbitrary length. The required collation function appears in the code block below followed by the evaluation and training experiments. Note, that torch.compile failed in training\u00a0mode.<\/p>\n<pre>from xformers.ops import fmha<br>from xformers.ops import memory_efficient_attention as mea<br><br>def collate_xformer(batch):<br>    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)<br>    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)<br>    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])<br>    seqlens = [b[0].shape[0] for b in batch]<br>    batch_sizes = [1 for b in batch]<br>    block_diag = fmha.BlockDiagonalMask.from_seqlens(seqlens, device='cpu')<br>    block_diag._batch_sizes = batch_sizes<br><br>    return {<br>        'inputs': inputs,<br>        'targets': targets,<br>        'indices': indices,<br>        'attn_mask': block_diag<br>    }<br><br>mea_eval = lambda q, k, v, attn_mask: mea(<br>    q,k,v, attn_bias=attn_mask)<br><br>mea_train = lambda q, k, v, attn_mask: mea(<br>    q,k,v, attn_bias=attn_mask.make_causal())<br><br>block_fn = functools.partial(MyAttentionBlock,<br>                             attn_fn=mea_eval,<br>                             format='bshd')<br><br>causal_block_fn = functools.partial(MyAttentionBlock,<br>                             attn_fn=mea_train,<br>                             format='bshd')<br><br>print(f'xFormer Attention ')<br>for compile in [False, True]:<br>    print(f'eval with xFormer Attention, '<br>          f'{\"compiled\" if compile else \"uncompiled\"}')<br>    main(block_fn=block_fn,<br>         train=False,<br>         data_collate_fn=collate_xformer,<br>         compile=compile)<br><br>print(f'train with xFormer Attention')<br>main(block_fn=causal_block_fn,<br>     train=True,<br>     data_collate_fn=collate_xformer)<\/pre>\n<p>The resultant step time were 50 ms and 159 ms for evaluation and training without torch.compile. Evaluation with torch.compile resulted in a step time of 42\u00a0ms.<\/p>\n<h4>Results<\/h4>\n<p>The table below summarizes the results of our optimization methods.<\/p>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/1024\/1%2AoNIilOLnAXOGMTW3gZmzYg.png?ssl=1\"><figcaption>Step time results for different optimization methods (lower is better)\u200a\u2014\u200aby\u00a0Author<\/figcaption><\/figure>\n<p>The best performer for our toy model is <a href=\"https:\/\/facebookresearch.github.io\/xformers\/components\/ops.html#xformers.ops.memory_efficient_attention\">xFormer\u2019s memory_efficient_attention<\/a> which delivered a ~3x performance for evaluation and ~2x performance for training. We caution against deriving any conclusions from these results as the performance impact of different attention functions can vary significantly depending on the specific model and use\u00a0case.<\/p>\n<h3>Optimizing a HuggingFace Model for Variable-Length Input<\/h3>\n<p>The tools and techniques described above are easy to implement when creating a model from scratch. However, these days it is not uncommon for ML developers to adopt existing (pretrained) models and finetune them for their use case. While the optimizations we have described can be integrated without changing the set of model weights and without altering the model behavior, it is not entirely clear what the best way to do this is. In an ideal world, our ML framework would allow us to program the use of an attention mechanism that is optimized for variable-length inputs. In this section we demonstrate how to optimize HuggingFace models for variable-length inputs.<\/p>\n<h4>A Toy HuggingFace Model &#8211; GPT2LMHeadModel<\/h4>\n<p>To facilitate the discussion, we create a toy example in which we train a HuggingFace <a href=\"https:\/\/huggingface.co\/docs\/transformers\/v4.46.3\/en\/model_doc\/gpt2#transformers.GPT2LMHeadModel\">GPT2LMHead<\/a> model on variable-length sequences. This requires adapting our random dataset and data-padding collation function according to HuggingFace&#8217;s input specifications.<\/p>\n<pre>from transformers import GPT2Config, GPT2LMHeadModel<br><br># Use random data<br>class HuggingFaceFakeDataset(Dataset):<br>    def __len__(self):<br>        return 1000000<br><br>    def __getitem__(self, index):<br>        length = torch.randint(1, MAX_SEQ_LEN, (1,))<br>        input_ids = torch.randint(1, NUM_TOKENS, (length,))<br>        labels = input_ids.clone()<br>        labels[0] = PAD_ID # ignore first token<br>        return {<br>            'input_ids': input_ids,<br>            'labels': labels<br>        }<br>        return input_ids, labels<br><br>def hf_collate_with_padding(batch):<br>    padded_inputs = []<br>    padded_labels = []<br>    for b in batch:<br>        input_ids = b['input_ids']<br>        labels = b['labels']<br>        padded_inputs.append(pad_sequence(input_ids, MAX_SEQ_LEN, PAD_ID))<br>        padded_labels.append(pad_sequence(labels, MAX_SEQ_LEN, PAD_ID))<br>    padded_inputs = torch.stack(padded_inputs, dim=0)<br>    padded_labels = torch.stack(padded_labels, dim=0)<br>    return {<br>        'input_ids': padded_inputs,<br>        'labels': padded_labels,<br>        'attention_mask': (padded_inputs != PAD_ID)<br>    }<\/pre>\n<h4>Training Function<\/h4>\n<p>Our training function instantiates a <a href=\"https:\/\/huggingface.co\/docs\/transformers\/v4.46.3\/en\/model_doc\/gpt2#transformers.GPT2LMHeadModel\">GPT2LMHeadModel<\/a> based on the requested <a href=\"https:\/\/huggingface.co\/docs\/transformers\/v4.46.3\/en\/model_doc\/gpt2#transformers.GPT2Config\">GPT2Config<\/a> and proceeds to train it on our variable-length sequences.<\/p>\n<pre>def hf_main(<br>    config,<br>    collate_fn=hf_collate_with_padding,<br>    compile=False<br>):<br>    torch.random.manual_seed(0)<br>    device = torch.device(DEVICE)<br>    torch.set_float32_matmul_precision(\"high\")<br><br>    # Create dataset and dataloader<br>    data_set = HuggingFaceFakeDataset()<br>    data_loader = DataLoader(<br>        data_set,<br>        batch_size=BATCH_SIZE,<br>        collate_fn=collate_fn,<br>        num_workers=12 if DEVICE == \"CUDA\" else 0,<br>        pin_memory=True,<br>        drop_last=True<br>    )<br><br>    model = GPT2LMHeadModel(config).to(device)<br><br>    if compile:<br>        model = torch.compile(model)<br><br>    # Define loss and optimizer<br>    criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)<br>    optimizer = torch.optim.SGD(model.parameters())<br><br>    model.train()<br><br>    t0 = time.perf_counter()<br>    summ = 0<br>    count = 0<br><br>    for step, data in enumerate(data_loader):<br>        # Copy data to GPU<br>        data = data_to_device(data, device=device)<br>        input_ids = data['input_ids']<br>        labels = data['labels']<br>        position_ids = data.get('position_ids')<br>        attn_mask = data.get('attention_mask')<br>        with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):<br>            outputs = model(input_ids=input_ids,<br>                            position_ids=position_ids,<br>                            attention_mask=attn_mask)<br>            logits = outputs.logits[..., :-1, :].contiguous()<br>            labels = labels[..., 1:].contiguous()<br>            loss = criterion(logits.view(-1, NUM_TOKENS), labels.flatten())<br><br>        optimizer.zero_grad(set_to_none=True)<br>        loss.backward()<br>        optimizer.step()<br><br>        # Capture step time<br>        batch_time = time.perf_counter() - t0<br>        if step &gt; 20:  # Skip first steps<br>            summ += batch_time<br>            count += 1<br>        t0 = time.perf_counter()<br>        if step &gt;= 100:<br>            break<br>    print(f'average step time: {summ \/ count}')<\/pre>\n<h4>SDPA with\u00a0Padding<\/h4>\n<p>In the callback below we call our training function with the default sequence-padding collator.<\/p>\n<pre>config = GPT2Config(<br>        n_layer=DEPTH,<br>        n_embd=DIM,<br>        n_head=NUM_HEADS,<br>        vocab_size=NUM_TOKENS,<br>    )<br><br>for compile in [False, True]:<br>    print(f\"HF GPT2 train with SDPA, compile={compile}\")<br>    hf_main(config=config, compile=compile)<\/pre>\n<p>The resultant step times are 815 ms without torch.compile and 440 ms with torch.compile.<\/p>\n<h4>FlashAttention2<\/h4>\n<p>We now take advantage of HuggingFace\u2019s <a href=\"https:\/\/huggingface.co\/docs\/transformers\/v4.46.3\/en\/model_doc\/gpt2#using-flash-attention-2\">built-in support for FlashAttention2<\/a>, by setting the <em>attn_implementation <\/em>parameter to \u201cflash_attention_2\u201d. Behind the scenes, HuggingFace will <a href=\"https:\/\/github.com\/huggingface\/transformers\/blob\/v4.46.3\/src\/transformers\/modeling_flash_attention_utils.py#L246\"><em>unpad<\/em><\/a> the padded data input and then pass them to the optimized <a href=\"https:\/\/github.com\/Dao-AILab\/flash-attention\/blob\/v2.7.0\/hopper\/flash_attn_interface.py#L429\">flash_attn_varlen_func<\/a> function we saw\u00a0above:<\/p>\n<pre>flash_config = GPT2Config(<br>        n_layer=DEPTH,<br>        n_embd=DIM,<br>        n_head=NUM_HEADS,<br>        vocab_size=NUM_TOKENS,<br>        attn_implementation='flash_attention_2'<br>    )<br><br>print(f\"HF GPT2 train with flash\")<br>hf_main(config=flash_config)<\/pre>\n<p>The resultant time step is 620 ms, amounting to a 30% boost (in uncompiled mode) with just a simple flick of a\u00a0switch.<\/p>\n<h4>FlashAttention2 with Unpadded\u00a0Input<\/h4>\n<p>Of course, padding the sequences in the collation function only to have them unpadded, hardly seems sensible. In a recent <a href=\"https:\/\/huggingface.co\/blog\/packing-with-FA2\">update to HuggingFace<\/a>, support was added for passing in concatenated (unpadded) sequences to a select number of models. Unfortunately, (as of the time of this writing) our GPT2 model did not make the cut. However, adding support requires just five small line additions changes to <a href=\"https:\/\/github.com\/huggingface\/transformers\/blob\/v4.46.3\/src\/transformers\/models\/gpt2\/modeling_gpt2.py\">modeling_gpt2.py<\/a> in order to propagate the sequence <a href=\"https:\/\/github.com\/huggingface\/transformers\/blob\/v4.46.3\/src\/transformers\/models\/gpt2\/modeling_gpt2.py#L985\"><em>position_ids<\/em><\/a><em> <\/em>to the <a href=\"https:\/\/github.com\/huggingface\/transformers\/blob\/v4.46.3\/src\/transformers\/models\/gpt2\/modeling_gpt2.py#L436\">flash-attention kernel<\/a>. The full <em>patch <\/em>appears in the block\u00a0below:<\/p>\n<pre>@@ -370,0 +371 @@<br>+        position_ids = None<br>@@ -444,0 +446 @@<br>+            position_ids=position_ids<br>@@ -611,0 +614 @@<br>+        position_ids=None<br>@@ -621,0 +625 @@<br>+            position_ids=position_ids<br>@@ -1140,0 +1145 @@<br>+                    position_ids=position_ids<\/pre>\n<p>We define a collate function that concatenates our sequences and train our hugging face model on unpadded sequences. (Also see the built-in <a href=\"https:\/\/huggingface.co\/docs\/transformers\/main\/en\/main_classes\/data_collator#transformers.DataCollatorWithFlattening\">DataCollatorWithFlattening<\/a> utility.)<\/p>\n<pre>def collate_flatten(batch):<br>    input_ids = torch.concat([b['input_ids'] for b in batch]).unsqueeze(0)<br>    labels = torch.concat([b['labels'] for b in batch]).unsqueeze(0)<br>    position_ids = [torch.arange(b['input_ids'].shape[0]) for b in batch]<br>    position_ids = torch.concat(position_ids)<br><br>    return {<br>        'input_ids': input_ids,<br>        'labels': labels,<br>        'position_ids': position_ids<br>    }<br><br>print(f\"HF GPT2 train with flash, no padding\")<br>hf_main(config=flash_config, collate_fn=collate_flatten)<\/pre>\n<p>The resulting step time is 323 ms, 90% faster than running flash-attention on the padded\u00a0input.<\/p>\n<h4>Results<\/h4>\n<p>The results of our HuggingFace experiments are summarized below.<\/p>\n<figure><img data-recalc-dims=\"1\" decoding=\"async\" alt=\"\" src=\"https:\/\/i0.wp.com\/cdn-images-1.medium.com\/max\/658\/1%2AZNq4Hw1nKM4L7QMC5rOVHg.png?ssl=1\"><figcaption>Step time results for different optimization methods (lower is better)\u200a\u2014\u200aby\u00a0Author<\/figcaption><\/figure>\n<p>With little effort, we were able to boost our runtime performance by 2.5x when compared to the uncompiled baseline experiment, and by 36% when compared to the compiled\u00a0version.<\/p>\n<p>In this section, we demonstrated how the HuggingFace APIs allow us to leverage the optimized kernels in FlashAttention2, significantly boosting the training performance of existing models on sequences of varying\u00a0length.<\/p>\n<h3>Summary<\/h3>\n<p>As AI models continue to grow in both popularity and complexity, optimizing their performance has become essential for reducing runtime and costs. This is especially true for compute-intensive components like attention layers. In this post, we have continued our exploration of attention layer optimization, and demonstrated new tools and techniques for enhancing Transformer model performance. For more insights on AI model optimization, be sure to check out the <a href=\"https:\/\/towardsdatascience.com\/increasing-transformer-model-efficiency-through-attention-layer-optimization-fefa6f87b1d6\">first post<\/a> in this series as well as our <a href=\"https:\/\/chaimrand.medium.com\/\">many other posts<\/a> on this\u00a0topic.<\/p>\n<p><img loading=\"lazy\" decoding=\"async\" src=\"https:\/\/medium.com\/_\/stat?event=post.clientViewed&amp;referrerSource=full_rss&amp;postId=19fb88fddf71\" width=\"1\" height=\"1\" alt=\"\"><\/p>\n<hr>\n<p><a href=\"https:\/\/towardsdatascience.com\/optimizing-transformer-models-for-variable-length-input-sequences-19fb88fddf71\">Optimizing Transformer Models for Variable-Length Input Sequences<\/a> was originally published in <a href=\"https:\/\/towardsdatascience.com\/\">Towards Data Science<\/a> on Medium, where people are continuing the conversation by highlighting and responding to this story.<\/p>\n<\/div>\n<p> \t<BR><br \/>\n <BR><\/BR><br \/>\n    Chaim Rand<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n<a href=\"https:\/\/medium.com\/m\/global-identity-2?redirectUrl=https%3A%2F%2Ftowardsdatascience.com%2Foptimizing-transformer-models-for-variable-length-input-sequences-19fb88fddf71\">Go to original source<\/a><br \/>\n \t<BR><br \/>\n <BR><\/BR><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Optimizing Transformer Models for Variable-Length Input Sequences How PyTorch NestedTensors, FlashAttention2, and xFormers can Boost Performance and Reduce AI\u00a0Costs Photo by Tanja Z\u00f6llner on\u00a0Unsplash As generative AI (genAI) models grow in both popularity and scale, so do the computational demands and costs associated with their training and deployment. Optimizing these models is crucial for enhancing [&hellip;]<\/p>\n","protected":false},"author":2,"featured_media":0,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[62,69,77,70,75,76],"tags":[78,73,79],"class_list":["post-212","post","type-post","status-publish","format-standard","hentry","category-aimldsaimlds","category-artificial-intelligence","category-genai","category-machine-learning","category-pytorch","category-transformer-model","tag-length","tag-models","tag-sequences"],"_links":{"self":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/212"}],"collection":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/comments?post=212"}],"version-history":[{"count":0,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/212\/revisions"}],"wp:attachment":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/media?parent=212"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/categories?post=212"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/tags?post=212"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}