Source code for mlcolvar.cvs.committor.committor

import torch
import lightning
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward, Normalization, BaseGNN
from mlcolvar.core.loss import CommittorLoss
from mlcolvar.core.nn.utils import Custom_Sigmoid
from typing import Union, List

__all__ = ["Committor"]


[docs] class Committor(BaseCV): """Base class for data-driven learning of committor function. The committor function q is expressed as the output of a neural network optimized with a self-consistent approach based on the Kolmogorov's variational principle for the committor and on the imposition of its boundary conditions (see Refs. [1,2]). It is also possible to use an approximated variational approach without explicit dependence on the atomic coordinates (see Ref. [3]). **Data**: for training it requires a DictDataset containing: - If using descriptors as input, the keys 'data', 'labels' and 'weights'. - If using graphs as input, `torch_geometric.data` with 'graph_labels' and 'weight' in their 'data_list'. **Loss**: Minimize Kolmogorov's variational functional of q and impose boundary condition on the metastable states (CommittorLoss) from Refs. [1,2]. It is also possible to use an approximated variational approach without explicit dependence on the atomic coordinates References ---------- .. [1] P. Kang, E. Trizio, and M. Parrinello, "Computing the committor using the committor to study the transition state ensemble", Nat. Comput. Sci., 2024, DOI: 10.1038/s43588-024-00645-0 .. [2] E. Trizio, P. Kang, and M. Parrinello, "Everything everywhere all at once: a probability-based enhanced sampling approach to rare events", Nat. Comput. Sci., 2025, DOI: 10.1038/s43588-025-00799-5 .. [3] E. Trizio, G. Rossi, and M. Parrinello, "Ceci n'est pas un committor: Efficient sampling via approximated committor functions", J Chem. Phys., 2026, DOI: 10.1063/5.0331622 See also -------- mlcolvar.cvs.committor.utils.compute_committor_weights Utils to compute the appropriate weights for the training set mlcolvar.cvs.committor.utils.initialize_committor_masses Utils to initialize the masses tensor for the training mlcolvar.core.loss.CommittorLoss Kolmogorov's variational optimization of committor and imposition of boundary conditions mlcolvar.core.loss.utils.SmartDerivatives Class to optimize the gradients calculation imporving speed and memory efficiency. """ DEFAULT_BLOCKS = ["norm_in", "nn", "sigmoid"] MODEL_BLOCKS = ["nn", "sigmoid"]
[docs] def __init__( self, model: Union[List[int], FeedForward, BaseGNN], alpha: float, atomic_masses: torch.Tensor = None, gamma: float = 10000, delta_f: float = 0, separate_boundary_dataset: bool = True, descriptors_derivatives: torch.nn.Module = None, log_var: bool = False, use_gradients_wrt_positions: bool = True, z_regularization: float = 0.0, z_threshold: float = None, n_dim: int = None, norm_in: bool = False, options: dict = None, **kwargs, ): """Define a NN-based committor model Parameters ---------- layers : list Number of neurons per layer alpha : float Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B) atomic_masses : torch.Tensor List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z, by default None. The mlcolvar.cvs.committor.utils.initialize_committor_masses can be used to simplify this. If the position-less loss is used, this must be set to None. gamma : float, optional Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound), by default 10000 delta_f : float, optional Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0. State B is supposed to be higher in energy. separate_boundary_dataset : bool, optional Switch to exculde boundary condition labeled data from the variational loss, by default True descriptors_derivatives : torch.nn.Module, optional `SmartDerivatives` object to save memory and time when using descriptors. Cannot be used with GNN models. See also mlcolvar.core.loss.committor_loss.SmartDerivatives log_var : bool, optional Switch to minimize the log of the variational functional, by default False. use_gradients_wrt_positions : bool, optional Whether to use gradients with respect to positions as prescribed in the original Kolmogorov variational functional, by default True. Set to false to use the approximated variational principle defined in Ref. [3] without explicit dependence on the atomic coordinates derivatives. z_regularization : float, optional Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'. The magnitude of the regularization is scaled by the given number, by default 0.0 z_threshold : float, optional Sets a maximum threshold for the z value during the training, by default None. The magnitude of the regularization term is scaled via the `z_regularization` key. n_dim : int Number of dimensions, by default None. If None, it defaults to 3 for the position-based loss and to 1 for the position-less loss. norm_in : bool Whether to normalize the input of the NN model, by default False. options : dict[str, Any], optional Options for the building blocks of the model, by default {}. Available blocks: ['nn']. """ super().__init__(model, **kwargs) if use_gradients_wrt_positions and atomic_masses is None: raise ValueError("atomic_masses must be provided when using Kolmogorov variational functional (use_gradients_wrt_positions is True)") elif not use_gradients_wrt_positions: if atomic_masses is not None: raise ValueError("atomic_masses must be None when using approximated variational principle (use_gradients_wrt_positions is False)") if descriptors_derivatives is not None: raise ValueError("descriptors_derivatives must be None when using approximated variational principle (use_gradients_wrt_positions is False)") # ======= LOSS ======= self.loss_fn = CommittorLoss(alpha=alpha, atomic_masses=atomic_masses, gamma=gamma, delta_f=delta_f, separate_boundary_dataset=separate_boundary_dataset, descriptors_derivatives=descriptors_derivatives, log_var=log_var, use_gradients_wrt_positions=use_gradients_wrt_positions, z_regularization=z_regularization, z_threshold=z_threshold, n_dim=n_dim ) # ======= OPTIONS ======= # parse and sanitize options = self.parse_options(options) # ======= BLOCKS ======= if not self._override_model: # Initialize norm_in o = "norm_in" if norm_in and (options[o] is not False) and (options[o] is not None): self.norm_in = Normalization(self.in_features, **options[o]) # initialize NN o = "nn" # set default activation to tanh if "activation" not in options[o]: options[o]["activation"] = "tanh" self.nn = FeedForward(self.layers, **options[o]) elif self._override_model: self.nn = model if self.nn.out_features != 1: raise ValueError('Output of the model must be of dimension 1') # separately add sigmoid activation on last layer, this way it can be deactived o = "sigmoid" if (options[o] is not False) and (options[o] is not None): self.sigmoid = Custom_Sigmoid(**options[o])
def forward_nn(self, x, cell=None): if self.preprocessing is not None: x = self._apply_module(self.preprocessing, x, cell=cell) if not self._override_model and self.norm_in is not None: x = self.norm_in(x) z = self.nn(x) return z
[docs] def training_step(self, train_batch, batch_idx): torch.set_grad_enabled(True) """Compute and return the training loss and record metrics.""" # =================get data=================== if isinstance(self.nn, FeedForward): x = train_batch["data"] # check data have shape (n_data, -1) x = x.reshape((x.shape[0], -1)) x.requires_grad = True labels = train_batch["labels"] weights = train_batch["weights"] elif isinstance(self.nn, BaseGNN): x = self._setup_graph_data(train_batch) labels = x['graph_labels'] weights = x['weight'].clone() try: ref_idx = train_batch["ref_idx"] except KeyError: ref_idx = None cell = self._get_batch_cell(train_batch) # =================forward==================== z = self.forward_nn(x, cell=cell) if self.sigmoid is not None: q = self.sigmoid(z) else: q = z # ===================loss===================== if self.training: loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn( x, z, q, labels, weights, ref_idx ) else: loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn( x, z, q, labels, weights, ref_idx ) # ====================log=====================+ name = "train" if self.training else "valid" self.log(f"{name}_loss", loss, on_epoch=True) self.log(f"{name}_loss_var", loss_var, on_epoch=True) self.log(f"{name}_loss_bound_A", loss_bound_A, on_epoch=True) self.log(f"{name}_loss_bound_B", loss_bound_B, on_epoch=True) return loss
def test_committor_1(): from mlcolvar.data import DictDataset, DictModule from mlcolvar.cvs.committor.utils import initialize_committor_masses, KolmogorovBias torch.manual_seed(42) # create two fake atoms and use their fake positions atomic_masses = initialize_committor_masses(atom_types=[0,1], masses=[15.999, 1.008]) # create dataset samples = 20 X = torch.randn((4*samples, 6)) # create labels y = torch.zeros(X.shape[0]) y[samples:] += 1 y[int(2*samples):] += 1 y[int(3*samples):] += 1 # create weights w = torch.ones(X.shape[0]) dataset = DictDataset({"data": X, "labels": y, "weights": w}) datamodule = DictModule(dataset, lengths=[1]) # train model trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) # dataset separation ref_out = torch.Tensor([[0.6544],[0.6197],[0.5898],[0.5733],[0.6533],[0.5534],[0.5616],[0.6202],[0.5582],[0.5430], [0.6364],[0.5984],[0.6382],[0.5967],[0.6440],[0.5697],[0.6061],[0.6010],[0.6399],[0.6172], [0.6164],[0.6528],[0.6583],[0.6236],[0.6641],[0.5834],[0.5832],[0.6204],[0.6409],[0.6558], [0.5891],[0.5879],[0.5890],[0.6583],[0.6577],[0.6467],[0.6405],[0.6590],[0.6463],[0.5581], [0.6154],[0.6368],[0.6196],[0.5162],[0.5998],[0.6041],[0.5513],[0.6476],[0.5742],[0.6162], [0.6462],[0.6371],[0.5295],[0.6148],[0.5999],[0.5870],[0.6352],[0.6145],[0.5708],[0.4992], [0.6539],[0.6014],[0.6470],[0.6299],[0.6254],[0.5268],[0.6286],[0.6056],[0.6077],[0.6055], [0.5861],[0.5991],[0.6449],[0.6500],[0.6295],[0.5627],[0.6269],[0.6392],[0.5961],[0.6694]]) ref_bias = torch.Tensor([-6.2043, -6.8591, -7.7645, -7.8704, -5.8342, -7.5036, -7.8780, -6.9957, -7.8679, -7.7473, -7.2451, -7.6833, -6.7631, -7.7863, -6.6693, -7.6212, -7.6929, -7.5685, -6.6894, -7.4857, -7.5187, -4.9488, -6.4961, -7.3898, -6.0350, -7.8837, -7.8748, -7.2552, -7.1221, -5.8647, -7.9190, -7.7184, -7.7073, -4.7898, -5.4073, -5.9113, -6.5451, -4.7149, -5.8899, -7.7421, -7.3999, -7.3456, -7.3005, -7.5067, -7.7396, -7.7099, -7.8664, -6.3275, -7.8864, -7.7243, -6.4288, -5.7041, -7.9351, -7.1991, -7.7027, -7.7947, -6.7121, -7.6094, -7.9009, -7.0479, -5.2398, -7.8241, -5.8642, -7.0701, -7.0348, -7.2577, -6.6142, -7.6322, -7.3279, -7.6393, -7.8608, -7.7037, -6.6949, -6.3947, -7.2246, -7.7009, -6.7359, -7.2186, -7.7849, -5.6882]) model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1) trainer.fit(model, datamodule) out = model(X) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) bias_model = KolmogorovBias(input_model=model, beta=1, epsilon=1e-6, lambd=1) bias = bias_model(X) assert( torch.allclose(bias, ref_bias, atol=1e-3) ) # naive whole dataset ref_out = torch.Tensor([[0.1206],[0.0688],[0.0941],[0.1026],[0.0739],[0.1279],[0.1115],[0.0629],[0.0994],[0.1012], [0.0886],[0.1218],[0.0785],[0.0704],[0.0948],[0.1193],[0.0877],[0.0964],[0.0774],[0.0874], [0.0948],[0.0636],[0.0869],[0.0664],[0.0659],[0.0927],[0.0654],[0.0927],[0.0743],[0.0787], [0.0802],[0.1074],[0.1105],[0.0595],[0.0693],[0.0620],[0.0688],[0.0669],[0.0591],[0.0986], [0.0706],[0.1180],[0.0894],[0.1030],[0.1012],[0.0606],[0.1408],[0.0766],[0.1063],[0.1049], [0.0749],[0.0588],[0.1177],[0.1127],[0.1090],[0.0806],[0.0954],[0.0799],[0.1048],[0.1378], [0.0783],[0.1384],[0.0689],[0.0649],[0.0983],[0.1548],[0.0778],[0.0934],[0.0858],[0.1203], [0.1073],[0.1139],[0.0716],[0.0988],[0.0918],[0.1109],[0.0918],[0.0928],[0.1070],[0.0742]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, separate_boundary_dataset=False) trainer.fit(model, datamodule) out = model(X) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) # test log loss ref_out = torch.Tensor([[0.7287],[0.6505],[0.5594],[0.6758],[0.7482],[0.6804],[0.7313],[0.6762],[0.6873],[0.6267], [0.6362],[0.8129],[0.5853],[0.5262],[0.6359],[0.5263],[0.4839],[0.7291],[0.6884],[0.6375], [0.6231],[0.6997],[0.5906],[0.6247],[0.5876],[0.7198],[0.6356],[0.5933],[0.6229],[0.7093], [0.5618],[0.5005],[0.7924],[0.6965],[0.6540],[0.5476],[0.6151],[0.7042],[0.6190],[0.5362], [0.6275],[0.5959],[0.7194],[0.6122],[0.4873],[0.6653],[0.6741],[0.7011],[0.7207],[0.5863], [0.6040],[0.7643],[0.6696],[0.6424],[0.6886],[0.5775],[0.6620],[0.7104],[0.7517],[0.7387], [0.7714],[0.5825],[0.6442],[0.5796],[0.6131],[0.5923],[0.7023],[0.5730],[0.7307],[0.6404], [0.5780],[0.6850],[0.5959],[0.6718],[0.6626],[0.6068],[0.7319],[0.5498],[0.6772],[0.5846]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, log_var=True) trainer.fit(model, datamodule) out = model(X) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) # test z regularization ref_out = torch.Tensor([[0.2878],[0.1591],[0.1665],[0.1166],[0.1349],[0.1053],[0.1544],[0.1113],[0.1435],[0.1232], [0.1130],[0.1261],[0.1726],[0.2098],[0.2091],[0.1407],[0.1942],[0.1400],[0.1382],[0.1630], [0.1573],[0.1742],[0.1613],[0.1289],[0.1703],[0.1390],[0.1184],[0.2557],[0.1520],[0.1328], [0.2220],[0.2254],[0.1823],[0.1426],[0.1744],[0.2594],[0.1105],[0.1390],[0.1557],[0.1985], [0.1340],[0.1971],[0.1429],[0.1270],[0.2239],[0.1134],[0.1999],[0.1416],[0.1707],[0.2238], [0.2054],[0.1560],[0.2357],[0.2971],[0.1445],[0.1906],[0.2130],[0.1457],[0.1382],[0.1432], [0.1337],[0.1444],[0.1603],[0.1396],[0.2043],[0.1964],[0.1459],[0.2243],[0.1930],[0.1893], [0.2634],[0.1868],[0.1340],[0.2483],[0.1550],[0.1559],[0.1614],[0.2020],[0.1270],[0.2555]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, z_regularization=100, z_threshold=0.000001) trainer.fit(model, datamodule) out = model(X) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) # test position-less loss ref_out = torch.Tensor([[0.2318],[0.2119],[0.3039],[0.2349],[0.1933],[0.2506],[0.1453],[0.2849],[0.2042],[0.2514], [0.3391],[0.1043],[0.3083],[0.3091],[0.2147],[0.4049],[0.4225],[0.1488],[0.2421],[0.2429], [0.2354],[0.1662],[0.3195],[0.3682],[0.2881],[0.2027],[0.3813],[0.2461],[0.2892],[0.2725], [0.2833],[0.3431],[0.1060],[0.2795],[0.2566],[0.3266],[0.3747],[0.3010],[0.2916],[0.3081], [0.3136],[0.2971],[0.1736],[0.2491],[0.3451],[0.3594],[0.1713],[0.2217],[0.1426],[0.2170], [0.2296],[0.1287],[0.1386],[0.1911],[0.1898],[0.2731],[0.1899],[0.1999],[0.1325],[0.1380], [0.1678],[0.3036],[0.2935],[0.3828],[0.2623],[0.2214],[0.2149],[0.2332],[0.1143],[0.2179], [0.2237],[0.1641],[0.3423],[0.2643],[0.2220],[0.2521],[0.1774],[0.3679],[0.2753],[0.1834]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(model=[6, 4, 2, 1], atomic_masses=None, alpha=1e-1, use_gradients_wrt_positions=False) trainer.fit(model, datamodule) out = model(X) print(out) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) # test z_regularization errors trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) for z_regularization, z_threshold in zip([10, 0, -1, 10], [None, 10, 1, -1]): try: model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, z_regularization=z_regularization, z_threshold=z_threshold, n_dim=2) trainer.fit(model, datamodule) except ValueError as e: print("[TEST LOG] Checked this error: ", e) # test dimension error try: trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, z_regularization=10, z_threshold=1, n_dim=2) trainer.fit(model, datamodule) except RuntimeError as e: print("[TEST LOG] Checked this error: ", e) def test_committor_2(): from mlcolvar.data import DictDataset, DictModule from mlcolvar.cvs.committor.utils import initialize_committor_masses, KolmogorovBias # create two fake atoms and use their fake positions atomic_masses = initialize_committor_masses(atom_types=[0,1], masses=[15.999, 1.008]) # create dataset samples = 50 X = torch.randn((4*samples, 6)) # create labels y = torch.zeros(X.shape[0]) y[samples:] += 1 y[int(2*samples):] += 1 y[int(3*samples):] += 1 # create weights w = torch.ones(X.shape[0]) dataset = DictDataset({"data": X, "labels": y, "weights": w}) datamodule = DictModule(dataset, lengths=[1]) # train model trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) print() print('NORMAL') print() # dataset separation model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0) trainer.fit(model, datamodule) model(X).sum().backward() bias_model = KolmogorovBias(input_model=model, beta=1, epsilon=1e-6, lambd=1) bias_model(X) # naive whole dataset trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, separate_boundary_dataset=False) trainer.fit(model, datamodule) model(X).sum().backward() print() print('EXTERNAL FEEDFORWARD') print() # dataset separation ff_model = FeedForward([6, 4, 2, 1]) model = Committor(model=ff_model, atomic_masses=atomic_masses, alpha=1e-1, delta_f=0) trainer.fit(model, datamodule) model(X).sum().backward() bias_model = KolmogorovBias(input_model=model, beta=1, epsilon=1e-6, lambd=1) bias_model(X) # naive whole dataset trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(model=ff_model, atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, separate_boundary_dataset=False) trainer.fit(model, datamodule) model(X).sum().backward() print() print('EXTERNAL GNN') print() from mlcolvar.core.nn.graph import SchNetModel from mlcolvar.data.graph.utils import create_test_graph_input gnn_model = SchNetModel(n_out=1, cutoff=0.1, atomic_numbers=[1, 8]) model = Committor(model=gnn_model, atomic_masses=atomic_masses, alpha=1e-1, delta_f=0) datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=3, n_atoms=3) trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0, enable_model_summary=False) trainer.fit(model, datamodule) example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2) model(example_input_graph_test).sum().backward() def test_committor_with_derivatives(): from mlcolvar.cvs.committor.utils import initialize_committor_masses from mlcolvar.data import DictModule, DictDataset from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives, compute_descriptors_derivatives from mlcolvar.core.transform import PairwiseDistances torch.manual_seed(42) n_atoms = 10 kT = 2.49432 # input positions for alanine example ref_pos = torch.Tensor([[ 1.2980, 0.5370, 1.3370, 1.3270, 0.5710, 1.1960, 1.4110, 0.5070, 1.1310, 1.2520, 0.6710, 1.1440, 1.2490, 0.6890, 0.9990, 1.1270, 0.6130, 0.9550, 1.2340, 0.8420, 0.9810, 1.1860, 0.9140, 1.0700, 1.2790, 0.8870, 0.8630, 1.2550, 1.0230, 0.8240 ], [ 2.7530, 0.7150, 0.5170, 2.8460, 0.6150, 0.5780, 2.9520, 0.6560, 0.6220, 2.8150, 0.4870, 0.5730, 2.9100, 0.3830, 0.6150, 2.9310, 0.3890, 0.7690, 2.8520, 0.2450, 0.5830, 2.7300, 0.2380, 0.5550, 2.9420, 0.1390, 0.5840, 2.9030, -0.0030, 0.5690 ], [ 0.4830, 2.5610, 2.9980, 0.5620, 2.5410, 2.8660, 0.5080, 2.4950, 2.7660, 0.6960, 2.5590, 2.8790, 0.8060, 2.5410, 2.7750, 0.7890, 2.6570, 2.6680, 0.9450, 2.5390, 2.8400, 0.9620, 2.5380, 2.9610, 1.0510, 2.5430, 2.7590, 1.1860, 2.5410, 2.7990 ], [ 1.0680, 0.1770, 0.1670, 0.9560, 0.2290, 0.0920, 0.9320, 0.1730, -0.0070, 0.8770, 0.3280, 0.1460, 0.7710, 0.4040, 0.0760, 0.7230, 0.5180, 0.1660, 0.8270, 0.4640, -0.0530, 0.9010, 0.5650, -0.0450, 0.7790, 0.4160, -0.1670, 0.8260, 0.4500, -0.2950 ], [ 2.4600, 0.5670, 2.4940, 2.6050, 0.5640, 2.5060, 2.6660, 0.4630, 2.5020, 2.6640, 0.6830, 2.5220, 2.8040, 0.7250, 2.5200, 2.8880, 0.6370, 2.6190, 2.8690, 0.7270, 2.3820, 2.9600, 0.8080, 2.3570, 2.8260, 0.6310, 2.3010, 2.8630, 0.6170, 2.1580 ]] ) # weights for inputs ref_weights = torch.Tensor([1.4809, 0.0736, 0.3693, 0.1849, 0.0885]) # initialize dataset with positions dataset = DictDataset({"data": ref_pos, "weights": ref_weights, "labels": torch.arange((len(ref_pos)))}) # initialize descriptors calculations: all pairwise distances ComputeDistances = PairwiseDistances(n_atoms=10, PBC=False, cell=[1, 1, 1], scaled_coords=False) # create friction tensor masses = initialize_committor_masses(atom_types=[0,0,1,2,0,0,0,1,2,0], masses=[ 12.011, 12.011, 15.999, 14.0067, 12.011, 12.011, 12.011, 15.999, 14.0067, 12.011]) # --------------------------------- TRAIN MODELS --------------------------------- # Train the models: positions as input, desc as input with smartderivatives and passing derivatives for separate_boundary_dataset in [False, True]: # 1 ------------ Positions as input ------------ # initialize datamodule torch.manual_seed(42) datamodule = DictModule(dataset, lengths=[1.0]) # seed for reproducibility model = Committor(model=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset) # here we use the preprocessing model.preprocessing = ComputeDistances trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) # fit trainer.fit(model, datamodule) # save outputs as a reference X = dataset["data"] # this is to check other strategies ref_output = model(X) if separate_boundary_dataset: ref_output_check = torch.Tensor([[0.4759], [0.4765], [0.4828], [0.4786], [0.4725]]) else: ref_output_check = torch.Tensor([[0.4756], [0.4762], [0.4825], [0.4783], [0.4723]]) assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3))) if not separate_boundary_dataset: # 2 ------------ Descriptors as input + explicit pass derivatives ------------ # get descriptor and their derivatives pos, desc, d_desc_d_pos = compute_descriptors_derivatives(dataset=dataset, descriptor_function=ComputeDistances, n_atoms=n_atoms, separate_boundary_dataset=separate_boundary_dataset) dataset_desc = DictDataset({"data": desc, "weights": ref_weights, "labels": torch.arange((len(ref_pos)))}, create_ref_idx=True) # seed for reproducibility torch.manual_seed(42) datamodule = DictModule(dataset_desc, lengths=[1.0]) model = Committor(model=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset, descriptors_derivatives=d_desc_d_pos) trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) # fit trainer.fit(model, datamodule) # save outputs as a reference X = dataset_desc["data"] # this is to check other strategies ref_output = model(X) assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3))) # test errors try: # separate boundary with explicit derivatives model = Committor(model=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=True, descriptors_derivatives=d_desc_d_pos) trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) trainer.fit(model, datamodule) except ValueError as e: print("[TEST LOG] Checked this error: ", e) # 3 ------------ Descriptors as input + SmartDerivatives ------------ # initialize smart derivatives, we do it explicitly to test different functionalities smart_derivatives = SmartDerivatives() smart_dataset = smart_derivatives.setup(dataset=dataset, descriptor_function=ComputeDistances, n_atoms=n_atoms, separate_boundary_dataset=separate_boundary_dataset) # seed for reproducibility torch.manual_seed(42) datamodule = DictModule(smart_dataset, lengths=[1.0]) model = Committor(model=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset, descriptors_derivatives=smart_derivatives) trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) # fit trainer.fit(model, datamodule) # save outputs as a reference X = smart_dataset["data"] # this is to check other strategies ref_output = model(X) assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3))) # test errors try: # no ref_idx! wrong_dataset = DictDataset(data=smart_dataset['data'], labels=smart_dataset['labels'], weights=smart_dataset['weights']) wrong_datamodule = DictModule(wrong_dataset, lengths=[1.0]) trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) trainer.fit(model, wrong_datamodule) except ValueError as e: print("[TEST LOG] Checked this error: ", e) print() print('EXTERNAL GNN') print() from mlcolvar.core.nn.graph import SchNetModel from mlcolvar.data.graph.utils import create_test_graph_input gnn_model = SchNetModel(n_out=1, cutoff=0.1, atomic_numbers=[1, 8]) model = Committor(model=gnn_model, atomic_masses=masses, alpha=1e-1, delta_f=0) datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=3, n_atoms=3) trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0, enable_model_summary=False) trainer.fit(model, datamodule) example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2) model(example_input_graph_test).sum().backward() def test_committor_runtime_cell_training(): """Integration tests for runtime-cell descriptor preprocessing in committor training.""" from mlcolvar.data import DictDataset, DictModule from mlcolvar.core.transform import PairwiseDistances from mlcolvar.cvs.committor.utils import initialize_committor_masses import pytest torch.manual_seed(42) n_atoms = 2 n_samples = 16 # Positions in a flattened (B, n_atoms*3) format. x = torch.rand((n_samples, n_atoms * 3)) w = torch.ones(n_samples) # Labels with two boundary regions and intermediate data. y = torch.zeros(n_samples) y[n_samples // 4:] += 1 y[n_samples // 2:] += 1 y[3 * n_samples // 4:] += 1 # Runtime cells (batch, 3), with a mild frame-to-frame variation. cell = torch.ones((n_samples, 3)) cell[:, 0] *= torch.linspace(0.95, 1.05, n_samples) cell[:, 1] *= 1.0 cell[:, 2] *= 1.1 dataset = DictDataset({"data": x, "labels": y, "weights": w, "cell": cell}) datamodule = DictModule(dataset, lengths=[1.0], batch_size=8) preprocessing = PairwiseDistances( n_atoms=n_atoms, PBC=True, cell=None, scaled_coords=False, slicing_pairs=[[0, 1]], ) masses = initialize_committor_masses(atom_types=[0, 1], masses=[12.011, 1.008]) model = Committor(model=[1, 8, 1], atomic_masses=masses, alpha=1e-1, delta_f=0) model.preprocessing = preprocessing trainer = lightning.Trainer( accelerator="cpu", max_epochs=1, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0, enable_progress_bar=False, enable_model_summary=False, ) trainer.fit(model, datamodule) out = model(x, cell=cell) assert torch.isfinite(out).all() # -------- negative case: missing runtime cell should fail -------- torch.manual_seed(42) n_atoms = 2 n_samples = 12 x = torch.rand((n_samples, n_atoms * 3)) w = torch.ones(n_samples) y = torch.zeros(n_samples) y[n_samples // 3:] += 1 y[2 * n_samples // 3:] += 1 # Intentionally no "cell" key. dataset = DictDataset({"data": x, "labels": y, "weights": w}) datamodule = DictModule(dataset, lengths=[1.0], batch_size=6) preprocessing = PairwiseDistances( n_atoms=n_atoms, PBC=True, cell=None, scaled_coords=False, slicing_pairs=[[0, 1]], ) masses = initialize_committor_masses(atom_types=[0, 1], masses=[12.011, 1.008]) model = Committor(model=[1, 8, 1], atomic_masses=masses, alpha=1e-1, delta_f=0) model.preprocessing = preprocessing trainer = lightning.Trainer( accelerator="cpu", max_epochs=1, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0, enable_progress_bar=False, enable_model_summary=False, ) with pytest.raises(ValueError, match="cell"): trainer.fit(model, datamodule)