Source code for mlcolvar.utils.trainer

from lightning import Callback
import copy


class SimpleMetricsCallback(Callback):
    """Lightning callback which append logged metrics to a list.
    The metrics are recorded at the end of each validation epoch.
    """

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_validation_end(self, trainer, pl_module):
        if not trainer.sanity_checking:
            metrics = copy.deepcopy(trainer.callback_metrics)
            self.metrics.append(metrics)


[docs] class MetricsCallback(Callback): """Lightning callback which saves logged metrics into a dictionary. The metrics are recorded at the end of each validation epoch. """
[docs] def __init__(self): super().__init__() self.metrics = {"epoch": []}
[docs] def on_train_epoch_end(self, trainer, pl_module): metrics = trainer.callback_metrics if not trainer.sanity_checking: self.metrics["epoch"].append(trainer.current_epoch) for key, val in metrics.items(): val = val.item() if key in self.metrics: self.metrics[key].append(val) else: self.metrics[key] = [val] has_scheduler = bool(getattr(trainer, "lr_scheduler_configs", None)) if has_scheduler and "lr" not in metrics and trainer.optimizers: lrs = [pg["lr"] for opt in trainer.optimizers for pg in opt.param_groups] lr_val = lrs[0] if len(lrs) == 1 else lrs if "lr" in self.metrics: self.metrics["lr"].append(lr_val) else: self.metrics["lr"] = [lr_val]
def test_metrics_callbacks(): import torch import lightning from mlcolvar.cvs import AutoEncoderCV from mlcolvar.data import DictDataset, DictModule from torch.optim.lr_scheduler import StepLR X = torch.rand((100, 2)) dataset = DictDataset({"data": X}) datamodule = DictModule(dataset) model = AutoEncoderCV([2, 2, 1]) metrics = SimpleMetricsCallback() trainer = lightning.Trainer( max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False, callbacks=metrics, ) trainer.fit(model, datamodule) model = AutoEncoderCV([2, 2, 1]) metrics = MetricsCallback() trainer = lightning.Trainer( max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False, callbacks=metrics, ) trainer.fit(model, datamodule) model = AutoEncoderCV( [2, 2, 1], options={ "lr_scheduler": { "scheduler": StepLR, "step_size": 1, "gamma": 0.5, } }, ) metrics = MetricsCallback() trainer = lightning.Trainer( max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False, callbacks=metrics, ) trainer.fit(model, datamodule) assert "lr" in metrics.metrics