{"id":1712,"date":"2025-02-07T07:02:21","date_gmt":"2025-02-07T07:02:21","guid":{"rendered":"https:\/\/mailitics.com\/index.php\/2025\/02\/07\/efficient-metric-collection-in-pytorch-avoiding-the-performance-pitfalls-of-torchmetrics\/"},"modified":"2025-02-07T07:02:21","modified_gmt":"2025-02-07T07:02:21","slug":"efficient-metric-collection-in-pytorch-avoiding-the-performance-pitfalls-of-torchmetrics","status":"publish","type":"post","link":"https:\/\/mailitics.com\/index.php\/2025\/02\/07\/efficient-metric-collection-in-pytorch-avoiding-the-performance-pitfalls-of-torchmetrics\/","title":{"rendered":"Efficient Metric Collection in PyTorch: Avoiding the Performance Pitfalls of TorchMetrics"},"content":{"rendered":"<p>    Efficient Metric Collection in PyTorch: Avoiding the Performance Pitfalls of TorchMetrics<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n    <!-- no image --><br \/>\n \t<BR><br \/>\n<BR><\/BR><\/p>\n<div>\n<p class=\"wp-block-paragraph\">Metric collection is an essential part of every machine learning project, enabling us to track model performance and monitor training progress. Ideally, <a href=\"https:\/\/towardsdatascience.com\/tag\/metrics\/\" title=\"Metrics\">Metrics<\/a> should be collected and computed without introducing any additional overhead to the training process. However, just like other components of the training loop, inefficient metric computation can introduce unnecessary overhead, increase training-step times and inflate training costs.<\/p>\n<p class=\"wp-block-paragraph\">This post is the seventh in our series on <a href=\"https:\/\/medium.com\/@chaimrand\/pytorch-model-performance-analysis-and-optimization-10c3c5822869\" target=\"_blank\" rel=\"noreferrer noopener\">performance profiling and optimization in PyTorch<\/a>. The series has aimed to emphasize the critical role of performance analysis and <a href=\"https:\/\/towardsdatascience.com\/tag\/optimization\/\" title=\"Optimization\">Optimization<\/a> in machine learning development. Each post has focused on different stages of the training pipeline, demonstrating practical tools and techniques for analyzing and boosting resource utilization and runtime efficiency.<\/p>\n<p class=\"wp-block-paragraph\">In this installment, we focus on metric collection. We will demonstrate how a na\u00efve implementation of metric collection can negatively impact runtime performance and explore tools and techniques for its analysis and optimization.<\/p>\n<p class=\"wp-block-paragraph\">To implement our metric collection, we will use <a href=\"https:\/\/pypi.org\/project\/torchmetrics\/\" rel=\"noreferrer noopener\" target=\"_blank\">TorchMetrics<\/a> a popular library designed to simplify and standardize metric computation in <a href=\"https:\/\/towardsdatascience.com\/tag\/pytorch\/\" title=\"Pytorch\">Pytorch<\/a>. Our goals will be to:<\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">\n<strong>Demonstrate the runtime overhead<\/strong> caused by a na\u00efve implementation of metric collection.<\/li>\n<li class=\"wp-block-list-item\">\n<strong>Use PyTorch Profiler<\/strong> to pinpoint performance bottlenecks introduced by metric computation.<\/li>\n<li class=\"wp-block-list-item\">\n<strong>Demonstrate optimization techniques<\/strong> to reduce metric collection overhead.<\/li>\n<\/ol>\n<p class=\"wp-block-paragraph\">To facilitate our discussion, we will define a toy PyTorch model and assess how metric collection can impact its runtime performance. We will run our experiments on an NVIDIA A40 GPU, with a <a href=\"https:\/\/hub.docker.com\/layers\/pytorch\/pytorch\/2.5.1-cuda12.4-cudnn9-devel\/images\/sha256-14611869895df612b7b07227d5925f30ec3cd6673bad58ce3d84ed107950e014\" rel=\"noreferrer noopener\" target=\"_blank\">PyTorch 2.5.1 docker<\/a> image and <a href=\"https:\/\/pypi.org\/project\/torchmetrics\/\" rel=\"noreferrer noopener\" target=\"_blank\">TorchMetrics 1.6.1<\/a>.<\/p>\n<p class=\"wp-block-paragraph\">It\u2019s important to note that metric collection behavior can vary greatly depending on the hardware, runtime environment, and model architecture. The code snippets provided in this post are intended for demonstrative purposes only. Please do not interpret our mention of any tool or technique as an endorsement for its use.<\/p>\n<h3 class=\"wp-block-heading\">Toy Resnet\u00a0Model<\/h3>\n<p class=\"wp-block-paragraph\">In the code block below we define a simple image classification model with a <a href=\"https:\/\/pytorch.org\/vision\/main\/models\/generated\/torchvision.models.resnet18\" rel=\"noreferrer noopener\" target=\"_blank\">ResNet-18<\/a> backbone.<\/p>\n<pre class=\"wp-block-code\"><code>import time\nimport torch\nimport torchvision\n\ndevice = \"cuda\"\n\nmodel = torchvision.models.resnet18().to(device)\ncriterion = torch.nn.CrossEntropyLoss()\noptimizer = torch.optim.SGD(model.parameters())<\/code><\/pre>\n<p class=\"wp-block-paragraph\">We define a synthetic dataset which we will use to train our toy model.<\/p>\n<pre class=\"wp-block-code\"><code>from torch.utils.data import Dataset, DataLoader\n\n# A dataset with random images and labels\nclass FakeDataset(Dataset):\n    def __len__(self):\n        return 100000000\n\n    def __getitem__(self, index):\n        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)\n        label = torch.tensor(data=index % 1000, dtype=torch.int64)\n        return rand_image, label\n\ntrain_set = FakeDataset()\n\nbatch_size = 128\nnum_workers = 12\n\ntrain_loader = DataLoader(\n    dataset=train_set,\n    batch_size=batch_size,\n    num_workers=num_workers,\n    pin_memory=True\n)<\/code><\/pre>\n<p class=\"wp-block-paragraph\">We define a collection of standard metrics from TorchMetrics, along with a control flag to enable or disable metric calculation.<\/p>\n<pre class=\"wp-block-code\"><code>from torchmetrics import (\n    MeanMetric,\n    Accuracy,\n    Precision,\n    Recall,\n    F1Score,\n)\n\n# toggle to enable\/disable metric collection\ncapture_metrics = False\n\nif capture_metrics:\n        metrics = {\n        \"avg_loss\": MeanMetric(),\n        \"accuracy\": Accuracy(task=\"multiclass\", num_classes=1000),\n        \"precision\": Precision(task=\"multiclass\", num_classes=1000),\n        \"recall\": Recall(task=\"multiclass\", num_classes=1000),\n        \"f1_score\": F1Score(task=\"multiclass\", num_classes=1000),\n    }\n\n    # Move all metrics to the device\n    metrics = {name: metric.to(device) for name, metric in metrics.items()}<\/code><\/pre>\n<p class=\"wp-block-paragraph\">Next, we define a <a href=\"https:\/\/pytorch.org\/tutorials\/recipes\/recipes\/profiler_recipe.html\" target=\"_blank\" rel=\"noreferrer noopener\">PyTorch Profiler<\/a> instance, along with a control flag that allows us to enable or disable profiling. For a detailed tutorial on using PyTorch Profiler, please refer to the <a href=\"https:\/\/medium.com\/towards-data-science\/pytorch-model-performance-analysis-and-optimization-10c3c5822869\" target=\"_blank\" rel=\"noreferrer noopener\">first post<\/a> in this series.<\/p>\n<pre class=\"wp-block-code\"><code>from torch import profiler\n\n# toggle to enable\/disable profiling\nenable_profiler = True\n\nif enable_profiler:\n    prof = profiler.profile(\n        schedule=profiler.schedule(wait=10, warmup=2, active=3, repeat=1),\n        on_trace_ready=profiler.tensorboard_trace_handler(\".\/logs\/\"),\n        profile_memory=True,\n        with_stack=True\n    )\n    prof.start()<\/code><\/pre>\n<p class=\"wp-block-paragraph\">Lastly, we define a standard training step:<\/p>\n<pre class=\"wp-block-code\"><code>model.train()\n\nt0 = time.perf_counter()\ntotal_time = 0\ncount = 0\n\nfor idx, (data, target) in enumerate(train_loader):\n    data = data.to(device, non_blocking=True)\n    target = target.to(device, non_blocking=True)\n    optimizer.zero_grad()\n    output = model(data)\n    loss = criterion(output, target)\n    loss.backward()\n    optimizer.step()\n\n    if capture_metrics:\n        # update metrics\n        metrics[\"avg_loss\"].update(loss)\n        for name, metric in metrics.items():\n            if name != \"avg_loss\":\n                metric.update(output, target)\n\n        if (idx + 1) % 100 == 0:\n            # compute metrics\n            metric_results = {\n                name: metric.compute().item() \n                    for name, metric in metrics.items()\n            }\n            # print metrics\n            print(f\"Step {idx + 1}: {metric_results}\")\n            # reset metrics\n            for metric in metrics.values():\n                metric.reset()\n\n    elif (idx + 1) % 100 == 0:\n        # print last loss value\n        print(f\"Step {idx + 1}: Loss = {loss.item():.4f}\")\n\n    batch_time = time.perf_counter() - t0\n    t0 = time.perf_counter()\n    if idx &gt; 10:  # skip first steps\n        total_time += batch_time\n        count += 1\n\n    if enable_profiler:\n        prof.step()\n\n    if idx &gt; 200:\n        break\n\nif enable_profiler:\n    prof.stop()\n\navg_time = total_time\/count\nprint(f'Average step time: {avg_time}')\nprint(f'Throughput: {batch_size\/avg_time:.2f} images\/sec')<\/code><\/pre>\n<h4 class=\"wp-block-heading\">Metric Collection Overhead<\/h4>\n<p class=\"wp-block-paragraph\">To measure the impact of metric collection on training step time, we ran our training script both with and without metric calculation. The results are summarized in the following table.<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" data-dominant-color=\"efefef\" data-has-transparency=\"false\" style=\"--dominant-color: #efefef;\" loading=\"lazy\" decoding=\"async\" width=\"721\" height=\"106\" src=\"https:\/\/i0.wp.com\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_glGBSjobt7wpDFXNyEvmAw.png?resize=721%2C106&#038;ssl=1\" alt=\"\" class=\"wp-image-597511 not-transparent\" srcset=\"https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_glGBSjobt7wpDFXNyEvmAw.png 721w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_glGBSjobt7wpDFXNyEvmAw-300x44.png 300w\" sizes=\"auto, (max-width: 721px) 100vw, 721px\"><figcaption class=\"wp-element-caption\">The Overhead of Naive Metric Collection (by Author)<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">Our na\u00efve metric collection resulted in a nearly 10% drop in runtime performance!! While metric collection is essential for machine learning development, it usually involves relatively simple mathematical operations and hardly warrants such a significant overhead. What is going on?!!<\/p>\n<h2 class=\"wp-block-heading\">Identifying Performance Issues with PyTorch\u00a0Profiler<\/h2>\n<p class=\"wp-block-paragraph\">To better understand the source of the performance degradation, we reran the training script with the PyTorch Profiler enabled. The resultant trace is shown below:<\/p>\n<figure class=\"wp-block-image size-large\"><img data-recalc-dims=\"1\" data-dominant-color=\"dfe0e5\" data-has-transparency=\"true\" style=\"--dominant-color: #dfe0e5;\" loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"526\" src=\"https:\/\/i0.wp.com\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_07ZVxGG-Sb6Boj1WJRz6Ag-1024x526.png?resize=1024%2C526&#038;ssl=1\" alt=\"\" class=\"wp-image-597512 has-transparency\" srcset=\"https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_07ZVxGG-Sb6Boj1WJRz6Ag-1024x526.png 1024w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_07ZVxGG-Sb6Boj1WJRz6Ag-300x154.png 300w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_07ZVxGG-Sb6Boj1WJRz6Ag-768x394.png 768w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_07ZVxGG-Sb6Boj1WJRz6Ag.png 1364w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\"><figcaption class=\"wp-element-caption\">Trace of Metric Collection Experiment (by Author)<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">The trace reveals recurring \u201ccudaStreamSynchronize\u201d operations that coincide with noticeable drops in GPU utilization. These types of \u201cCPU-GPU sync\u201d events were discussed in detail in <a href=\"https:\/\/medium.com\/towards-data-science\/pytorch-model-performance-analysis-and-optimization-part-2-3bc241be91\" target=\"_blank\" rel=\"noreferrer noopener\">part two<\/a> of our series. In a typical training step, the CPU and GPU work in parallel: The CPU manages tasks like data transfers to the GPU and kernel loading, and the GPU executes the model on the input data and updates its weights. Ideally, we would like to minimize the points of synchronization between the CPU and GPU in order to maximize performance. Here, however, we can see that the metric collection has triggered a sync event by performing a CPU to GPU data copy. This requires the CPU to suspend its processing until the GPU catches up which, in turn, causes the GPU to wait for the CPU to resume loading the subsequent kernel operations. The bottom line is that these synchronization points lead to inefficient utilization of both the CPU and GPU. Our metric collection implmentation adds eight such synchronization events to each training step.<\/p>\n<p class=\"wp-block-paragraph\">A closer examination of the trace shows that the sync events are coming from the <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L547\" rel=\"noreferrer noopener\" target=\"_blank\">update<\/a> call of the <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L494\" rel=\"noreferrer noopener\" target=\"_blank\">MeanMetric<\/a> TorchMetric. For the experienced profiling expert, this may be sufficient to identify the root cause, but we will go a step further and use the <a href=\"https:\/\/pytorch.org\/tutorials\/beginner\/profiler.html#performance-debugging-using-profiler\" rel=\"noreferrer noopener\" target=\"_blank\">torch.profiler.record_function<\/a> utility to identify the exact offending line of code.<\/p>\n<h3 class=\"wp-block-heading\">Profiling with <a href=\"https:\/\/pytorch.org\/tutorials\/beginner\/profiler.html#performance-debugging-using-profiler\" rel=\"noreferrer noopener\" target=\"_blank\">record_function<\/a><br \/>\n<\/h3>\n<p class=\"wp-block-paragraph\">To pinpoint the exact source of the sync event, we extended the <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L494\" rel=\"noreferrer noopener\" target=\"_blank\">MeanMetric<\/a> class and overrode the <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L547\" rel=\"noreferrer noopener\" target=\"_blank\">update<\/a> method using <a href=\"https:\/\/pytorch.org\/tutorials\/beginner\/profiler.html#performance-debugging-using-profiler\" rel=\"noreferrer noopener\" target=\"_blank\">record_function<\/a> context blocks. This approach allows us to profile individual operations within the method and identify performance bottlenecks.<\/p>\n<pre class=\"wp-block-code\"><code>class ProfileMeanMetric(MeanMetric):\n    def update(self, value, weight = 1.0):\n        # broadcast weight to value shape\n        with profiler.record_function(\"process value\"):\n            if not isinstance(value, torch.Tensor):\n                value = torch.as_tensor(value, dtype=self.dtype,\n                                        device=self.device)\n        with profiler.record_function(\"process weight\"):\n            if weight is not None and not isinstance(weight, torch.Tensor):\n                weight = torch.as_tensor(weight, dtype=self.dtype,\n                                         device=self.device)\n        with profiler.record_function(\"broadcast weight\"):\n            weight = torch.broadcast_to(weight, value.shape)\n        with profiler.record_function(\"cast_and_nan_check\"):\n            value, weight = self._cast_and_nan_check_input(value, weight)\n\n        if value.numel() == 0:\n            return\n\n        with profiler.record_function(\"update value\"):\n            self.mean_value += (value * weight).sum()\n        with profiler.record_function(\"update weight\"):\n            self.weight += weight.sum()<\/code><\/pre>\n<p class=\"wp-block-paragraph\">We then updated our avg_loss metric to use the newly created ProfileMeanMetric and reran the training script.<\/p>\n<figure class=\"wp-block-image size-large\"><img data-recalc-dims=\"1\" data-dominant-color=\"d9dedf\" data-has-transparency=\"true\" style=\"--dominant-color: #d9dedf;\" loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"348\" src=\"https:\/\/i0.wp.com\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_v_aeVv1TrZkaSZWSnuhvcQ-1024x348.png?resize=1024%2C348&#038;ssl=1\" alt=\"\" class=\"wp-image-597513 has-transparency\" srcset=\"https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_v_aeVv1TrZkaSZWSnuhvcQ-1024x348.png 1024w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_v_aeVv1TrZkaSZWSnuhvcQ-300x102.png 300w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_v_aeVv1TrZkaSZWSnuhvcQ-768x261.png 768w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_v_aeVv1TrZkaSZWSnuhvcQ.png 1375w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\"><figcaption class=\"wp-element-caption\">Trace of Metric Collection with <a href=\"https:\/\/pytorch.org\/tutorials\/beginner\/profiler.html#performance-debugging-using-profiler\" target=\"_blank\" rel=\"noreferrer noopener\">record_function<\/a> (by Author)<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">The updated trace reveals that the sync event originates from the following line:<\/p>\n<pre class=\"wp-block-code\"><code>weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)<\/code><\/pre>\n<p class=\"wp-block-paragraph\">This operation converts the default scalar value <code>weight=1.0<\/code> into a PyTorch tensor and places it on the GPU. The sync event occurs because this action triggers a CPU-to-GPU data copy, which requires the CPU to wait for the GPU to process the copied value.<\/p>\n<h3 class=\"wp-block-heading\">Optimization 1: Specify Weight\u00a0Value<\/h3>\n<p class=\"wp-block-paragraph\">Now that we have found the source of the issue, we can overcome it easily by specifying a <em>weight <\/em>value in our <em>update<\/em> call. This prevents the runtime from converting the default scalar <code>weight=1.0<\/code> into a tensor on the GPU, avoiding the sync event:<\/p>\n<pre class=\"wp-block-code\"><code># update metrics\n if capture_metric:\n     metrics[\"avg_loss\"].update(loss, weight=torch.ones_like(loss))<\/code><\/pre>\n<p class=\"wp-block-paragraph\">Rerunning the script after applying this change reveals that we have succeeded in eliminating the initial sync event\u2026 only to have uncovered a new one, this time coming from the <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L76\" target=\"_blank\" rel=\"noreferrer noopener\">_cast_and_nan_check_input<\/a> function:<\/p>\n<figure class=\"wp-block-image size-large\"><img data-recalc-dims=\"1\" data-dominant-color=\"dcdbde\" data-has-transparency=\"true\" style=\"--dominant-color: #dcdbde;\" loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"332\" src=\"https:\/\/i0.wp.com\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_3TkHbqLnl0ePvAd8ad0pEw-1024x332.png?resize=1024%2C332&#038;ssl=1\" alt=\"\" class=\"wp-image-597514 has-transparency\" srcset=\"https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_3TkHbqLnl0ePvAd8ad0pEw-1024x332.png 1024w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_3TkHbqLnl0ePvAd8ad0pEw-300x97.png 300w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_3TkHbqLnl0ePvAd8ad0pEw-768x249.png 768w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_3TkHbqLnl0ePvAd8ad0pEw.png 1354w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\"><figcaption class=\"wp-element-caption\">Trace of Metric Collection following Optimization 1 (by Author)<\/figcaption><\/figure>\n<h3 class=\"wp-block-heading\">Profiling with <a href=\"https:\/\/pytorch.org\/tutorials\/beginner\/profiler.html#performance-debugging-using-profiler\" rel=\"noreferrer noopener\" target=\"_blank\">record_function<\/a>\u200a\u2014\u200aPart\u00a02<\/h3>\n<p class=\"wp-block-paragraph\">To explore our new sync event, we extended our custom metric with additional profiling probes and reran our script.<\/p>\n<pre class=\"wp-block-code\"><code>class ProfileMeanMetric(MeanMetric):\n    def update(self, value, weight = 1.0):\n        # broadcast weight to value shape\n        with profiler.record_function(\"process value\"):\n            if not isinstance(value, torch.Tensor):\n                value = torch.as_tensor(value, dtype=self.dtype,\n                                        device=self.device)\n        with profiler.record_function(\"process weight\"):\n            if weight is not None and not isinstance(weight, torch.Tensor):\n                weight = torch.as_tensor(weight, dtype=self.dtype,\n                                         device=self.device)\n        with profiler.record_function(\"broadcast weight\"):\n            weight = torch.broadcast_to(weight, value.shape)\n        with profiler.record_function(\"cast_and_nan_check\"):\n            value, weight = self._cast_and_nan_check_input(value, weight)\n\n        if value.numel() == 0:\n            return\n\n        with profiler.record_function(\"update value\"):\n            self.mean_value += (value * weight).sum()\n        with profiler.record_function(\"update weight\"):\n            self.weight += weight.sum()\n\n    def _cast_and_nan_check_input(self, x, weight = None):\n        \"\"\"Convert input ``x`` to a tensor and check for Nans.\"\"\"\n        with profiler.record_function(\"process x\"):\n            if not isinstance(x, torch.Tensor):\n                x = torch.as_tensor(x, dtype=self.dtype,\n                                    device=self.device)\n        with profiler.record_function(\"process weight\"):\n            if weight is not None and not isinstance(weight, torch.Tensor):\n                weight = torch.as_tensor(weight, dtype=self.dtype,\n                                         device=self.device)\n            nans = torch.isnan(x)\n            if weight is not None:\n                nans_weight = torch.isnan(weight)\n            else:\n                nans_weight = torch.zeros_like(nans).bool()\n                weight = torch.ones_like(x)\n\n        with profiler.record_function(\"any nans\"):\n            anynans = nans.any() or nans_weight.any()\n\n        with profiler.record_function(\"process nans\"):\n            if anynans:\n                if self.nan_strategy == \"error\":\n                    raise RuntimeError(\"Encountered `nan` values in tensor\")\n                if self.nan_strategy in (\"ignore\", \"warn\"):\n                    if self.nan_strategy == \"warn\":\n                        print(\"Encountered `nan` values in tensor.\"\n                              \" Will be removed.\")\n                    x = x[~(nans | nans_weight)]\n                    weight = weight[~(nans | nans_weight)]\n                else:\n                    if not isinstance(self.nan_strategy, float):\n                        raise ValueError(f\"`nan_strategy` shall be float\"\n                                         f\" but you pass {self.nan_strategy}\")\n                    x[nans | nans_weight] = self.nan_strategy\n                    weight[nans | nans_weight] = self.nan_strategy\n\n        with profiler.record_function(\"return value\"):\n            retval = x.to(self.dtype), weight.to(self.dtype)\n        return retval<\/code><\/pre>\n<p class=\"wp-block-paragraph\">The resultant trace is captured below:<\/p>\n<figure class=\"wp-block-image size-large\"><img data-recalc-dims=\"1\" data-dominant-color=\"e3dee5\" data-has-transparency=\"true\" style=\"--dominant-color: #e3dee5;\" loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"354\" src=\"https:\/\/i0.wp.com\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_ad4KRI8de2zKtwV5BKoAcg-1024x354.png?resize=1024%2C354&#038;ssl=1\" alt=\"\" class=\"wp-image-597515 has-transparency\" srcset=\"https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_ad4KRI8de2zKtwV5BKoAcg-1024x354.png 1024w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_ad4KRI8de2zKtwV5BKoAcg-300x104.png 300w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_ad4KRI8de2zKtwV5BKoAcg-768x266.png 768w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_ad4KRI8de2zKtwV5BKoAcg-1536x531.png 1536w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_ad4KRI8de2zKtwV5BKoAcg.png 1561w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\"><figcaption class=\"wp-element-caption\">Trace of Metric Collection with <a href=\"https:\/\/pytorch.org\/tutorials\/beginner\/profiler.html#performance-debugging-using-profiler\" target=\"_blank\" rel=\"noreferrer noopener\">record_function<\/a>\u200a\u2014\u200apart 2 (by Author)<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">The trace points directly to the offending line:<\/p>\n<pre class=\"wp-block-code\"><code>anynans = nans.any() or nans_weight.any()<\/code><\/pre>\n<p class=\"wp-block-paragraph\">This operation checks for <code>NaN<\/code> values in the input tensors, but it introduces a costly CPU-GPU synchronization event because the operation involves copying data from the GPU to the CPU.<\/p>\n<p class=\"wp-block-paragraph\">Upon a closer inspection of the TorchMetric <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L31\" rel=\"noreferrer noopener\" target=\"_blank\">BaseAggregator<\/a> class, we find several options for handling NAN value updates, all of which pass through the offending line of code. However, for our use case\u200a\u2014\u200acalculating the average loss metric\u200a\u2014\u200athis check is unnecessary and does not justify the runtime performance penalty.<\/p>\n<h3 class=\"wp-block-heading\">Optimization 2: Disable NAN Value\u00a0Checks<\/h3>\n<p class=\"wp-block-paragraph\">To eliminate the overhead, we propose disabling the <code>NaN<\/code> value checks by overriding the <code>_cast_and_nan_check_input<\/code> function. Instead of a static override, we implemented a dynamic solution that can be applied flexibly to any descendants of the <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L31\" target=\"_blank\" rel=\"noreferrer noopener\">BaseAggregator<\/a> class.<\/p>\n<pre class=\"wp-block-code\"><code>from torchmetrics.aggregation import BaseAggregator\n\ndef suppress_nan_check(MetricClass):\n    assert issubclass(MetricClass, BaseAggregator), MetricClass\n    class DisableNanCheck(MetricClass):\n        def _cast_and_nan_check_input(self, x, weight=None):\n            if not isinstance(x, torch.Tensor):\n                x = torch.as_tensor(x, dtype=self.dtype, \n                                    device=self.device)\n            if weight is not None and not isinstance(weight, torch.Tensor):\n                weight = torch.as_tensor(weight, dtype=self.dtype,\n                                         device=self.device)\n            if weight is None:\n                weight = torch.ones_like(x)\n            return x.to(self.dtype), weight.to(self.dtype)\n    return DisableNanCheck\n\nNoNanMeanMetric = suppress_nan_check(MeanMetric)\n\nmetrics[\"avg_loss\"] = NoNanMeanMetric().to(device)<\/code><\/pre>\n<h3 class=\"wp-block-heading\">Post Optimization Results:\u00a0Success<\/h3>\n<p class=\"wp-block-paragraph\">After implementing the two optimizations\u200a\u2014\u200aspecifying the weight value and disabling the <code>NaN<\/code> checks\u2014we find the step time performance and the GPU utilization to match those of our baseline experiment. In addition, the resultant PyTorch Profiler trace shows that all of the added \u201ccudaStreamSynchronize\u201d events that were associated with the metric collection, have been eliminated. With a few small changes, we have reduced the cost of training by ~10% without any changes to the behavior of the metric collection.<\/p>\n<p class=\"wp-block-paragraph\">In the next section we will explore an additional Metric collection optimization.<\/p>\n<h2 class=\"wp-block-heading\">Example 2: Optimizing Metric Device Placement<\/h2>\n<p class=\"wp-block-paragraph\">In the previous section, the metric values resided on the GPU, making it logical to store and compute the metrics on the GPU. However, in scenarios where the values we wish to aggregate reside on the CPU, it might be preferable to store the metrics on the CPU to avoid unnecessary device transfers.<\/p>\n<p class=\"wp-block-paragraph\">In the code block below, we modify our script to calculate the average step time using a <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L494\" rel=\"noreferrer noopener\" target=\"_blank\">MeanMetric<\/a> on the CPU. This change has no impact on the runtime performance of our training step:<\/p>\n<pre class=\"wp-block-code\"><code>avg_time = NoNanMeanMetric()\nt0 = time.perf_counter()\n\nfor idx, (data, target) in enumerate(train_loader):\n    # move data to device\n    data = data.to(device, non_blocking=True)\n    target = target.to(device, non_blocking=True)\n\n    optimizer.zero_grad()\n    output = model(data)\n    loss = criterion(output, target)\n    loss.backward()\n    optimizer.step()\n\n    if capture_metrics:\n        metrics[\"avg_loss\"].update(loss)\n        for name, metric in metrics.items():\n            if name != \"avg_loss\":\n                metric.update(output, target)\n\n        if (idx + 1) % 100 == 0:\n            # compute metrics\n            metric_results = {\n                name: metric.compute().item()\n                    for name, metric in metrics.items()\n            }\n            # print metrics\n            print(f\"Step {idx + 1}: {metric_results}\")\n            # reset metrics\n            for metric in metrics.values():\n                metric.reset()\n\n    elif (idx + 1) % 100 == 0:\n        # print last loss value\n        print(f\"Step {idx + 1}: Loss = {loss.item():.4f}\")\n\n    batch_time = time.perf_counter() - t0\n    t0 = time.perf_counter()\n    if idx &gt; 10:  # skip first steps\n        avg_time.update(batch_time)\n\n    if enable_profiler:\n        prof.step()\n\n    if idx &gt; 200:\n        break\n\nif enable_profiler:\n    prof.stop()\n\navg_time = avg_time.compute().item()\nprint(f'Average step time: {avg_time}')\nprint(f'Throughput: {batch_size\/avg_time:.2f} images\/sec')<\/code><\/pre>\n<p class=\"wp-block-paragraph\">The problem arises when we attempt to extend our script to support distributed training. To demonstrate the problem, we modified our model definition to use <a href=\"https:\/\/pytorch.org\/tutorials\/intermediate\/ddp_tutorial.html\" target=\"_blank\" rel=\"noreferrer noopener\">DistributedDataParallel (DDP)<\/a>:<\/p>\n<pre class=\"wp-block-code\"><code># toggle to enable\/disable ddp\nuse_ddp = True\n\nif use_ddp:\n    import os\n    import torch.distributed as dist\n    from torch.nn.parallel import DistributedDataParallel as DDP\n    os.environ[\"MASTER_ADDR\"] = \"127.0.0.1\"\n    os.environ[\"MASTER_PORT\"] = \"29500\"\n    dist.init_process_group(\"nccl\", rank=0, world_size=1)\n    torch.cuda.set_device(0)\n    model = DDP(torchvision.models.resnet18().to(device))\nelse:\n    model = torchvision.models.resnet18().to(device)\n\n# insert training loop\n\n# append to end of the script:\nif use_ddp:\n    # destroy the process group\n    dist.destroy_process_group()<\/code><\/pre>\n<p class=\"wp-block-paragraph\">The DDP modification results in the following error:<\/p>\n<pre class=\"wp-block-code\"><code>RuntimeError: No backend type associated with device type cpu<\/code><\/pre>\n<p class=\"wp-block-paragraph\">By default, metrics in distributed training are programmed to synchronize across all devices in use. However, the synchronization backend used by DDP does not support metrics stored on the CPU.<\/p>\n<p class=\"wp-block-paragraph\">One way to solve this is to disable the cross-device metric synchronization:<\/p>\n<pre class=\"wp-block-code\"><code>avg_time = NoNanMeanMetric(sync_on_compute=False)<\/code><\/pre>\n<p class=\"wp-block-paragraph\">In our case, where we are measuring the average time, this solution is acceptable. However, in some cases, the metric synchronization is essential, and we have may have no choice but to move the metric onto the GPU:<\/p>\n<pre class=\"wp-block-code\"><code>avg_time = NoNanMeanMetric().to(device)<\/code><\/pre>\n<p class=\"wp-block-paragraph\">Unfortunately, this situation gives rise to a new CPU-GPU sync event coming from the <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/blob\/v1.6.1\/src\/torchmetrics\/aggregation.py#L547\" target=\"_blank\" rel=\"noreferrer noopener\">update<\/a> function.<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" data-dominant-color=\"c9c2ca\" data-has-transparency=\"true\" style=\"--dominant-color: #c9c2ca;\" loading=\"lazy\" decoding=\"async\" width=\"236\" height=\"190\" src=\"https:\/\/i0.wp.com\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/1_b-EcMc_rNqSXzKr1rD3-mw.png?resize=236%2C190&#038;ssl=1\" alt=\"\" class=\"wp-image-597516 has-transparency\"><figcaption class=\"wp-element-caption\">Trace of avg_time Metric Collection (by Author)<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\">This sync event should hardly come as a surprise\u2014after all, we are updating a GPU metric with a value residing on the CPU, which should necessitate a memory copy. However, in the case of a scalar metric, this data transfer can be completely avoided with a simple optimization.<\/p>\n<h3 class=\"wp-block-heading\">Optimization 3: Perform Metric Updates with Tensors instead of\u00a0Scalars<\/h3>\n<p class=\"wp-block-paragraph\">The solution is straightforward: instead of updating the metric with a float value, we convert to a Tensor before calling <code>update<\/code>.<\/p>\n<pre class=\"wp-block-code\"><code>batch_time = torch.as_tensor(batch_time)\navg_time.update(batch_time, torch.ones_like(batch_time))<\/code><\/pre>\n<p class=\"wp-block-paragraph\">This minor change bypasses the problematic line of code, eliminates the sync event, and restores the step time to the baseline performance.<\/p>\n<p class=\"wp-block-paragraph\">At first glance, this result may seem surprising: We would expect that updating a GPU metric with a CPU tensor should still require a memory copy. However, PyTorch optimizes operations on scalar tensors by using a dedicated kernel that performs the addition without an explicit data transfer. This avoids the expensive synchronization event that would otherwise occur.<\/p>\n<h2 class=\"wp-block-heading\">Summary<\/h2>\n<p class=\"wp-block-paragraph\">In this post, we explored how a na\u00efve approach to TorchMetrics can introduce CPU-GPU synchronization events and significantly degrade PyTorch training performance. Using PyTorch Profiler, we identified the lines of code responsible for these sync events and applied targeted optimizations to eliminate them:<\/p>\n<ul class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">Explicitly specify a weight tensor when calling the <code>MeanMetric.update<\/code> function instead of relying on the default value.<\/li>\n<li class=\"wp-block-list-item\">Disable NaN checks in the base <code>Aggregator<\/code> class or replace them with a more efficient alternative.<\/li>\n<li class=\"wp-block-list-item\">Carefully manage the device placement of each metric to minimize unnecessary transfers.<\/li>\n<li class=\"wp-block-list-item\">Disable cross-device metric synchronization when not required.<\/li>\n<li class=\"wp-block-list-item\">When the metric resides on a GPU, convert floating-point scalars to tensors before passing them to the <code>update<\/code> function to avoid implicit synchronization.<\/li>\n<\/ul>\n<p class=\"wp-block-paragraph\">We have created a dedicated <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/pull\/2943\" rel=\"noreferrer noopener\" target=\"_blank\">pull request<\/a> on the <a href=\"https:\/\/github.com\/Lightning-AI\/torchmetrics\/tree\/master\" rel=\"noreferrer noopener\" target=\"_blank\">TorchMetrics github<\/a> page covering some of the optimizations discussed in this post. Please feel free to contribute your own improvements and optimizations!<\/p>\n<p>The post <a href=\"https:\/\/towardsdatascience.com\/efficient-metric-collection-in-pytorch-avoiding-the-performance-pitfalls-of-torchmetrics\/\">Efficient Metric Collection in PyTorch: Avoiding the Performance Pitfalls of TorchMetrics<\/a> appeared first on <a href=\"https:\/\/towardsdatascience.com\/\">Towards Data Science<\/a>.<\/p>\n<\/div>\n<p> \t<BR><br \/>\n <BR><\/BR><br \/>\n    Chaim Rand<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n<a href=\"https:\/\/towardsdatascience.com\/efficient-metric-collection-in-pytorch-avoiding-the-performance-pitfalls-of-torchmetrics\/\">Go to original source<\/a><br \/>\n \t<BR><br \/>\n <BR><\/BR><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Efficient Metric Collection in PyTorch: Avoiding the Performance Pitfalls of TorchMetrics Metric collection is an essential part of every machine learning project, enabling us to track model performance and monitor training progress. Ideally, Metrics should be collected and computed without introducing any additional overhead to the training process. However, just like other components of the [&hellip;]<\/p>\n","protected":false},"author":2,"featured_media":0,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[62,70,977,222,402,157,75],"tags":[1667,769,1194],"class_list":["post-1712","post","type-post","status-publish","format-standard","hentry","category-aimldsaimlds","category-machine-learning","category-metrics","category-mlops","category-optimization","category-python","category-pytorch","tag-collection","tag-metric","tag-performance"],"_links":{"self":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/1712"}],"collection":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/comments?post=1712"}],"version-history":[{"count":0,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/1712\/revisions"}],"wp:attachment":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/media?parent=1712"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/categories?post=1712"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/tags?post=1712"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}