mlcolvar.cvs.MultiTaskCV

class mlcolvar.cvs.MultiTaskCV(main_cv: BaseCV, auxiliary_loss_fns: Sequence, loss_coefficients: Sequence[float] | None = None)[source]

Bases: object

Multi-task collective variable.

This class wraps an existing CV object and adds a linear combination of other auxiliary loss functions that target different datasets with different information. The class works only if the main_cv does not make use of __slots__.

Examples

A semi-supervised autoencoder mixing ELBO and the Fisher’s discriminant loss.

>>> from mlcolvar.cvs import AutoEncoderCV
>>> from mlcolvar.core.loss import FisherDiscriminantLoss
>>> from mlcolvar.data import DictDataset, DictModule
>>> n_descriptors = 5
>>> n_labels = 2  # Number of states
>>> n_cvs = 2

Initialize the multi-task CV. The Fisher’s discriminant loss has half the weight of the ELBO loss.

>>> main_cv = AutoEncoderCV(encoder_layers=[n_descriptors, 10, n_cvs])
>>> aux_loss_fn = FisherDiscriminantLoss(n_states=n_labels)
>>> multi_cv = MultiTaskCV(main_cv, auxiliary_loss_fns=[aux_loss_fn], loss_coefficients=[0.5])

MultiTaskCV now exposes the same API as AutoEncoderCV.

>>> multi_cv.norm_in.set_custom(mean=torch.tensor(0.0), range=torch.tensor(1.0))

Create a multi-dataset datamodule for this CV.

>>> n_samples = 100
>>> unsupervised_dataset = DictDataset({
...     'data': torch.rand(n_samples, n_descriptors),
... })
>>> supervised_dataset = DictDataset({
...     'data': torch.rand(n_samples, n_descriptors),
...     'labels': torch.tensor([0., 1]).repeat(n_samples//2)
... })
>>> datamodule = DictModule(dataset=[unsupervised_dataset, supervised_dataset])

# Create a PyTorch Lightning trainer. >>> import lightning >>> trainer = lightning.Trainer(max_epochs=1, log_every_n_steps=5, logger=None, enable_checkpointing=False)

__init__(main_cv: BaseCV, auxiliary_loss_fns: Sequence, loss_coefficients: Sequence[float] | None = None)[source]

Constructor.

Parameters:
  • main_cv (BaseCV) – The main collective variable. The CV will dynamically inherit from this object’s class and expose all its members.

  • auxiliary_loss_fns (list) – A list of auxiliary loss functions.

  • loss_coefficients (list-like of floats, optional) – A list of length len(auxiliary_loss_fns) with the coefficients of the linear combination of loss functions. If not provided, all auxiliary loss functions are assigned coefficient 1 (the main CV has always coefficient 1).

Methods

__init__(main_cv, auxiliary_loss_fns[, ...])

Constructor.

training_step(train_batch, batch_idx)