mlcolvar.cvs.MultiTaskCV¶
- class mlcolvar.cvs.MultiTaskCV(main_cv: BaseCV, auxiliary_loss_fns: Sequence, loss_coefficients: Sequence[float] | None = None)[source]¶
Bases:
objectMulti-task collective variable.
This class wraps an existing CV object and adds a linear combination of other auxiliary loss functions that target different datasets with different information. The class works only if the
main_cvdoes not make use of__slots__.Examples
A semi-supervised autoencoder mixing ELBO and the Fisher’s discriminant loss.
>>> from mlcolvar.cvs import AutoEncoderCV >>> from mlcolvar.core.loss import FisherDiscriminantLoss >>> from mlcolvar.data import DictDataset, DictModule
>>> n_descriptors = 5 >>> n_labels = 2 # Number of states >>> n_cvs = 2
Initialize the multi-task CV. The Fisher’s discriminant loss has half the weight of the ELBO loss.
>>> main_cv = AutoEncoderCV(encoder_layers=[n_descriptors, 10, n_cvs]) >>> aux_loss_fn = FisherDiscriminantLoss(n_states=n_labels) >>> multi_cv = MultiTaskCV(main_cv, auxiliary_loss_fns=[aux_loss_fn], loss_coefficients=[0.5])
MultiTaskCV now exposes the same API as AutoEncoderCV.
>>> multi_cv.norm_in.set_custom(mean=torch.tensor(0.0), range=torch.tensor(1.0))
Create a multi-dataset datamodule for this CV.
>>> n_samples = 100 >>> unsupervised_dataset = DictDataset({ ... 'data': torch.rand(n_samples, n_descriptors), ... }) >>> supervised_dataset = DictDataset({ ... 'data': torch.rand(n_samples, n_descriptors), ... 'labels': torch.tensor([0., 1]).repeat(n_samples//2) ... }) >>> datamodule = DictModule(dataset=[unsupervised_dataset, supervised_dataset])
# Create a PyTorch Lightning trainer. >>> import lightning >>> trainer = lightning.Trainer(max_epochs=1, log_every_n_steps=5, logger=None, enable_checkpointing=False)
- __init__(main_cv: BaseCV, auxiliary_loss_fns: Sequence, loss_coefficients: Sequence[float] | None = None)[source]¶
Constructor.
- Parameters:
main_cv (BaseCV) – The main collective variable. The CV will dynamically inherit from this object’s class and expose all its members.
auxiliary_loss_fns (list) – A list of auxiliary loss functions.
loss_coefficients (list-like of floats, optional) – A list of length
len(auxiliary_loss_fns)with the coefficients of the linear combination of loss functions. If not provided, all auxiliary loss functions are assigned coefficient 1 (the main CV has always coefficient 1).
Methods
__init__(main_cv, auxiliary_loss_fns[, ...])Constructor.
training_step(train_batch, batch_idx)