{"id":2122,"date":"2025-02-28T07:05:20","date_gmt":"2025-02-28T07:05:20","guid":{"rendered":"https:\/\/mailitics.com\/index.php\/2025\/02\/28\/debugging-the-dreaded-nan\/"},"modified":"2025-02-28T07:05:20","modified_gmt":"2025-02-28T07:05:20","slug":"debugging-the-dreaded-nan","status":"publish","type":"post","link":"https:\/\/mailitics.com\/index.php\/2025\/02\/28\/debugging-the-dreaded-nan\/","title":{"rendered":"Debugging the Dreaded NaN"},"content":{"rendered":"<p>    Debugging the Dreaded NaN<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n    <!-- no image --><br \/>\n \t<BR><br \/>\n<BR><\/BR><\/p>\n<div>\n<p class=\"wp-block-paragraph\" id=\"b483\">You are training your latest AI model, anxiously watching as the loss steadily decreases when suddenly \u2014 boom! Your logs are flooded with NaNs (Not a Number) \u2014 your model is irreparably corrupted and you\u2019re left staring at your screen in despair. To make matters worse, the NaNs don\u2019t appear consistently. Sometimes your model trains just fine; other times, it fails inexplicably. Sometimes it will crash immediately, sometimes after many days of training.<\/p>\n<p class=\"wp-block-paragraph\" id=\"ede5\">NaNs in <a href=\"https:\/\/towardsdatascience.com\/tag\/deep-learning\/\" title=\"Deep Learning\">Deep Learning<\/a> workloads are amongst the most frustrating issues to encounter. And because they often appear sporadically \u2014 triggered by a specific combination of model state, input data, and stochastic factors \u2014 they can be incredibly difficult to reproduce and debug.<\/p>\n<p class=\"wp-block-paragraph\" id=\"860d\">Given the considerable cost of training AI models and the potential waste caused by NaN failures, it is recommended to have dedicated tools for capturing and analyzing NaN occurrences. In a\u00a0<a href=\"https:\/\/towardsdatascience.com\/capturing-a-training-state-in-tensorflow-7d643e3fb20b\/\" rel=\"noreferrer noopener\" target=\"_blank\">previous post<\/a>, we discussed the challenge of debugging NaNs in a TensorFlow training workload. We proposed an efficient scheme for capturing and reproducing NaNs and shared a sample TensorFlow implementation. In this post, we adopt and demonstrate a similar mechanism for debugging NaNs in PyTorch workloads. The general scheme is as follows:<\/p>\n<p class=\"wp-block-paragraph\" id=\"9495\">On each training step:<\/p>\n<ol class=\"wp-block-list\">\n<li class=\"wp-block-list-item\">Save a copy of the training input batch.<\/li>\n<li class=\"wp-block-list-item\">Check the gradients for NaN values. If any appear, save a checkpoint with the current model weights before the model is corrupted. Also, save the input batch and, if necessary, the stochastic state. Discontinue the training job.<\/li>\n<li class=\"wp-block-list-item\">Reproduce and debug the NaN occurrence by loading the saved experiment state.<\/li>\n<\/ol>\n<p class=\"wp-block-paragraph\" id=\"7939\">Although this scheme can be easily implemented in native PyTorch, we will take the opportunity to demonstrate some of the conveniences of\u00a0<a href=\"https:\/\/github.com\/Lightning-AI\/pytorch-lightning\" target=\"_blank\" rel=\"noreferrer noopener\">PyTorch Lightning<\/a>\u00a0\u2014 a powerful open-source framework designed to streamline the development of machine learning (ML) models. Built on PyTorch, Lightning abstracts away many of the boiler-plate components of an ML experiment, such as training loops, data distribution, logging, and more, enabling developers to focus on the core logic of their models.<\/p>\n<p class=\"wp-block-paragraph\" id=\"bf51\">To implement our NaN capturing scheme, we will use\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/extensions\/callbacks.html\" rel=\"noreferrer noopener\" target=\"_blank\">Lightning\u2019s callback<\/a>\u00a0interface \u2014 a dedicated structure that enables inserting custom logic at specific points during the flow of execution.<\/p>\n<p class=\"wp-block-paragraph\" id=\"6ea8\">Importantly, please do not view our choice of Lightning or any other tool or technique that we mention as an endorsement of its use. The code that we will share is intended for demonstrative purposes \u2014 please do not rely on its correctness or optimality.<\/p>\n<p class=\"wp-block-paragraph\" id=\"c589\">Many thanks to\u00a0<a href=\"https:\/\/www.linkedin.com\/in\/rom-maltser\/?originalSubdomain=il\" rel=\"noreferrer noopener\" target=\"_blank\">Rom Maltser<\/a>\u00a0for his contributions to this post.<\/p>\n<h2 class=\"wp-block-heading\" id=\"6503\">NaNCapture Callback<\/h2>\n<p class=\"wp-block-paragraph\" id=\"c365\">To implement our NaN capturing solution, we create a NaNCapture Lightning callback. The constructor receives a directory path for storing\/loading checkpoints and sets up the NaNCapture state. We also define utilities for checking for NaNs, storing checkpoints, and halting the training job.<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-javascript\"> import os\nimport torch\nfrom copy import deepcopy\nimport lightning.pytorch as pl\n\nclass NaNCapture(pl.Callback):\n\n    def __init__(self, dirpath: str):\n        # path to checkpoint\n        self.dirpath = dirpath\n        \n        # update to True when Nan is identified\n        self.nan_captured = False\n        \n        # stores a copy of the last batch\n        self.last_batch = None\n        self.batch_idx = None\n\n    @staticmethod\n    def contains_nan(tensor):\n        return torch.isnan(tensor).any().item()\n        # alternatively check for finite\n        # return not torch.isfinite(tensor).item()\n\n    @staticmethod\n    def halt_training(trainer):\n        trainer.should_stop = True\n        # communicate stop command to all other ranks\n        trainer.strategy.reduce_boolean_decision(trainer.should_stop,\n                                                 all=False)\n\n    def save_ckpt(self, trainer):\n        os.makedirs(self.dirpath, exist_ok=True)\n        # include trainer.global_rank to avoid conflict\n        filename = f\"nan_checkpoint_rank_{trainer.global_rank}.ckpt\"\n        full_path = os.path.join(self.dirpath, filename)\n        print(f\"saving ckpt to {full_path}\")\n        trainer.save_checkpoint(full_path, False)<\/code><\/pre>\n<h3 class=\"wp-block-heading\" id=\"b85d\">Callback Function: on_train_batch_start<\/h3>\n<p class=\"wp-block-paragraph\" id=\"a1e5\">We begin by implementing the\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/extensions\/callbacks.html#on-train-batch-start\" rel=\"noreferrer noopener\" target=\"_blank\">on_train_batch_start<\/a>\u00a0hook to store a copy of each input batch. In case of a NaN event, this batch will be stored in the checkpoint.<\/p>\n<h3 class=\"wp-block-heading\" id=\"5d2c\">Callback Function: on_before_optimizer_step<\/h3>\n<p class=\"wp-block-paragraph\" id=\"0540\">Next we implement the\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/extensions\/callbacks.html#on-before-optimizer-step\" target=\"_blank\" rel=\"noreferrer noopener\">on_before_optimizer_step<\/a>\u00a0hook. Here, we check for NaN entries in all of the gradient tensors. If found, we store a checkpoint with the uncorrupted model weights and halt the training.<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-&lt;a href=\" https: title=\"Python\">Python\"&gt;    def on_before_optimizer_step(self, trainer, pl_module, optimizer):\n        if not self.nan_captured:\n            # Check if gradients contain NaN\n            grads = [p.grad.view(-1) for p in pl_module.parameters()\n                     if p.grad is not None]\n            all_grads = torch.cat(grads)\n            if self.contains_nan(all_grads):\n                print(\"nan found\")\n                self.save_ckpt(trainer)\n                self.halt_training(trainer)\n<\/code><\/pre>\n<h3 class=\"wp-block-heading\" id=\"1542\">Capturing the Training State<\/h3>\n<p class=\"wp-block-paragraph\" id=\"6243\">To enable reproducibility, we include the NaNCapture state in the checkpoint by appending it to the training state dictionary. Lightning provides dedicated utilities for saving and loading a\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/extensions\/callbacks.html#save-callback-state\" target=\"_blank\" rel=\"noreferrer noopener\">callback state<\/a>:<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\">def state_dict(self):\n        d = {\"nan_captured\": self.nan_captured}\n        if self.nan_captured:\n            d[\"last_batch\"] = self.last_batch\n        return d\n\n\n    def load_state_dict(self, state_dict):\n        self.nan_captured = state_dict.get(\"nan_captured\", False)\n        if self.nan_captured:\n            self.last_batch = state_dict[\"last_batch\"]<\/code><\/pre>\n<h2 class=\"wp-block-heading\" id=\"1067\">Reproducing the NaN Occurrence<\/h2>\n<p class=\"wp-block-paragraph\" id=\"c64a\">We have described how our NaNCapture callback can be used to store the training state that resulted in a NaN, but how do we reload this state in order to reproduce the issue and debug it? To accomplish this, we leverage Lightning\u2019s dedicated data loading class,\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/data\/datamodule.html\" rel=\"noreferrer noopener\" target=\"_blank\">LightningDataModule<\/a>.<\/p>\n<h3 class=\"wp-block-heading\" id=\"159e\">DataModule Function: on_before_batch_transfer<\/h3>\n<p class=\"wp-block-paragraph\" id=\"485f\">In the code block below, we extend the\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/data\/datamodule.html\" target=\"_blank\" rel=\"noreferrer noopener\">LightningDataModule<\/a>\u00a0class to allow injecting a fixed training input batch. This is achieved by overriding the\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/data\/datamodule.html#on-before-batch-transfer\" target=\"_blank\" rel=\"noreferrer noopener\">on_before_batch_transfer<\/a>\u00a0hook, as shown below:<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\">from lightning.pytorch import LightningDataModule\n\nclass InjectableDataModule(LightningDataModule):\n\n    def __init__(self):\n        super().__init__()\n        self.cached_batch = None\n\n    def set_custom_batch(self, batch):\n        self.cached_batch = batch\n\n    def on_before_batch_transfer(self, batch, dataloader_idx):\n        if self.cached_batch:\n            return self.cached_batch\n        return batch<\/code><\/pre>\n<h3 class=\"wp-block-heading\" id=\"8ce0\">Callback Function: on_train_start<\/h3>\n<p class=\"wp-block-paragraph\" id=\"40de\">The final step is modifying the\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/extensions\/callbacks.html#on-train-start\" target=\"_blank\" rel=\"noreferrer noopener\">on_train_start<\/a>\u00a0hook of our NaNCapture callback to inject the stored training batch into the\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/data\/datamodule.html\" target=\"_blank\" rel=\"noreferrer noopener\">LightningDataModule<\/a>.<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\">    def on_train_start(self, trainer, pl_module):\n        if self.nan_captured:\n            datamodule = trainer.datamodule\n            datamodule.set_custom_batch(self.last_batch)<\/code><\/pre>\n<p class=\"wp-block-paragraph\" id=\"e17d\">In the next section we will demonstrate the end-to-end solution using a toy example.<\/p>\n<h2 class=\"wp-block-heading\" id=\"5c9e\">Toy Example<\/h2>\n<p class=\"wp-block-paragraph\" id=\"aa3a\">To test our new callback, we create a\u00a0<a href=\"https:\/\/pytorch.org\/vision\/main\/models\/generated\/torchvision.models.resnet50\" rel=\"noreferrer noopener\" target=\"_blank\">resnet50<\/a>-based image classification model with a loss function deliberately designed to trigger NaN occurrences.<\/p>\n<p class=\"wp-block-paragraph\" id=\"8728\">Instead of using the standard\u00a0<a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.nn.CrossEntropyLoss.html\" rel=\"noreferrer noopener\" target=\"_blank\">CrossEntropy<\/a>\u00a0loss, we compute\u00a0<a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.nn.functional.binary_cross_entropy_with_logits.html\" rel=\"noreferrer noopener\" target=\"_blank\">binary_cross_entropy_with_logits<\/a>\u00a0for each class independently and divide the result by the number of samples belonging to that class. Inevitably, we will encounter a batch in which one or more classes are missing, leading to a divide-by-zero operation, resulting in NaN values and corrupting the model.<\/p>\n<p class=\"wp-block-paragraph\" id=\"891a\">The implementation below follows Lightning\u2019s\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/starter\/introduction.html\" target=\"_blank\" rel=\"noreferrer noopener\">introductory tutorial<\/a>.<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\">import lightning.pytorch as pl\nimport torch\nimport torchvision\nimport torch.nn.functional as F\n\nnum_classes = 20\n\n\n# define a lightning module\nclass ResnetModel(pl.LightningModule):\n    def __init__(self):\n        \"\"\"Initializes a new instance of the MNISTModel class.\"\"\"\n        super().__init__()\n        self.model = torchvision.models.resnet50(num_classes=num_classes)\n\n    def forward(self, x):\n        return self.model(x)\n\n    def training_step(self, batch, batch_nb):\n        x, y = batch\n        outputs = self(x)\n        # uncomment for default loss\n        # return F.cross_entropy(outputs, y)\n        \n        # calculate binary_cross_entropy for each class individually\n        losses = []\n        for c in range(num_classes):\n            count = torch.count_nonzero(y==c)\n            masked = torch.where(y==c, 1., 0.)\n            loss = F.binary_cross_entropy_with_logits(\n                outputs[..., c],\n                masked,\n                reduction='sum'\n            )\n            mean_loss = loss\/count # could result in NaN\n            losses.append(mean_loss)\n        total_loss = torch.stack(losses).mean()\n        return total_loss\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=0.02)<\/code><\/pre>\n<p class=\"wp-block-paragraph\" id=\"0ff8\">We define a synthetic dataset and encapsulate it in our <code>InjectableDataModule<\/code> class:<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\">import os\nimport random\nfrom torch.utils.data import Dataset, DataLoader\n\nbatch_size = 128\nnum_steps = 800\n\n# A dataset with random images and labels\nclass FakeDataset(Dataset):\n    def __len__(self):\n        return batch_size*num_steps\n\n    def __getitem__(self, index):\n        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)\n        label = torch.tensor(random.randint(0, num_classes-1),\n                             dtype=torch.int64)\n        return rand_image, label\n\n\n\n# define a lightning datamodule\nclass FakeDataModule(InjectableDataModule):\n\n    def train_dataloader(self):\n        dataset = FakeDataset()\n        return DataLoader(\n            dataset,\n            batch_size=batch_size,\n            num_workers=os.cpu_count(),\n            pin_memory=True\n        )<\/code><\/pre>\n<p class=\"wp-block-paragraph\" id=\"4584\">Finally, we initialize a Lightning\u00a0<a href=\"https:\/\/lightning.ai\/docs\/pytorch\/stable\/common\/trainer.html\" target=\"_blank\" rel=\"noreferrer noopener\">Trainer<\/a>\u00a0with our NaNCapture callback and call trainer.fit with our Lightning module and Lightning DataModule.<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\">import time\n\nif __name__ == \"__main__\":\n\n    # Initialize a lightning module\n    lit_module = ResnetModel()\n\n    # Initialize a DataModule\n    mnist_data = FakeDataModule()\n\n    # Train the model\n    ckpt_dir = \".\/ckpt_dir\"\n    trainer = pl.Trainer(\n        max_epochs=1,\n        callbacks=[NaNCapture(ckpt_dir)]\n    )\n\n    ckpt_path = None\n    \n    # check is nan ckpt exists\n    if os.path.isdir(ckpt_dir):\n\n    # check if nan ckpt exists\n    if os.path.isdir(ckpt_dir):\n        dir_contents = [os.path.join(ckpt_dir, f)\n                        for f in os.listdir(ckpt_dir)]\n        ckpts = [f for f in dir_contents\n                 if os.path.isfile(f) and f.endswith('.ckpt')]\n        if ckpts:\n            ckpt_path = ckpts[0]\n\n    t0 = time.perf_counter()\n    trainer.fit(lit_module, mnist_data, ckpt_path=ckpt_path)\n    print(f\"total runtime: {time.perf_counter() - t0}\")<\/code><\/pre>\n<p class=\"wp-block-paragraph\" id=\"a3fc\">After a number of training steps, a NaN event will occur. At this point a checkpoint is saved with the full training state and the training is halted.<\/p>\n<p class=\"wp-block-paragraph\" id=\"4ddd\">When the script is run again the exact state that caused the NaN will be reloaded allowing us to easily reproduce the issue and debug its root cause.<\/p>\n<h2 class=\"wp-block-heading\" id=\"07f6\">Performance Overhead<\/h2>\n<p class=\"wp-block-paragraph\" id=\"720a\">To assess the impact of our NaNCapture callback on runtime performance, we modified our experiment to use\u00a0<a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.nn.CrossEntropyLoss.html\" rel=\"noreferrer noopener\" target=\"_blank\">CrossEntropyLoss<\/a>\u00a0(to avoid NaNs) and measured the average throughput when running with and without NaNCapture callback. The experiments were conducted on an\u00a0<a href=\"https:\/\/www.nvidia.com\/en-eu\/data-center\/l40s\/\" rel=\"noreferrer noopener\" target=\"_blank\">NVIDIA L40S GPU<\/a>, with a\u00a0<a href=\"https:\/\/hub.docker.com\/layers\/pytorch\/pytorch\/2.5.1-cuda12.4-cudnn9-devel\/images\/sha256-14611869895df612b7b07227d5925f30ec3cd6673bad58ce3d84ed107950e014\" rel=\"noreferrer noopener\" target=\"_blank\">PyTorch 2.5.1 Docker<\/a>\u00a0image.<\/p>\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" data-dominant-color=\"f1f1f1\" data-has-transparency=\"false\" style=\"--dominant-color: #f1f1f1;\" loading=\"lazy\" decoding=\"async\" width=\"738\" height=\"106\" src=\"https:\/\/i0.wp.com\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/image-19.png?resize=738%2C106&#038;ssl=1\" alt=\"\" class=\"wp-image-598517 not-transparent\" srcset=\"https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/image-19.png 738w, https:\/\/towardsdatascience.com\/wp-content\/uploads\/2025\/02\/image-19-300x43.png 300w\" sizes=\"auto, (max-width: 738px) 100vw, 738px\"><figcaption class=\"wp-element-caption\">Overhead of NaNCapture Callback (by Author)<\/figcaption><\/figure>\n<p class=\"wp-block-paragraph\" id=\"7df4\">For our toy model, the NaNCapture callback adds a minimal 1.5% overhead to the runtime performance \u2014 a small price to pay for the valuable debugging capabilities it provides.<\/p>\n<p class=\"wp-block-paragraph\" id=\"6555\">Naturally, the actual overhead will depend on the specifics of the model and runtime environment.<\/p>\n<h2 class=\"wp-block-heading\" id=\"6222\">How to Handle Stochasticity<\/h2>\n<p class=\"wp-block-paragraph\" id=\"4212\">The solution we have described henceforth will succeed in reproducing the training state provided that the model does not include any randomness. However, introducing stochasticity into the model definition is often critical for convergence. A common example of a stochastic layer is\u00a0<a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.nn.Dropout.html\" rel=\"noreferrer noopener\" target=\"_blank\">torch.nn.Dropout<\/a>.<\/p>\n<p class=\"wp-block-paragraph\" id=\"561c\">You may find that your NaN event depends on the precise state of randomness when the failure occurred. Consequently, we would like to enhance our NaNCapture callback to capture and restore the random state at the point of failure. The random state is determined by a number of libraries. In the code block below, we attempt to capture the full state of randomness:<\/p>\n<pre class=\"wp-block-prismatic-blocks\"><code class=\"language-python\">import os\nimport torch\nimport random\nimport numpy as np\nfrom copy import deepcopy\nimport lightning.pytorch as pl\n\nclass NaNCapture(pl.Callback):\n\n    def __init__(self, dirpath: str):\n        # path to checkpoint\n        self.dirpath = dirpath\n        \n        # update to True when Nan is identified\n        self.nan_captured = False\n        \n        # stores a copy of the last batch\n        self.last_batch = None\n        self.batch_idx = None\n\n        # rng state\n        self.rng_state = {\n            \"torch\": None,\n            \"torch_cuda\": None,\n            \"numpy\": None,\n            \"random\": None\n        }\n\n    @staticmethod\n    def contains_nan(tensor):\n        return torch.isnan(tensor).any().item()\n        # alternatively check for finite\n        # return not torch.isfinite(tensor).item()\n\n    @staticmethod\n    def halt_training(trainer):\n        trainer.should_stop = True\n        trainer.strategy.reduce_boolean_decision(trainer.should_stop,\n                                                 all=False)\n\n    def save_ckpt(self, trainer):\n        os.makedirs(self.dirpath, exist_ok=True)\n        # include trainer.global_rank to avoid conflict\n        filename = f\"nan_checkpoint_rank_{trainer.global_rank}.ckpt\"\n        full_path = os.path.join(self.dirpath, filename)\n        print(f\"saving ckpt to {full_path}\")\n        trainer.save_checkpoint(full_path, False)\n\n    def on_train_start(self, trainer, pl_module):\n        if self.nan_captured:\n            # inject batch\n            datamodule = trainer.datamodule\n            datamodule.set_custom_batch(self.last_batch)\n\n    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):\n       if self.nan_captured:\n            # restore random state\n            torch.random.set_rng_state(self.rng_state[\"torch\"])\n            torch.cuda.set_rng_state_all(self.rng_state[\"torch_cuda\"])\n            np.random.set_state(self.rng_state[\"numpy\"])\n            random.setstate(self.rng_state[\"random\"])\n        else:\n            # capture current batch\n            self.last_batch= deepcopy(batch)\n            self.batch_idx = batch_idx\n    \n            # capture current random state\n            self.rng_state[\"torch\"] = torch.random.get_rng_state()\n            self.rng_state[\"torch_cuda\"] = torch.cuda.get_rng_state_all()\n            self.rng_state[\"numpy\"] = np.random.get_state()\n            self.rng_state[\"random\"] = random.getstate()\n    \n    def on_before_optimizer_step(self, trainer, pl_module, optimizer):\n        if not self.nan_captured:\n            # Check if gradients contain NaN\n            grads = [p.grad.view(-1) for p in pl_module.parameters()\n                     if p.grad is not None]\n            all_grads = torch.cat(grads)\n            if self.contains_nan(all_grads):\n                print(\"nan found\")\n                self.save_ckpt(trainer)\n                self.halt_training(trainer)\n\n    def state_dict(self):\n        d = {\"nan_captured\": self.nan_captured}\n        if self.nan_captured:\n            d[\"last_batch\"] = self.last_batch\n            d[\"rng_state\"] = self.rng_state\n        return d\n\n    def load_state_dict(self, state_dict):\n        self.nan_captured = state_dict.get(\"nan_captured\", False)\n        if self.nan_captured:\n            self.last_batch = state_dict[\"last_batch\"]\n            self.rng_state = state_dict[\"rng_state\"]<\/code><\/pre>\n<p class=\"wp-block-paragraph\" id=\"5d67\">Importantly, setting the random state may not guarantee full\u00a0<a href=\"https:\/\/pytorch.org\/docs\/stable\/notes\/randomness.html#reproducibility\" rel=\"noreferrer noopener\" target=\"_blank\">reproducibility<\/a>. The GPU owes its power to its massive parallelism. In some GPU operations, multiple threads may read or write concurrently to the same memory locations resulting in nondeterminism. PyTorch allows for some control over this via its\u00a0<a href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.use_deterministic_algorithms.html\" rel=\"noreferrer noopener\" target=\"_blank\">use_deterministic_algorithms<\/a>, but this may impact the runtime performance. Additionally, there is a possibility that the NaN event will not reproduced once this configuration setting is changed. Please see the PyTorch documentation on\u00a0<a href=\"https:\/\/pytorch.org\/docs\/stable\/notes\/randomness.html#reproducibility\" rel=\"noreferrer noopener\" target=\"_blank\">reproducibility<\/a>\u00a0for more details.<\/p>\n<h2 class=\"wp-block-heading\" id=\"39e0\">Summary<\/h2>\n<p class=\"wp-block-paragraph\" id=\"f512\">Encountering NaN failures is one of the most discouraging events that can happen in machine learning development. These errors not only waste valuable computation and development resources, but often indicate fundamental issues in the model architecture or experiment design. Due to their sporadic, sometimes elusive nature, debugging NaN failures can be a nightmare.<\/p>\n<p class=\"wp-block-paragraph\" id=\"4248\">This post introduced a proactive approach for capturing and reproducing NaN errors using a dedicated Lightning callback. The solution we shared is a proposal which can be modified and extended for your specific use case.<\/p>\n<p class=\"wp-block-paragraph\" id=\"300a\">While this solution may not address every possible NaN scenario, it significantly reduces debugging time when applicable, potentially saving developers countless hours of frustration and wasted effort.<\/p>\n<p>The post <a href=\"https:\/\/towardsdatascience.com\/debugging-the-dreaded-nan\/\">Debugging the Dreaded NaN<\/a> appeared first on <a href=\"https:\/\/towardsdatascience.com\/\">Towards Data Science<\/a>.<\/p>\n<\/div>\n<p> \t<BR><br \/>\n <BR><\/BR><br \/>\n    Chaim Rand<br \/>\n \t<BR><br \/>\n<BR><\/BR><br \/>\n<a href=\"https:\/\/towardsdatascience.com\/debugging-the-dreaded-nan\/\">Go to original source<\/a><br \/>\n \t<BR><br \/>\n <BR><\/BR><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Debugging the Dreaded NaN You are training your latest AI model, anxiously watching as the loss steadily decreases when suddenly \u2014 boom! Your logs are flooded with NaNs (Not a Number) \u2014 your model is irreparably corrupted and you\u2019re left staring at your screen in despair. To make matters worse, the NaNs don\u2019t appear consistently. [&hellip;]<\/p>\n","protected":false},"author":2,"featured_media":0,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[62,83,88,166,70,1498,157,1880],"tags":[103,1881,319],"class_list":["post-2122","post","type-post","status-publish","format-standard","hentry","category-aimldsaimlds","category-data-science","category-deep-learning","category-hands-on-tutorials","category-machine-learning","category-model-training","category-python","category-pytorch-lightning","tag-model","tag-nan","tag-training"],"_links":{"self":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/2122"}],"collection":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/comments?post=2122"}],"version-history":[{"count":0,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/posts\/2122\/revisions"}],"wp:attachment":[{"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/media?parent=2122"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/categories?post=2122"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/mailitics.com\/index.php\/wp-json\/wp\/v2\/tags?post=2122"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}