diff --git a/avalanche/training/plugins/ewc.py b/avalanche/training/plugins/ewc.py index 255e686ec..882654bd6 100644 --- a/avalanche/training/plugins/ewc.py +++ b/avalanche/training/plugins/ewc.py @@ -121,6 +121,7 @@ def after_training_exp(self, strategy, **kwargs): strategy.experience.dataset, strategy.device, strategy.train_mb_size, + kwargs["num_workers"], ) self.update_importances(importances, exp_counter) self.saved_params[exp_counter] = copy_params_dict(strategy.model) @@ -129,7 +130,8 @@ def after_training_exp(self, strategy, **kwargs): del self.saved_params[exp_counter - 1] def compute_importances( - self, model, criterion, optimizer, dataset, device, batch_size + self, model, criterion, optimizer, dataset, device, batch_size, + num_workers ): """ Compute EWC importance matrix for each parameter @@ -152,12 +154,15 @@ def compute_importances( # list of list importances = zerolike_params_dict(model) + collate_fn = ( dataset.collate_fn if hasattr(dataset, "collate_fn") else None ) dataloader = DataLoader( - dataset, batch_size=batch_size, collate_fn=collate_fn + dataset, batch_size=batch_size, collate_fn=collate_fn, + num_workers=num_workers ) + for i, batch in enumerate(dataloader): # get only input, target and task_id from the batch x, y, task_labels = batch[0], batch[1], batch[-1]