diff --git a/avalanche/evaluation/metrics/checkpoint.py b/avalanche/evaluation/metrics/checkpoint.py index 82007a5c4..bf0481283 100644 --- a/avalanche/evaluation/metrics/checkpoint.py +++ b/avalanche/evaluation/metrics/checkpoint.py @@ -80,8 +80,9 @@ def _package_result(self, strategy) -> 'MetricResult': self.get_global_counter())] def after_eval_exp(self, strategy: 'BaseStrategy') -> 'MetricResult': - model_params = copy.deepcopy(strategy.model.parameters()) + model_params = copy.deepcopy(list(strategy.model.parameters())) self.update(model_params) + return self._package_result(strategy) def __str__(self): return "WeightCheckpoint" diff --git a/avalanche/logging/wandb_logger.py b/avalanche/logging/wandb_logger.py index f54e922fd..2eb91ae38 100644 --- a/avalanche/logging/wandb_logger.py +++ b/avalanche/logging/wandb_logger.py @@ -27,6 +27,7 @@ from avalanche.evaluation.metric_results import AlternativeValues, \ MetricValue, TensorImage +from avalanche.evaluation.metric_utils import phase_and_task from avalanche.logging import StrategyLogger @@ -84,7 +85,8 @@ def __init__(self, project_name: str = "Avalanche", def import_wandb(self): try: import wandb - except ImportError: + assert hasattr(wandb, '__version__') + except (ImportError, AssertionError): raise ImportError( 'Please run "pip install wandb" to install wandb') self.wandb = wandb @@ -107,31 +109,8 @@ def before_run(self): self.wandb.run._label(repo="Avalanche") def log_single_metric(self, name, value, x_plot): - if isinstance(value, AlternativeValues): - value = value.best_supported_value(Image, Tensor, TensorImage, - Figure, float, int, - self.wandb.viz.CustomChart) - if not isinstance(value, (Image, Tensor, Figure, float, int, - self.wandb.viz.CustomChart)): - # Unsupported type - return - - if isinstance(value, Image): - self.wandb.log({name: self.wandb.Image(value)}) - - elif isinstance(value, Tensor): - value = np.histogram(value.view(-1).numpy()) - self.wandb.log({name: self.wandb.Histogram(np_histogram=value)}) - - elif isinstance(value, (float, int, Figure, - self.wandb.viz.CustomChart)): - self.wandb.log({name: value}) - - elif isinstance(value, TensorImage): - self.wandb.log({name: self.wandb.Image(array(value))}) - - elif name.startswith("WeightCheckpoint"): + if name.startswith("WeightCheckpoint"): if self.log_artifacts: cwd = os.getcwd() ckpt = os.path.join(cwd, self.path) @@ -140,18 +119,52 @@ def log_single_metric(self, name, value, x_plot): except OSError as e: if e.errno != errno.EEXIST: raise - suffix = '.pth' - dir_name = os.path.join(ckpt, name+suffix) - artifact_name = os.path.join('Models', name+suffix) + ckpt_name = "Model_{}.pth".format(phase_and_task(self.strategy)[1]) + dir_name = os.path.join(ckpt, ckpt_name) + artifact_name = os.path.join('Models', ckpt_name) + if isinstance(value, Tensor): torch.save(value, dir_name) - name = os.path.splittext(self.checkpoint) - artifact = self.wandb.Artifact(name, type='model') + model_name = os.path.splittext(self.checkpoint) + metadata = {'experience': + self.strategy.experience.current_experience, + **({'task_id': + phase_and_task(self.strategy)[1]} + if phase_and_task(self.strategy)[1] + else {})} + artifact = self.wandb.Artifact(model_name, type='model', + metadata=metadata) artifact.add_file(dir_name, name=artifact_name) self.wandb.run.log_artifact(artifact) if self.uri is not None: artifact.add_reference(self.uri, name=artifact_name) + else: + if isinstance(value, AlternativeValues): + value = value.best_supported_value(Image, Tensor, TensorImage, + Figure, float, int, + self.wandb.viz.CustomChart) + + if not isinstance(value, (Image, Tensor, Figure, float, int, + self.wandb.viz.CustomChart)): + # Unsupported type + return + + if isinstance(value, Image): + self.wandb.log({name: self.wandb.Image(value)}) + + elif isinstance(value, Tensor): + value = np.histogram(value.view(-1).numpy()) + self.wandb.log({name: self.wandb.Histogram(np_histogram=value)}) + + elif isinstance(value, (float, int, Figure, + self.wandb.viz.CustomChart)): + self.wandb.log({name: value}) + + elif isinstance(value, TensorImage): + self.wandb.log({name: self.wandb.Image(array(value))}) + + __all__ = [ 'WandBLogger' diff --git a/examples/wandb_logger.py b/examples/wandb_logger.py index 2f7271eb4..c2c6f77f8 100644 --- a/examples/wandb_logger.py +++ b/examples/wandb_logger.py @@ -17,10 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from avalanche.evaluation.metrics.checkpoint import WeightCheckpoint from os.path import expanduser import argparse + import torch from torch.nn import CrossEntropyLoss from torch.optim import SGD @@ -34,7 +36,7 @@ from avalanche.evaluation.metrics import forgetting_metrics, \ accuracy_metrics, loss_metrics, cpu_usage_metrics, \ timing_metrics, gpu_usage_metrics, ram_usage_metrics, disk_usage_metrics, \ - MAC_metrics, confusion_matrix_metrics + MAC_metrics, confusion_matrix_metrics, WeightCheckpoint from avalanche.models import SimpleMLP from avalanche.training.strategies import Naive @@ -74,6 +76,8 @@ def main(args): interactive_logger = InteractiveLogger() wandb_logger = WandBLogger(project_name=args.project, run_name=args.run, + log_artifacts=args.artifacts, + path=args.path if args.path else None, config=args) eval_plugin = EvaluationPlugin( @@ -100,7 +104,8 @@ def main(args): minibatch=True, epoch=True, experience=True, stream=True), MAC_metrics( minibatch=True, epoch=True, experience=True), - loggers=[interactive_logger, wandb_logger] + WeightCheckpoint(), + loggers=[wandb_logger] ) # CREATE THE STRATEGY INSTANCE (NAIVE) @@ -127,8 +132,13 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument('--cuda', type=int, default=0, help='Select zero-indexed cuda device. -1 to use CPU.') - parser.add_argument('--run', type=str, help='Provide a run name for WandB') parser.add_argument('--project', type=str, help='Define the name of the WandB project') + parser.add_argument('--run', type=str, help='Provide a run name for WandB') + parser.add_argument('--artifacts', default=False, + action="store_true", + help='Log Model Checkpoints as W&B Artifacts') + parser.add_argument('--path', type=str, default="Checkpoint", + help='Local path to save the model checkpoints') args = parser.parse_args() main(args)