From 5740f2344607e932c75373717e3875cd3680e86d Mon Sep 17 00:00:00 2001 From: Wenjin Wang Date: Fri, 3 Mar 2023 12:54:20 +0800 Subject: [PATCH 1/2] support num_workers in EWC --- avalanche/training/plugins/ewc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/avalanche/training/plugins/ewc.py b/avalanche/training/plugins/ewc.py index d138a8ef6..41fd08e52 100644 --- a/avalanche/training/plugins/ewc.py +++ b/avalanche/training/plugins/ewc.py @@ -119,6 +119,7 @@ def after_training_exp(self, strategy, **kwargs): strategy.experience.dataset, strategy.device, strategy.train_mb_size, + kwargs["num_workers"], ) self.update_importances(importances, exp_counter) self.saved_params[exp_counter] = copy_params_dict(strategy.model) @@ -127,7 +128,7 @@ def after_training_exp(self, strategy, **kwargs): del self.saved_params[exp_counter - 1] def compute_importances( - self, model, criterion, optimizer, dataset, device, batch_size + self, model, criterion, optimizer, dataset, device, batch_size, num_workers ): """ Compute EWC importance matrix for each parameter @@ -150,7 +151,7 @@ def compute_importances( # list of list importances = zerolike_params_dict(model) - dataloader = DataLoader(dataset, batch_size=batch_size) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) for i, batch in enumerate(dataloader): # get only input, target and task_id from the batch x, y, task_labels = batch[0], batch[1], batch[-1] From 620eeeca681e3118ca550828e332d70f96d7dbd8 Mon Sep 17 00:00:00 2001 From: WangWenjin <30342227+WenjinW@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:56:03 +0800 Subject: [PATCH 2/2] Update ewc.py Fix pep8 --- avalanche/training/plugins/ewc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/avalanche/training/plugins/ewc.py b/avalanche/training/plugins/ewc.py index e8cbd931e..882654bd6 100644 --- a/avalanche/training/plugins/ewc.py +++ b/avalanche/training/plugins/ewc.py @@ -130,7 +130,8 @@ def after_training_exp(self, strategy, **kwargs): del self.saved_params[exp_counter - 1] def compute_importances( - self, model, criterion, optimizer, dataset, device, batch_size, num_workers + self, model, criterion, optimizer, dataset, device, batch_size, + num_workers ): """ Compute EWC importance matrix for each parameter @@ -158,7 +159,8 @@ def compute_importances( dataset.collate_fn if hasattr(dataset, "collate_fn") else None ) dataloader = DataLoader( - dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers) + dataset, batch_size=batch_size, collate_fn=collate_fn, + num_workers=num_workers ) for i, batch in enumerate(dataloader):