Source code for mlcolvar.cvs.committor.committor

import torch
import lightning
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward
from mlcolvar.core.loss import CommittorLoss
from mlcolvar.core.nn.utils import Custom_Sigmoid

__all__ = ["Committor"]


[docs] class Committor(BaseCV, lightning.LightningModule): """Base class for data-driven learning of committor function. The committor function q is expressed as the output of a neural network optimized with a self-consistent approach based on the Kolmogorov's variational principle for the committor and on the imposition of its boundary conditions. **Data**: for training it requires a DictDataset with the keys 'data', 'labels' and 'weights' **Loss**: Minimize Kolmogorov's variational functional of q and impose boundary condition on the metastable states (CommittorLoss) References ---------- .. [*] P. Kang, E. Trizio, and M. Parrinello, "Computing the committor using the committor to study the transition state ensemble", Nat. Comput. Sci., 2024, DOI: 10.1038/s43588-024-00645-0 .. [*] E. Trizio, P. Kang and M. Parrinello, "Everything everywhere all at once: a probability-based enhanced sampling approach to rare events", Nat. Comput. Sci., 2025, DOI: 10.1038/s43588-025-00799-5 See also -------- mlcolvar.cvs.committor.utils.compute_committor_weights Utils to compute the appropriate weights for the training set mlcolvar.cvs.committor.utils.initialize_committor_masses Utils to initialize the masses tensor for the training mlcolvar.core.loss.CommittorLoss Kolmogorov's variational optimization of committor and imposition of boundary conditions mlcolvar.core.loss.utils.SmartDerivatives Class to optimize the gradients calculation imporving speed and memory efficiency. """ BLOCKS = ["nn", "sigmoid"]
[docs] def __init__( self, layers: list, atomic_masses: torch.Tensor, alpha: float, gamma: float = 10000, delta_f: float = 0, cell: float = None, separate_boundary_dataset : bool = True, descriptors_derivatives : torch.nn.Module = None, log_var: bool = False, z_regularization: float = 0.0, z_threshold: float = None, n_dim : int = 3, options: dict = None, **kwargs, ): """Define a NN-based committor model Parameters ---------- layers : list Number of neurons per layer atomic_masses : torch.Tensor List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z. The mlcolvar.cvs.committor.utils.initialize_committor_masses can be used to simplify this. alpha : float Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B) gamma : float, optional Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound), by default 10000 delta_f : float, optional Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0. State B is supposed to be higher in energy. cell : float, optional CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, by default None separate_boundary_dataset : bool, optional Switch to exculde boundary condition labeled data from the variational loss, by default True descriptors_derivatives : torch.nn.Module, optional `SmartDerivatives` object to save memory and time when using descriptors. See also mlcolvar.core.loss.committor_loss.SmartDerivatives log_var : bool, optional Switch to minimize the log of the variational functional, by default False. z_regularization : float, optional Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'. The magnitude of the regularization is scaled by the given number, by default 0.0 z_threshold : float, optional Sets a maximum threshold for the z value during the training, by default None. The magnitude of the regularization term is scaled via the `z_regularization` key. n_dim : int Number of dimensions, by default 3. options : dict[str, Any], optional Options for the building blocks of the model, by default {}. Available blocks: ['nn'] . """ super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs) # ======= LOSS ======= self.loss_fn = CommittorLoss(atomic_masses=atomic_masses, alpha=alpha, gamma=gamma, delta_f=delta_f, cell=cell, separate_boundary_dataset=separate_boundary_dataset, descriptors_derivatives=descriptors_derivatives, log_var=log_var, z_regularization=z_regularization, z_threshold=z_threshold, n_dim=n_dim ) # ======= OPTIONS ======= # parse and sanitize options = self.parse_options(options) # ======= BLOCKS ======= # initialize NN turning o = "nn" # set default activation to tanh if "activation" not in options[o]: options[o]["activation"] = "tanh" self.nn = FeedForward(layers, **options[o]) # separately add sigmoid activation on last layer, this way it can be deactived o = "sigmoid" if (options[o] is not False) and (options[o] is not None): self.sigmoid = Custom_Sigmoid(**options[o])
def forward_nn(self, x): if self.preprocessing is not None: x = self.preprocessing(x) z = self.nn(x) return z
[docs] def training_step(self, train_batch, batch_idx): torch.set_grad_enabled(True) """Compute and return the training loss and record metrics.""" # =================get data=================== x = train_batch["data"] # check data are have shape (n_data, -1) x = x.reshape((x.shape[0], -1)) x.requires_grad = True labels = train_batch["labels"] weights = train_batch["weights"] try: ref_idx = train_batch["ref_idx"] except KeyError: ref_idx = None # =================forward==================== z = self.forward_nn(x) if self.sigmoid is not None: q = self.sigmoid(z) else: q = z # ===================loss===================== if self.training: loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn( x, z, q, labels, weights, ref_idx ) else: loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn( x, z, q, labels, weights, ref_idx ) # ====================log=====================+ name = "train" if self.training else "valid" self.log(f"{name}_loss", loss, on_epoch=True) self.log(f"{name}_loss_var", loss_var, on_epoch=True) self.log(f"{name}_loss_bound_A", loss_bound_A, on_epoch=True) self.log(f"{name}_loss_bound_B", loss_bound_B, on_epoch=True) return loss
def test_committor(): from mlcolvar.data import DictDataset, DictModule from mlcolvar.cvs.committor.utils import initialize_committor_masses, KolmogorovBias torch.manual_seed(42) # create two fake atoms and use their fake positions atomic_masses = initialize_committor_masses(atom_types=[0,1], masses=[15.999, 1.008]) # create dataset samples = 20 X = torch.randn((4*samples, 6)) # create labels y = torch.zeros(X.shape[0]) y[samples:] += 1 y[int(2*samples):] += 1 y[int(3*samples):] += 1 # create weights w = torch.ones(X.shape[0]) dataset = DictDataset({"data": X, "labels": y, "weights": w}) datamodule = DictModule(dataset, lengths=[1]) # train model trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) # dataset separation ref_out = torch.Tensor([[0.6544],[0.6197],[0.5898],[0.5733],[0.6533],[0.5534],[0.5616],[0.6202],[0.5582],[0.5430], [0.6364],[0.5984],[0.6382],[0.5967],[0.6440],[0.5697],[0.6061],[0.6010],[0.6399],[0.6172], [0.6164],[0.6528],[0.6583],[0.6236],[0.6641],[0.5834],[0.5832],[0.6204],[0.6409],[0.6558], [0.5891],[0.5879],[0.5890],[0.6583],[0.6577],[0.6467],[0.6405],[0.6590],[0.6463],[0.5581], [0.6154],[0.6368],[0.6196],[0.5162],[0.5998],[0.6041],[0.5513],[0.6476],[0.5742],[0.6162], [0.6462],[0.6371],[0.5295],[0.6148],[0.5999],[0.5870],[0.6352],[0.6145],[0.5708],[0.4992], [0.6539],[0.6014],[0.6470],[0.6299],[0.6254],[0.5268],[0.6286],[0.6056],[0.6077],[0.6055], [0.5861],[0.5991],[0.6449],[0.6500],[0.6295],[0.5627],[0.6269],[0.6392],[0.5961],[0.6694]]) ref_bias = torch.Tensor([-6.2043, -6.8591, -7.7645, -7.8704, -5.8342, -7.5036, -7.8780, -6.9957, -7.8679, -7.7473, -7.2451, -7.6833, -6.7631, -7.7863, -6.6693, -7.6212, -7.6929, -7.5685, -6.6894, -7.4857, -7.5187, -4.9488, -6.4961, -7.3898, -6.0350, -7.8837, -7.8748, -7.2552, -7.1221, -5.8647, -7.9190, -7.7184, -7.7073, -4.7898, -5.4073, -5.9113, -6.5451, -4.7149, -5.8899, -7.7421, -7.3999, -7.3456, -7.3005, -7.5067, -7.7396, -7.7099, -7.8664, -6.3275, -7.8864, -7.7243, -6.4288, -5.7041, -7.9351, -7.1991, -7.7027, -7.7947, -6.7121, -7.6094, -7.9009, -7.0479, -5.2398, -7.8241, -5.8642, -7.0701, -7.0348, -7.2577, -6.6142, -7.6322, -7.3279, -7.6393, -7.8608, -7.7037, -6.6949, -6.3947, -7.2246, -7.7009, -6.7359, -7.2186, -7.7849, -5.6882]) model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0) trainer.fit(model, datamodule) out = model(X) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) bias_model = KolmogorovBias(input_model=model, beta=1, epsilon=1e-6, lambd=1) bias = bias_model(X) assert( torch.allclose(bias, ref_bias, atol=1e-3) ) # naive whole dataset ref_out = torch.Tensor([[0.1206],[0.0688],[0.0941],[0.1026],[0.0739],[0.1279],[0.1115],[0.0629],[0.0994],[0.1012], [0.0886],[0.1218],[0.0785],[0.0704],[0.0948],[0.1193],[0.0877],[0.0964],[0.0774],[0.0874], [0.0948],[0.0636],[0.0869],[0.0664],[0.0659],[0.0927],[0.0654],[0.0927],[0.0743],[0.0787], [0.0802],[0.1074],[0.1105],[0.0595],[0.0693],[0.0620],[0.0688],[0.0669],[0.0591],[0.0986], [0.0706],[0.1180],[0.0894],[0.1030],[0.1012],[0.0606],[0.1408],[0.0766],[0.1063],[0.1049], [0.0749],[0.0588],[0.1177],[0.1127],[0.1090],[0.0806],[0.0954],[0.0799],[0.1048],[0.1378], [0.0783],[0.1384],[0.0689],[0.0649],[0.0983],[0.1548],[0.0778],[0.0934],[0.0858],[0.1203], [0.1073],[0.1139],[0.0716],[0.0988],[0.0918],[0.1109],[0.0918],[0.0928],[0.1070],[0.0742]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, separate_boundary_dataset=False) trainer.fit(model, datamodule) out = model(X) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) # test log loss ref_out = torch.Tensor([[0.7236],[0.6559],[0.5530],[0.6739],[0.7527],[0.6769],[0.7338],[0.6840],[0.6892],[0.6255], [0.6321],[0.8155],[0.5824],[0.5282],[0.6315],[0.5164],[0.4789],[0.7296],[0.6918],[0.6379], [0.6191],[0.7071],[0.5849],[0.6282],[0.5886],[0.7218],[0.6431],[0.5893],[0.6257],[0.7119], [0.5604],[0.4941],[0.7952],[0.7044],[0.6574],[0.5482],[0.6171],[0.7085],[0.6243],[0.5334], [0.6313],[0.5883],[0.7220],[0.6117],[0.4803],[0.6749],[0.6686],[0.7030],[0.7206],[0.5813], [0.6033],[0.7746],[0.6691],[0.6363],[0.6862],[0.5791],[0.6586],[0.7126],[0.7538],[0.7382], [0.7757],[0.5703],[0.6464],[0.5825],[0.6061],[0.5842],[0.7049],[0.5703],[0.7346],[0.6371], [0.5740],[0.6844],[0.5948],[0.6675],[0.6640],[0.6047],[0.7321],[0.5453],[0.6755],[0.5813]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, log_var=True) trainer.fit(model, datamodule) out = model(X) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) # test z regularization ref_out = torch.Tensor([[0.2878],[0.1591],[0.1665],[0.1166],[0.1349],[0.1053],[0.1544],[0.1113],[0.1435],[0.1232], [0.1130],[0.1261],[0.1726],[0.2098],[0.2091],[0.1407],[0.1942],[0.1400],[0.1382],[0.1630], [0.1573],[0.1742],[0.1613],[0.1289],[0.1703],[0.1390],[0.1184],[0.2557],[0.1520],[0.1328], [0.2220],[0.2254],[0.1823],[0.1426],[0.1744],[0.2594],[0.1105],[0.1390],[0.1557],[0.1985], [0.1340],[0.1971],[0.1429],[0.1270],[0.2239],[0.1134],[0.1999],[0.1416],[0.1707],[0.2238], [0.2054],[0.1560],[0.2357],[0.2971],[0.1445],[0.1906],[0.2130],[0.1457],[0.1382],[0.1432], [0.1337],[0.1444],[0.1603],[0.1396],[0.2043],[0.1964],[0.1459],[0.2243],[0.1930],[0.1893], [0.2634],[0.1868],[0.1340],[0.2483],[0.1550],[0.1559],[0.1614],[0.2020],[0.1270],[0.2555]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=100, z_threshold=0.000001) trainer.fit(model, datamodule) out = model(X) out.sum().backward() assert( torch.allclose(out, ref_out, atol=1e-3) ) # test z_regularization errors trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) for z_regularization, z_threshold in zip([10, 0, -1, 10], [None, 10, 1, -1]): try: model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=z_regularization, z_threshold=z_threshold, n_dim=2) trainer.fit(model, datamodule) except ValueError as e: print("[TEST LOG] Checked this error: ", e) # test dimension error try: trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=10, z_threshold=1, n_dim=2) trainer.fit(model, datamodule) except RuntimeError as e: print("[TEST LOG] Checked this error: ", e) def test_committor_with_derivatives(): from mlcolvar.cvs.committor.utils import initialize_committor_masses from mlcolvar.data import DictModule, DictDataset from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives, compute_descriptors_derivatives from mlcolvar.core.transform import PairwiseDistances torch.manual_seed(42) n_atoms = 10 kT = 2.49432 # input positions for alanine example ref_pos = torch.Tensor([[ 1.2980, 0.5370, 1.3370, 1.3270, 0.5710, 1.1960, 1.4110, 0.5070, 1.1310, 1.2520, 0.6710, 1.1440, 1.2490, 0.6890, 0.9990, 1.1270, 0.6130, 0.9550, 1.2340, 0.8420, 0.9810, 1.1860, 0.9140, 1.0700, 1.2790, 0.8870, 0.8630, 1.2550, 1.0230, 0.8240 ], [ 2.7530, 0.7150, 0.5170, 2.8460, 0.6150, 0.5780, 2.9520, 0.6560, 0.6220, 2.8150, 0.4870, 0.5730, 2.9100, 0.3830, 0.6150, 2.9310, 0.3890, 0.7690, 2.8520, 0.2450, 0.5830, 2.7300, 0.2380, 0.5550, 2.9420, 0.1390, 0.5840, 2.9030, -0.0030, 0.5690 ], [ 0.4830, 2.5610, 2.9980, 0.5620, 2.5410, 2.8660, 0.5080, 2.4950, 2.7660, 0.6960, 2.5590, 2.8790, 0.8060, 2.5410, 2.7750, 0.7890, 2.6570, 2.6680, 0.9450, 2.5390, 2.8400, 0.9620, 2.5380, 2.9610, 1.0510, 2.5430, 2.7590, 1.1860, 2.5410, 2.7990 ], [ 1.0680, 0.1770, 0.1670, 0.9560, 0.2290, 0.0920, 0.9320, 0.1730, -0.0070, 0.8770, 0.3280, 0.1460, 0.7710, 0.4040, 0.0760, 0.7230, 0.5180, 0.1660, 0.8270, 0.4640, -0.0530, 0.9010, 0.5650, -0.0450, 0.7790, 0.4160, -0.1670, 0.8260, 0.4500, -0.2950 ], [ 2.4600, 0.5670, 2.4940, 2.6050, 0.5640, 2.5060, 2.6660, 0.4630, 2.5020, 2.6640, 0.6830, 2.5220, 2.8040, 0.7250, 2.5200, 2.8880, 0.6370, 2.6190, 2.8690, 0.7270, 2.3820, 2.9600, 0.8080, 2.3570, 2.8260, 0.6310, 2.3010, 2.8630, 0.6170, 2.1580 ]] ) # weights for inputs ref_weights = torch.Tensor([1.4809, 0.0736, 0.3693, 0.1849, 0.0885]) # initialize dataset with positions dataset = DictDataset({"data": ref_pos, "weights": ref_weights, "labels": torch.arange((len(ref_pos)))}) # initialize descriptors calculations: all pairwise distances ComputeDistances = PairwiseDistances(n_atoms=10, PBC=False, cell=[1, 1, 1], scaled_coords=False) # create friction tensor masses = initialize_committor_masses(atom_types=[0,0,1,2,0,0,0,1,2,0], masses=[ 12.011, 12.011, 15.999, 14.0067, 12.011, 12.011, 12.011, 15.999, 14.0067, 12.011]) # --------------------------------- TRAIN MODELS --------------------------------- # Train the models: positions as input, desc as input with smartderivatives and passing derivatives for separate_boundary_dataset in [False, True]: # 1 ------------ Positions as input ------------ # initialize datamodule torch.manual_seed(42) datamodule = DictModule(dataset, lengths=[1.0]) # seed for reproducibility model = Committor(layers=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset) # here we use the preprocessing model.preprocessing = ComputeDistances trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) # fit trainer.fit(model, datamodule) # save outputs as a reference X = dataset["data"] # this is to check other strategies ref_output = model(X) if separate_boundary_dataset: ref_output_check = torch.Tensor([[0.4759], [0.4765], [0.4828], [0.4786], [0.4725]]) else: ref_output_check = torch.Tensor([[0.4756], [0.4762], [0.4825], [0.4783], [0.4723]]) assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3))) if not separate_boundary_dataset: # 2 ------------ Descriptors as input + explicit pass derivatives ------------ # get descriptor and their derivatives pos, desc, d_desc_d_pos = compute_descriptors_derivatives(dataset=dataset, descriptor_function=ComputeDistances, n_atoms=n_atoms, separate_boundary_dataset=separate_boundary_dataset) dataset_desc = DictDataset({"data": desc, "weights": ref_weights, "labels": torch.arange((len(ref_pos)))}, create_ref_idx=True) # seed for reproducibility torch.manual_seed(42) datamodule = DictModule(dataset_desc, lengths=[1.0]) model = Committor(layers=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset, descriptors_derivatives=d_desc_d_pos) trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) # fit trainer.fit(model, datamodule) # save outputs as a reference X = dataset_desc["data"] # this is to check other strategies ref_output = model(X) assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3))) # test errors try: # separate boundary with explicit derivatives model = Committor(layers=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=True, descriptors_derivatives=d_desc_d_pos) trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) trainer.fit(model, datamodule) except ValueError as e: print("[TEST LOG] Checked this error: ", e) # 3 ------------ Descriptors as input + SmartDerivatives ------------ # initialize smart derivatives, we do it explicitly to test different functionalities smart_derivatives = SmartDerivatives() smart_dataset = smart_derivatives.setup(dataset=dataset, descriptor_function=ComputeDistances, n_atoms=n_atoms, separate_boundary_dataset=separate_boundary_dataset) # seed for reproducibility torch.manual_seed(42) datamodule = DictModule(smart_dataset, lengths=[1.0]) model = Committor(layers=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset, descriptors_derivatives=smart_derivatives) trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) # fit trainer.fit(model, datamodule) # save outputs as a reference X = smart_dataset["data"] # this is to check other strategies ref_output = model(X) assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3))) # test errors try: # no ref_idx! wrong_dataset = DictDataset(data=smart_dataset['data'], labels=smart_dataset['labels'], weights=smart_dataset['weights']) wrong_datamodule = DictModule(wrong_dataset, lengths=[1.0]) trainer = lightning.Trainer( accelerator='cpu', callbacks=None, max_epochs=6, enable_progress_bar=False, enable_checkpointing=False, logger=False, limit_val_batches=0, num_sanity_val_steps=0, ) trainer.fit(model, wrong_datamodule) except ValueError as e: print("[TEST LOG] Checked this error: ", e) if __name__ == "__main__": test_committor() test_committor_with_derivatives()