Source code for mlcolvar.cvs.timelagged.deeptica

import torch
import lightning
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward, BaseGNN, Normalization
from mlcolvar.core.stats import TICA
from mlcolvar.core.loss import ReduceEigenvaluesLoss
from typing import Union, List

__all__ = ["DeepTICA"]


[docs] class DeepTICA(BaseCV): """Neural network-based time-lagged independent component analysis (Deep-TICA). It is a non-linear generalization of TICA in which a feature map is learned by a neural network optimized as to maximize the eigenvalues of the transfer operator, approximated by TICA. The method is described in [1]_. Note that from the point of view of the architecture DeepTICA is similar to the SRV [2] method. **Data**: for training it requires a DictDataset containing: - If using descriptors as input, the keys 'data' (input at time t) and 'data_lag' (input at time t+lag), as well as the corresponding 'weights' and 'weights_lag' which will be used to weight the time correlation functions. - If using graphs as input, the keys 'data_list' and 'data_list_lag', each containing the respective 'weight' This can be created in both cases with the helper function `create_timelagged_dataset`. **Loss**: maximize TICA eigenvalues (ReduceEigenvaluesLoss) References ---------- .. [1] L. Bonati, G. Piccini, and M. Parrinello, “ Deep learning the slow modes for rare events sampling,” PNAS USA 118, e2113533118 (2021) .. [2] W. Chen, H. Sidky, and A. L. Ferguson, “ Nonlinear discovery of slow molecular modes using state-free reversible vampnets,” JCP 150, 214114 (2019). See also -------- mlcolvar.core.stats.TICA Time Lagged Indipendent Component Analysis mlcolvar.core.loss.ReduceEigenvalueLoss Eigenvalue reduction to a scalar quantity mlcolvar.utils.timelagged.create_timelagged_dataset Create dataset of time-lagged data. """ DEFAULT_BLOCKS = ["norm_in", "nn", "tica"] MODEL_BLOCKS = ["nn", "tica"]
[docs] def __init__(self, model: Union[List[int], FeedForward, BaseGNN], n_cvs: int = None, options: dict = None, **kwargs): """ Define a Deep-TICA CV, composed of a neural network module and a TICA object. By default a module standardizing the inputs is also used. Parameters ---------- model : list or FeedForward or BaseGNN Determines the underlying machine-learning model. One can pass: 1. A list of integers corresponding to the number of neurons per layer of a feed-forward NN. The model Will be automatically intialized using a `mlcolvar.core.nn.feedforward.FeedForward` object. The CV class will be initialized according to the DEFAULT_BLOCKS. 2. An externally intialized model (either `mlcolvar.core.nn.feedforward.FeedForward` or `mlcolvar.core.nn.graph.BaseGNN` object). The CV class will be initialized according to the MODEL_BLOCKS. n_cvs : int, optional Number of cvs to optimize, default None (= last layer) options : dict[str, Any], optional Options for the building blocks of the model, by default {}. Available blocks: ['norm_in','nn','tica']. Set 'block_name' = None or False to turn off that block """ super().__init__(model, **kwargs) # ======= LOSS ======= # Maximize the squared sum of all the TICA eigenvalues. self.loss_fn = ReduceEigenvaluesLoss(mode="sum2") # here we need to override the self.out_features attribute self.out_features = n_cvs # ======= OPTIONS ======= # parse and sanitize options = self.parse_options(options) # ======= BLOCKS ======= if not self._override_model: # initialize norm_in o = "norm_in" if (options[o] is not False) and (options[o] is not None): self.norm_in = Normalization(self.in_features, **options[o]) # initialize nn o = "nn" self.nn = FeedForward(self.layers, **options[o]) elif self._override_model: self.nn = model if self.out_features is not None: self.register_buffer('n_out', torch.as_tensor(self.out_features)) # initialize tica o = "tica" self.tica = TICA(self.nn.out_features, n_cvs, **options[o])
def forward_nn(self, x: torch.Tensor) -> torch.Tensor: if not self._override_model: if self.norm_in is not None: x = self._apply_module(self.norm_in, x) x = self._apply_module(self.nn, x) return x
[docs] def set_regularization(self, c0_reg=1e-6): """ Add identity matrix multiplied by `c0_reg` to correlation matrix C(0) to avoid instabilities in performin Cholesky and . Parameters ---------- c0_reg : float Regularization value for C_0. """ self.tica.reg_C_0 = c0_reg
[docs] def training_step(self, train_batch, batch_idx): """Compute and return the training loss and record metrics. 1) Calculate the NN output 2) Remove average (inside forward_nn) 3) Compute TICA """ # =================get data=================== if isinstance(self.nn, FeedForward): x_t = train_batch["data"] x_lag = train_batch["data_lag"] w_t = train_batch["weights"] w_lag = train_batch["weights_lag"] elif isinstance(self.nn, BaseGNN): x_t = self._setup_graph_data(train_batch, key='data_list') x_lag = self._setup_graph_data(train_batch, key='data_list_lag') w_t = x_t['weight'] w_lag = x_lag['weight'] # =================forward==================== f_t = self.forward_nn(x_t) f_lag = self.forward_nn(x_lag) # ===================tica===================== eigvals, _ = self.tica.compute( data=[f_t, f_lag], weights=[w_t, w_lag], save_params=True ) # ===================loss===================== loss = self.loss_fn(eigvals) # ====================log===================== name = "train" if self.training else "valid" loss_dict = {f"{name}_loss": loss} eig_dict = {f"{name}_eigval_{i+1}": eigvals[i] for i in range(len(eigvals))} self.log_dict(dict(loss_dict, **eig_dict), on_step=True, on_epoch=True) return loss
def test_deep_tica(): # tests import numpy as np from mlcolvar.data import DictModule from mlcolvar.utils.timelagged import create_timelagged_dataset # create dataset X = torch.randn((10000, 2)) dataset = create_timelagged_dataset(X, lag_time=1) datamodule = DictModule(dataset, batch_size=10000) # create cv print() print('NORMAL') print() layers = [2, 10, 10, 2] model = DeepTICA(layers, n_cvs=1) # change loss options model.loss_fn.mode = "sum2" # create trainer and fit trainer = lightning.Trainer( max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False ) trainer.fit(model, datamodule) model.eval() with torch.no_grad(): s = model(X).numpy() print(X.shape, "-->", s.shape) print() print('EXTERNAL') print() ff_model = FeedForward(layers=layers) model = DeepTICA(ff_model, n_cvs=1) # change loss options model.loss_fn.mode = "sum2" # create trainer and fit trainer = lightning.Trainer( max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False ) trainer.fit(model, datamodule) model.eval() with torch.no_grad(): s = model(X).numpy() print(X.shape, "-->", s.shape) # gnn external print() print('GNN') print() from mlcolvar.core.nn.graph.schnet import SchNetModel from mlcolvar.data.graph.utils import create_test_graph_input gnn_model = SchNetModel(n_out=2, cutoff=0.1, atomic_numbers=[1, 8]) model = DeepTICA(gnn_model, n_cvs=1) # change loss options model.loss_fn.mode = "sum2" # create trainer and fit trainer = lightning.Trainer( max_epochs=1, log_every_n_steps=2, logger=False, enable_checkpointing=False, enable_model_summary=False, ) dataset = create_test_graph_input(output_type='dataset', n_samples=200, n_states=2) lagged_dataset = create_timelagged_dataset(dataset, logweights=torch.randn(len(dataset))) datamodule = DictModule(dataset=lagged_dataset) trainer.fit(model, datamodule) model.eval() with torch.no_grad(): example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2) s = model(example_input_graph_test).numpy() print(X.shape, "-->", s.shape)