import torch
import lightning
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward, Normalization
from mlcolvar.core.loss import TDALoss
from mlcolvar.data import DictModule
__all__ = ["DeepTDA"]
[docs]
class DeepTDA(BaseCV, lightning.LightningModule):
"""
Deep Targeted Discriminant Analysis (Deep-TDA) CV.
Combine the inputs with a neural-network and optimize it in a way such that
the data are distributed accordingly to a mixture of Gaussians. The method is described in [1]_.
**Data**: for training it requires a DictDataset with the keys 'data' and 'labels'.
**Loss**: distance of the samples of each class from a set of Gaussians (TDALoss)
References
----------
.. [1] E. Trizio and M. Parrinello, "From enhanced sampling to reaction profiles",
The Journal of Physical Chemistry Letters 12, 8621– 8626 (2021).
See also
--------
mlcolvar.core.loss.TDALoss
Distance from a simple Gaussian target distribution.
"""
BLOCKS = ["norm_in", "nn"]
# TODO n_states optional?
[docs]
def __init__(
self,
n_states: int,
n_cvs: int,
target_centers: list,
target_sigmas: list,
layers: list,
options: dict = None,
**kwargs,
):
"""
Define Deep Targeted Discriminant Analysis (Deep-TDA) CV composed by a neural network module.
By default a module standardizing the inputs is also used.
Parameters
----------
n_states : int
Number of states for the training
n_cvs : int
Number of collective variables to be trained
target_centers : list
Centers of the Gaussian targets
target_sigmas : list
Standard deviations of the Gaussian targets
layers : list
Number of neurons per layer
options : dict[str, Any], optional
Options for the building blocks of the model, by default {}.
Available blocks: ['norm_in', 'nn'].
Set 'block_name' = None or False to turn off that block
"""
super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs)
# ======= LOSS =======
self.loss_fn = TDALoss(
n_states=n_states,
target_centers=target_centers,
target_sigmas=target_sigmas,
)
# ======= OPTIONS =======
# parse and sanitize
options = self.parse_options(options)
# Save n_states
self.n_states = n_states
if self.out_features != n_cvs:
raise ValueError(
"Number of neurons of last layer should match the number of CVs!"
)
# check size and type of targets
if not isinstance(target_centers, torch.Tensor):
target_centers = torch.Tensor(target_centers)
if not isinstance(target_sigmas, torch.Tensor):
target_sigmas = torch.Tensor(target_sigmas)
if target_centers.shape != target_sigmas.shape:
raise ValueError(
f"Size of target_centers and target_sigmas should be the same!"
)
if n_states != target_centers.shape[0]:
raise ValueError(
f"Size of target_centers at dimension 0 should match the number of states! Expected {n_states} found {target_centers.shape[0]}"
)
if len(target_centers.shape) == 2:
if n_cvs != target_centers.shape[1]:
raise ValueError(
(
f"Size of target_centers at dimension 1 should match the number of cvs! Expected {n_cvs} found {target_centers.shape[1]}"
)
)
# ======= BLOCKS =======
# Initialize norm_in
o = "norm_in"
if (options[o] is not False) and (options[o] is not None):
self.norm_in = Normalization(self.in_features, **options[o])
# initialize NN
o = "nn"
self.nn = FeedForward(layers, **options[o])
[docs]
def training_step(self, train_batch, batch_idx):
"""Compute and return the training loss and record metrics."""
# =================get data===================
x = train_batch["data"]
labels = train_batch["labels"]
# =================forward====================
z = self.forward_cv(x)
# ===================loss=====================
loss, loss_centers, loss_sigmas = self.loss_fn(
z, labels, return_loss_terms=True
)
# ====================log=====================+
name = "train" if self.training else "valid"
self.log(f"{name}_loss", loss, on_epoch=True)
self.log(f"{name}_loss_centers", loss_centers, on_epoch=True)
self.log(f"{name}_loss_sigmas", loss_sigmas, on_epoch=True)
return loss
# TODO signature of tests?
import numpy as np
def test_deeptda_cv():
from mlcolvar.data import DictDataset
for states_and_cvs in [[2, 1], [3, 1], [3, 2], [5, 4]]:
# get the number of states and cvs for the test run
n_states = states_and_cvs[0]
n_cvs = states_and_cvs[1]
in_features, out_features = 2, n_cvs
layers = [in_features, 4, 2, out_features]
target_centers = np.random.randn(n_states, n_cvs)
target_sigmas = np.random.randn(n_states, n_cvs)
# test initialize via dictionary
options = {"nn": {"activation": "relu"}}
model = DeepTDA(
n_states=n_states,
n_cvs=n_cvs,
target_centers=target_centers,
target_sigmas=target_sigmas,
layers=layers,
options=options,
)
print("----------")
print(model)
# create dataset
samples = 100
X = torch.randn((samples * n_states, 2))
# create labels
y = torch.zeros(X.shape[0])
for i in range(1, n_states):
y[samples * i :] += 1
dataset = DictDataset({"data": X, "labels": y})
datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=samples)
# train model
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False
)
trainer.fit(model, datamodule)
# trace model
traced_model = model.to_torchscript(
file_path=None, method="trace", example_inputs=X[0]
)
model.eval()
assert torch.allclose(model(X), traced_model(X))