mlcolvar.cvs.BaseCV

class mlcolvar.cvs.BaseCV(model: List[int] | FeedForward | BaseGNN, preprocessing: Module = None, postprocessing: Module = None, *args, **kwargs)[source]

Bases: LightningModule

Base collective variable class.

To inherit from this class, the class must define a BLOCKS class attribute.

__init__(model: List[int] | FeedForward | BaseGNN, preprocessing: Module = None, postprocessing: Module = None, *args, **kwargs)[source]

Base CV class options.

Parameters:
  • preprocessing (torch.nn.Module, optional) – Preprocessing module, default None

  • postprocessing (torch.nn.Module, optional) – Postprocessing module, default None

Methods

__init__(model[, preprocessing, postprocessing])

Base CV class options.

configure_optimizers()

Initialize the optimizer based on self._optimizer_name and self.optimizer_kwargs.

forward(x[, cell])

Evaluation of the CV

forward_cv(x)

Execute sequentially all the blocks in self.BLOCKS unless they are not initialized.

initialize_blocks()

Initialize the blocks as attributes of the CV class.

initialize_transforms(datamodule)

on_fit_start()

Called at the very beginning of fit.

parse_model(model)

parse_options([options])

Sanitize options and create defaults ({}) if not in options.

setup([stage])

Called at the beginning of fit (train + validate), validate, test, or predict.

test_step(test_batch, batch_idx)

Equal to training step if not overridden.

to_torchscript([file_path, method, ...])

By default compiles the whole model to a torch.jit.ScriptModule Tracing can be used with the argument method='trace'.

validation_step(val_batch, batch_idx)

Equal to training step if not overridden.

configure_model() None

Hook to create modules in a strategy and precision aware context.

This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we’d like to shard the model instantly to save memory and initialization time. For non-sharded strategies, you can choose to override this hook or to initialize your model under the init_module() context manager.

This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent, i.e., after the first time the hook is called, subsequent calls to it should be a no-op.

configure_optimizers()[source]

Initialize the optimizer based on self._optimizer_name and self.optimizer_kwargs. It also adds the learning rate scheduler if self.lr_scheduler_kwargs is not empty. The scheduler is given as a dictionary with the key ‘scheduler’ containing the scheduler class and the rest of the keys are config options for the scheduler.

Returns:

  • torch.optim – Torch optimizer

  • dict, optional – Learning rate scheduler configuration (if any)

configure_sharded_model() None

Deprecated.

Use configure_model() instead.

cpu() Self

See torch.nn.Module.cpu().

cuda(device: device | int | None = None) Self

Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Arguments:
device: If specified, all parameters will be copied to that device. If None, the current CUDA device

index will be used.

Returns:

Module: self

double() Self

See torch.nn.Module.double().

property example_input_array

The example input array is a specification of what the module can consume in the forward() method. The return type is interpreted as follows:

  • Single tensor: It is assumed the model takes a single argument, i.e., model.forward(model.example_input_array)

  • Tuple: The input array should be interpreted as a sequence of positional arguments, i.e., model.forward(*model.example_input_array)

  • Dict: The input array represents named keyword arguments, i.e., model.forward(**model.example_input_array)

float() Self

See torch.nn.Module.float().

forward(x: Tensor, cell=None) Tensor[source]

Evaluation of the CV

  • Apply preprocessing if any

  • Execute sequentially all the blocks in self.BLOCKS unless they are not initialized

  • Apply postprocessing if any

Parameters:

x (torch.Tensor) – Input of the forward operation of the model

Returns:

Output of the forward operation of the model

Return type:

torch.Tensor

forward_cv(x: Tensor) Tensor[source]

Execute sequentially all the blocks in self.BLOCKS unless they are not initialized.

No pre/post processing will be executed here. This is supposed to be called during training/validation and to be overloaded if necessary.

Parameters:

x (torch.Tensor) – Input of the forward operation of the model

Returns:

Output of the forward operation of the model

Return type:

torch.Tensor

half() Self

See torch.nn.Module.half().

property hparams: AttributeDict | MutableMapping

The collection of hyperparameters saved with save_hyperparameters(). It is mutable by the user. For the frozen set of initial hyperparameters, use hparams_initial.

Returns:

Mutable hyperparameters dictionary

property hparams_initial: AttributeDict

The collection of hyperparameters saved with save_hyperparameters(). These contents are read-only. Manual updates to the saved hyperparameters can instead be performed through hparams.

Returns:

AttributeDict: immutable initial hyperparameters

initialize_blocks()[source]

Initialize the blocks as attributes of the CV class.

property n_cvs

Number of CVs.

on_after_backward() None

Called after loss.backward() and before optimizers are stepped.

Note:

If using native AMP, the gradients will not be unscaled at this point. Use the on_before_optimizer_step if you need the unscaled gradients.

on_after_batch_transfer(batch: Any, dataloader_idx: int) Any

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

Note:

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Args:

batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs.

Returns:

A batch of data

Example:

def on_after_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = gpu_transforms(batch['x'])
    return batch
See Also:
on_before_backward(loss: Tensor) None

Called before loss.backward().

Args:

loss: Loss divided by number of batches for gradient accumulation and scaled if using AMP.

on_before_batch_transfer(batch: Any, dataloader_idx: int) Any

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

Note:

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Args:

batch: A batch of data that needs to be altered or augmented. dataloader_idx: The index of the dataloader to which the batch belongs.

Returns:

A batch of data

Example:

def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch
See Also:
on_before_optimizer_step(optimizer: Optimizer) None

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: :paramref:`~lightning.pytorch.trainer.trainer.Trainer.accumulate_grad_batches`.

If using AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.

If clipping gradients, the gradients will not have been clipped yet.

Args:

optimizer: Current optimizer being used.

Example:

def on_before_optimizer_step(self, optimizer):
    # example to inspect gradient information in tensorboard
    if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
        for k, v in self.named_parameters():
            self.logger.experiment.add_histogram(
                tag=k, values=v.grad, global_step=self.trainer.global_step
            )
on_before_zero_grad(optimizer: Optimizer) None

Called after training_step() and before optimizer.zero_grad().

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

This is where it is called:

for optimizer in optimizers:
    out = training_step(...)

    model.on_before_zero_grad(optimizer) # < ---- called here
    optimizer.zero_grad()

    backward()
Args:

optimizer: The optimizer for which grads should be zeroed.

on_fit_end() None

Called at the very end of fit.

If on DDP it is called on every process

on_fit_start()[source]

Called at the very beginning of fit.

If on DDP it is called on every process

on_load_checkpoint(checkpoint: dict[str, Any]) None

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Args:

checkpoint: Loaded checkpoint

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note:

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

on_predict_batch_end(outputs: Any | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called in the predict loop after the batch.

Args:

outputs: The outputs of predict_step(x) batch: The batched data as it is returned by the prediction DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader

on_predict_batch_start(batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called in the predict loop before anything happens for that batch.

Args:

batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader

on_predict_end() None

Called at the end of predicting.

on_predict_epoch_end() None

Called at the end of predicting.

on_predict_epoch_start() None

Called at the beginning of predicting.

on_predict_model_eval() None

Called when the predict loop starts.

The predict loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior.

on_predict_start() None

Called at the beginning of predicting.

on_save_checkpoint(checkpoint: dict[str, Any]) None

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Args:
checkpoint: The full checkpoint dictionary before it gets dumped to a file.

Implementations of this hook can insert additional data into this dictionary.

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note:

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

on_test_batch_end(outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called in the test loop after the batch.

Args:

outputs: The outputs of test_step(x) batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader

on_test_batch_start(batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called in the test loop before anything happens for that batch.

Args:

batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader

on_test_end() None

Called at the end of testing.

on_test_epoch_end() None

Called in the test loop at the very end of the epoch.

on_test_epoch_start() None

Called in the test loop at the very beginning of the epoch.

on_test_model_eval() None

Called when the test loop starts.

The test loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior. See also on_test_model_train().

on_test_model_train() None

Called when the test loop ends.

The test loop by default restores the training mode of the LightningModule to what it was before starting testing. Override this hook to change the behavior. See also on_test_model_eval().

on_test_start() None

Called at the beginning of testing.

on_train_batch_end(outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int) None

Called in the training loop after the batch.

Args:

outputs: The outputs of training_step(x) batch: The batched data as it is returned by the training DataLoader. batch_idx: the index of the batch

Note:

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_batch_start(batch: Any, batch_idx: int) int | None

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch. Learning rate scheduler will still be stepped at the end of epoch.

Args:

batch: The batched data as it is returned by the training DataLoader. batch_idx: the index of the batch

on_train_end() None

Called at the end of training before logger experiment is closed.

on_train_epoch_end() None

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(self.training_step_outputs).mean()
        self.log("training_epoch_mean", epoch_mean)
        # free up the memory
        self.training_step_outputs.clear()
on_train_epoch_start() None

Called in the training loop at the very beginning of the epoch.

on_train_start() None

Called at the beginning of training after sanity check.

on_validation_batch_end(outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called in the validation loop after the batch.

Args:

outputs: The outputs of validation_step(x) batch: The batched data as it is returned by the validation DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader

on_validation_batch_start(batch: Any, batch_idx: int, dataloader_idx: int = 0) None

Called in the validation loop before anything happens for that batch.

Args:

batch: The batched data as it is returned by the validation DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader

on_validation_end() None

Called at the end of validation.

on_validation_epoch_end() None

Called in the validation loop at the very end of the epoch.

on_validation_epoch_start() None

Called in the validation loop at the very beginning of the epoch.

on_validation_model_eval() None

Called when the validation loop starts.

The validation loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior. See also on_validation_model_train().

on_validation_model_train() None

Called when the validation loop ends.

The validation loop by default restores the training mode of the LightningModule to what it was before starting validation. Override this hook to change the behavior. See also on_validation_model_eval().

on_validation_model_zero_grad() None

Called by the training loop to release gradients before entering the validation loop.

on_validation_start() None

Called at the beginning of validation.

property optimizer_name: str

Optimizer name. Options can be set using optimizer_kwargs. Actual optimizer will be return during training from configure_optimizer function.

parse_options(options: dict = None)[source]

Sanitize options and create defaults ({}) if not in options. Furthermore, it sets the optimizer kwargs, if given.

Parameters:

options (dict[str, Any], optional) – Options for the building blocks of the model, by default None.

predict_dataloader() Any

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

prepare_data() None

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

Warning

DO NOT set state to the model (use setup instead) since this is NOT called on every device

Example:

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In a distributed environment, prepare_data can be called in two ways (using prepare_data_per_node)

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.

  2. Once in total. Only called on GLOBAL_RANK=0.

Example:

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True


# call on GLOBAL_RANK=0 (great for shared file systems)
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = False

This is called before requesting the dataloaders:

model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
remove_ignored_hparams(ignore_list: list[str]) None

Remove ignored hyperparameters from the stored state.

This allows derived classes to drop hyperparameters previously saved by base classes.

Args:

ignore_list: Names of hyperparameters to remove.

save_hyperparameters(*args: Any, ignore: Sequence[str] | str | None = None, frame: FrameType | None = None, logger: bool = True) None

Save arguments to hparams attribute.

Args:
args: single object of dict, NameSpace or OmegaConf

or string names or arguments from class __init__

ignore: an argument name or a list of argument names from

class __init__ to be ignored

frame: a frame object. Default is None logger: Whether to send the hyperparameters to the logger. Default: True

Example::
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class ManuallyArgsModel(HyperparametersMixin):
...     def __init__(self, arg1, arg2, arg3):
...         super().__init__()
...         # manually assign arguments
...         self.save_hyperparameters('arg1', 'arg3')
...     def forward(self, *args, **kwargs):
...         ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class AutomaticArgsModel(HyperparametersMixin):
...     def __init__(self, arg1, arg2, arg3):
...         super().__init__()
...         # equivalent automatic
...         self.save_hyperparameters()
...     def forward(self, *args, **kwargs):
...         ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class SingleArgModel(HyperparametersMixin):
...     def __init__(self, params):
...         super().__init__()
...         # manually assign single argument
...         self.save_hyperparameters(params)
...     def forward(self, *args, **kwargs):
...         ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class ManuallyArgsModel(HyperparametersMixin):
...     def __init__(self, arg1, arg2, arg3):
...         super().__init__()
...         # pass argument(s) to ignore as a string or in a list
...         self.save_hyperparameters(ignore='arg2')
...     def forward(self, *args, **kwargs):
...         ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
setup(stage=None)[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Args:

stage: either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
teardown(stage: str) None

Called at the end of fit (train + validate), validate, test, or predict.

Args:

stage: either 'fit', 'validate', 'test', or 'predict'

test_dataloader() Any

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note:

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

test_step(test_batch, batch_idx)[source]

Equal to training step if not overridden. Different behaviors for train/valid step can be enforced in training_step() based on the self.training variable.

to(*args: Any, **kwargs: Any) Self

See torch.nn.Module.to().

to_torchscript(file_path: str | Path | None = None, method: str | None = 'script', example_inputs: Any | None = None, **kwargs: Any) ScriptModule | Dict[str, ScriptModule][source]

By default compiles the whole model to a torch.jit.ScriptModule Tracing can be used with the argument method=’trace’. In case, you can provide and example_inputs, otherwise, the default example_input_array will be used.

Args:

file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript’s script or trace method. Default: ‘script’ example_inputs: An input to be used to do tracing when method is set to ‘trace’.

Default: None (uses example_input_array)

**kwargs: Additional arguments that will be passed to the torch.jit.script() or

torch.jit.trace() function.

Return:

This LightningModule as a torchscript, regardless of whether file_path is defined or not.

train_dataloader() Any

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

transfer_batch_to_device(batch: Any, device: device, dataloader_idx: int) Any

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

  • torch.Tensor or anything that implements .to(…)

  • list

  • dict

  • tuple

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).

Note:

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Args:

batch: A batch of data that needs to be transferred to a new device. device: The target device as defined in PyTorch. dataloader_idx: The index of the dataloader to which the batch belongs.

Returns:

A reference to the data on the new device.

Example:

def transfer_batch_to_device(self, batch, device, dataloader_idx):
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    elif dataloader_idx == 0:
        # skip device transfer for the first dataloader or anything you wish
        pass
    else:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
    return batch
See Also:
  • move_data_to_device()

  • apply_to_collection()

type(dst_type: str | dtype) Self

See torch.nn.Module.type().

val_dataloader() Any

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note:

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note:

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

validation_step(val_batch, batch_idx)[source]

Equal to training step if not overridden. Different behaviors for train/valid step can be enforced in training_step() based on the self.training variable.

Attributes

CHECKPOINT_HYPER_PARAMS_KEY

CHECKPOINT_HYPER_PARAMS_NAME

CHECKPOINT_HYPER_PARAMS_TYPE

DEFAULT_BLOCKS

MODEL_BLOCKS

T_destination

automatic_optimization

If set to False you are responsible for calling .backward(), .step(), .zero_grad().

call_super_init

current_epoch

The current epoch in the Trainer, or 0 if not attached.

device

device_mesh

Strategies like ModelParallelStrategy will create a device mesh that can be accessed in the configure_model() hook to parallelize the LightningModule.

dtype

dump_patches

example_input_array

The example input array is a specification of what the module can consume in the forward() method.

fabric

global_rank

The index of the current process across all nodes and devices.

global_step

Total training batches seen across all epochs.

hparams

The collection of hyperparameters saved with save_hyperparameters().

hparams_initial

The collection of hyperparameters saved with save_hyperparameters().

local_rank

The index of the current process within a single node.

logger

Reference to the logger object in the Trainer.

loggers

Reference to the list of loggers in the Trainer.

n_cvs

Number of CVs.

on_gpu

Returns True if this model is currently located on a GPU.

optimizer_name

Optimizer name.

strict_loading

Determines how Lightning loads this model using .load_state_dict(..., strict=model.strict_loading).

trainer

training