import torch
import lightning
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward, Normalization
from mlcolvar.core.loss import CommittorLoss
from mlcolvar.core.nn.utils import Custom_Sigmoid
__all__ = ["Committor"]
[docs]
class Committor(BaseCV, lightning.LightningModule):
"""Base class for data-driven learning of committor function.
The committor function q is expressed as the output of a neural network optimized with a self-consistent
approach based on the Kolmogorov's variational principle for the committor and on the imposition of its boundary conditions (see Refs. [1,2]).
It is also possible to use an approximated variational approach without explicit dependence on the atomic coordinates (see Ref. [3]).
**Data**: for training it requires a DictDataset with the keys 'data', 'labels' and 'weights'
**Loss**: Minimize Kolmogorov's variational functional of q and impose boundary condition on the metastable states (CommittorLoss) from Refs. [1,2].
It is also possible to use an approximated variational approach without explicit dependence on the atomic coordinates
References
----------
.. [1] P. Kang, E. Trizio, and M. Parrinello, "Computing the committor using the committor to study the transition state ensemble", Nat. Comput. Sci., 2024, DOI: 10.1038/s43588-024-00645-0
.. [2] E. Trizio, P. Kang, and M. Parrinello, "Everything everywhere all at once: a probability-based enhanced sampling approach to rare events", Nat. Comput. Sci., 2025, DOI: 10.1038/s43588-025-00799-5
.. [3] E. Trizio, G. Rossi, and M. Parrinello, "Ceci n'est pas un committor: Efficient sampling via approximated committor functions", J Chem. Phys., 2026, DOI: 10.1063/5.0331622
See also
--------
mlcolvar.cvs.committor.utils.compute_committor_weights
Utils to compute the appropriate weights for the training set
mlcolvar.cvs.committor.utils.initialize_committor_masses
Utils to initialize the masses tensor for the training
mlcolvar.core.loss.CommittorLoss
Kolmogorov's variational optimization of committor and imposition of boundary conditions
mlcolvar.core.loss.utils.SmartDerivatives
Class to optimize the gradients calculation imporving speed and memory efficiency.
"""
BLOCKS = ["norm_in", "nn", "sigmoid"]
[docs]
def __init__(
self,
layers: list,
alpha: float,
atomic_masses: torch.Tensor = None,
gamma: float = 10000,
delta_f: float = 0,
cell: float = None,
separate_boundary_dataset: bool = True,
descriptors_derivatives: torch.nn.Module = None,
log_var: bool = False,
use_gradients_wrt_positions: bool = True,
z_regularization: float = 0.0,
z_threshold: float = None,
n_dim: int = None,
norm_in: bool = False,
options: dict = None,
**kwargs,
):
"""Define a NN-based committor model
Parameters
----------
layers : list
Number of neurons per layer
alpha : float
Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B)
atomic_masses : torch.Tensor
List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z, by default None.
The mlcolvar.cvs.committor.utils.initialize_committor_masses can be used to simplify this.
If the position-less loss is used, this must be set to None.
gamma : float, optional
Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound), by default 10000
delta_f : float, optional
Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0.
State B is supposed to be higher in energy.
cell : float, optional
CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, by default None
separate_boundary_dataset : bool, optional
Switch to exculde boundary condition labeled data from the variational loss, by default True
descriptors_derivatives : torch.nn.Module, optional
`SmartDerivatives` object to save memory and time when using descriptors.
See also mlcolvar.core.loss.committor_loss.SmartDerivatives
log_var : bool, optional
Switch to minimize the log of the variational functional, by default False.
use_gradients_wrt_positions : bool, optional
Whether to use gradients with respect to positions as prescribed in the original Kolmogorov variational functional, by default True.
Set to false to use the approximated variational principle defined in Ref. [3] without explicit dependence on the atomic coordinates derivatives.
z_regularization : float, optional
Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'.
The magnitude of the regularization is scaled by the given number, by default 0.0
z_threshold : float, optional
Sets a maximum threshold for the z value during the training, by default None.
The magnitude of the regularization term is scaled via the `z_regularization` key.
n_dim : int
Number of dimensions, by default None.
If None, it defaults to 3 for the position-based loss and to 1 for the position-less loss.
norm_in : bool
Whether to normalize the input of the NN model, by default False.
options : dict[str, Any], optional
Options for the building blocks of the model, by default {}.
Available blocks: ['nn'] .
"""
super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs)
if use_gradients_wrt_positions and atomic_masses is None:
raise ValueError("atomic_masses must be provided when using Kolmogorov variational functional (use_gradients_wrt_positions is True)")
elif not use_gradients_wrt_positions:
if atomic_masses is not None:
raise ValueError("atomic_masses must be None when using approximated variational principle (use_gradients_wrt_positions is False)")
if descriptors_derivatives is not None:
raise ValueError("descriptors_derivatives must be None when using approximated variational principle (use_gradients_wrt_positions is False)")
# ======= LOSS =======
self.loss_fn = CommittorLoss(alpha=alpha,
atomic_masses=atomic_masses,
gamma=gamma,
delta_f=delta_f,
cell=cell,
separate_boundary_dataset=separate_boundary_dataset,
descriptors_derivatives=descriptors_derivatives,
log_var=log_var,
use_gradients_wrt_positions=use_gradients_wrt_positions,
z_regularization=z_regularization,
z_threshold=z_threshold,
n_dim=n_dim
)
# ======= OPTIONS =======
# parse and sanitize
options = self.parse_options(options)
# ======= BLOCKS =======
# Initialize norm_in
o = "norm_in"
if norm_in and (options[o] is not False) and (options[o] is not None):
self.norm_in = Normalization(self.in_features, **options[o])
# initialize NN turning
o = "nn"
# set default activation to tanh
if "activation" not in options[o]:
options[o]["activation"] = "tanh"
self.nn = FeedForward(layers, **options[o])
# separately add sigmoid activation on last layer, this way it can be deactived
o = "sigmoid"
if (options[o] is not False) and (options[o] is not None):
self.sigmoid = Custom_Sigmoid(**options[o])
def forward_nn(self, x):
if self.preprocessing is not None:
x = self.preprocessing(x)
if self.norm_in is not None:
x = self.norm_in(x)
z = self.nn(x)
return z
[docs]
def training_step(self, train_batch, batch_idx):
torch.set_grad_enabled(True)
"""Compute and return the training loss and record metrics."""
# =================get data===================
x = train_batch["data"]
# check data are have shape (n_data, -1)
x = x.reshape((x.shape[0], -1))
x.requires_grad = True
labels = train_batch["labels"]
weights = train_batch["weights"]
try:
ref_idx = train_batch["ref_idx"]
except KeyError:
ref_idx = None
# =================forward====================
z = self.forward_nn(x)
if self.sigmoid is not None:
q = self.sigmoid(z)
else:
q = z
# ===================loss=====================
if self.training:
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn(
x, z, q, labels, weights, ref_idx
)
else:
loss, loss_var, loss_bound_A, loss_bound_B = self.loss_fn(
x, z, q, labels, weights, ref_idx
)
# ====================log=====================+
name = "train" if self.training else "valid"
self.log(f"{name}_loss", loss, on_epoch=True)
self.log(f"{name}_loss_var", loss_var, on_epoch=True)
self.log(f"{name}_loss_bound_A", loss_bound_A, on_epoch=True)
self.log(f"{name}_loss_bound_B", loss_bound_B, on_epoch=True)
return loss
def test_committor():
from mlcolvar.data import DictDataset, DictModule
from mlcolvar.cvs.committor.utils import initialize_committor_masses, KolmogorovBias
torch.manual_seed(42)
# create two fake atoms and use their fake positions
atomic_masses = initialize_committor_masses(atom_types=[0,1], masses=[15.999, 1.008])
# create dataset
samples = 20
X = torch.randn((4*samples, 6))
# create labels
y = torch.zeros(X.shape[0])
y[samples:] += 1
y[int(2*samples):] += 1
y[int(3*samples):] += 1
# create weights
w = torch.ones(X.shape[0])
dataset = DictDataset({"data": X, "labels": y, "weights": w})
datamodule = DictModule(dataset, lengths=[1])
# train model
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
# dataset separation
ref_out = torch.Tensor([[0.6544],[0.6197],[0.5898],[0.5733],[0.6533],[0.5534],[0.5616],[0.6202],[0.5582],[0.5430],
[0.6364],[0.5984],[0.6382],[0.5967],[0.6440],[0.5697],[0.6061],[0.6010],[0.6399],[0.6172],
[0.6164],[0.6528],[0.6583],[0.6236],[0.6641],[0.5834],[0.5832],[0.6204],[0.6409],[0.6558],
[0.5891],[0.5879],[0.5890],[0.6583],[0.6577],[0.6467],[0.6405],[0.6590],[0.6463],[0.5581],
[0.6154],[0.6368],[0.6196],[0.5162],[0.5998],[0.6041],[0.5513],[0.6476],[0.5742],[0.6162],
[0.6462],[0.6371],[0.5295],[0.6148],[0.5999],[0.5870],[0.6352],[0.6145],[0.5708],[0.4992],
[0.6539],[0.6014],[0.6470],[0.6299],[0.6254],[0.5268],[0.6286],[0.6056],[0.6077],[0.6055],
[0.5861],[0.5991],[0.6449],[0.6500],[0.6295],[0.5627],[0.6269],[0.6392],[0.5961],[0.6694]])
ref_bias = torch.Tensor([-6.2043, -6.8591, -7.7645, -7.8704, -5.8342, -7.5036, -7.8780, -6.9957,
-7.8679, -7.7473, -7.2451, -7.6833, -6.7631, -7.7863, -6.6693, -7.6212,
-7.6929, -7.5685, -6.6894, -7.4857, -7.5187, -4.9488, -6.4961, -7.3898,
-6.0350, -7.8837, -7.8748, -7.2552, -7.1221, -5.8647, -7.9190, -7.7184,
-7.7073, -4.7898, -5.4073, -5.9113, -6.5451, -4.7149, -5.8899, -7.7421,
-7.3999, -7.3456, -7.3005, -7.5067, -7.7396, -7.7099, -7.8664, -6.3275,
-7.8864, -7.7243, -6.4288, -5.7041, -7.9351, -7.1991, -7.7027, -7.7947,
-6.7121, -7.6094, -7.9009, -7.0479, -5.2398, -7.8241, -5.8642, -7.0701,
-7.0348, -7.2577, -6.6142, -7.6322, -7.3279, -7.6393, -7.8608, -7.7037,
-6.6949, -6.3947, -7.2246, -7.7009, -6.7359, -7.2186, -7.7849, -5.6882])
model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1)
trainer.fit(model, datamodule)
out = model(X)
out.sum().backward()
assert( torch.allclose(out, ref_out, atol=1e-3) )
bias_model = KolmogorovBias(input_model=model, beta=1, epsilon=1e-6, lambd=1)
bias = bias_model(X)
assert( torch.allclose(bias, ref_bias, atol=1e-3) )
# naive whole dataset
ref_out = torch.Tensor([[0.1206],[0.0688],[0.0941],[0.1026],[0.0739],[0.1279],[0.1115],[0.0629],[0.0994],[0.1012],
[0.0886],[0.1218],[0.0785],[0.0704],[0.0948],[0.1193],[0.0877],[0.0964],[0.0774],[0.0874],
[0.0948],[0.0636],[0.0869],[0.0664],[0.0659],[0.0927],[0.0654],[0.0927],[0.0743],[0.0787],
[0.0802],[0.1074],[0.1105],[0.0595],[0.0693],[0.0620],[0.0688],[0.0669],[0.0591],[0.0986],
[0.0706],[0.1180],[0.0894],[0.1030],[0.1012],[0.0606],[0.1408],[0.0766],[0.1063],[0.1049],
[0.0749],[0.0588],[0.1177],[0.1127],[0.1090],[0.0806],[0.0954],[0.0799],[0.1048],[0.1378],
[0.0783],[0.1384],[0.0689],[0.0649],[0.0983],[0.1548],[0.0778],[0.0934],[0.0858],[0.1203],
[0.1073],[0.1139],[0.0716],[0.0988],[0.0918],[0.1109],[0.0918],[0.0928],[0.1070],[0.0742]])
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, separate_boundary_dataset=False)
trainer.fit(model, datamodule)
out = model(X)
out.sum().backward()
assert( torch.allclose(out, ref_out, atol=1e-3) )
# test log loss
ref_out = torch.Tensor([[0.7287],[0.6505],[0.5594],[0.6758],[0.7482],[0.6804],[0.7313],[0.6762],[0.6873],[0.6267],
[0.6362],[0.8129],[0.5853],[0.5262],[0.6359],[0.5263],[0.4839],[0.7291],[0.6884],[0.6375],
[0.6231],[0.6997],[0.5906],[0.6247],[0.5876],[0.7198],[0.6356],[0.5933],[0.6229],[0.7093],
[0.5618],[0.5005],[0.7924],[0.6965],[0.6540],[0.5476],[0.6151],[0.7042],[0.6190],[0.5362],
[0.6275],[0.5959],[0.7194],[0.6122],[0.4873],[0.6653],[0.6741],[0.7011],[0.7207],[0.5863],
[0.6040],[0.7643],[0.6696],[0.6424],[0.6886],[0.5775],[0.6620],[0.7104],[0.7517],[0.7387],
[0.7714],[0.5825],[0.6442],[0.5796],[0.6131],[0.5923],[0.7023],[0.5730],[0.7307],[0.6404],
[0.5780],[0.6850],[0.5959],[0.6718],[0.6626],[0.6068],[0.7319],[0.5498],[0.6772],[0.5846]])
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, log_var=True)
trainer.fit(model, datamodule)
out = model(X)
out.sum().backward()
assert( torch.allclose(out, ref_out, atol=1e-3) )
# test z regularization
ref_out = torch.Tensor([[0.2878],[0.1591],[0.1665],[0.1166],[0.1349],[0.1053],[0.1544],[0.1113],[0.1435],[0.1232],
[0.1130],[0.1261],[0.1726],[0.2098],[0.2091],[0.1407],[0.1942],[0.1400],[0.1382],[0.1630],
[0.1573],[0.1742],[0.1613],[0.1289],[0.1703],[0.1390],[0.1184],[0.2557],[0.1520],[0.1328],
[0.2220],[0.2254],[0.1823],[0.1426],[0.1744],[0.2594],[0.1105],[0.1390],[0.1557],[0.1985],
[0.1340],[0.1971],[0.1429],[0.1270],[0.2239],[0.1134],[0.1999],[0.1416],[0.1707],[0.2238],
[0.2054],[0.1560],[0.2357],[0.2971],[0.1445],[0.1906],[0.2130],[0.1457],[0.1382],[0.1432],
[0.1337],[0.1444],[0.1603],[0.1396],[0.2043],[0.1964],[0.1459],[0.2243],[0.1930],[0.1893],
[0.2634],[0.1868],[0.1340],[0.2483],[0.1550],[0.1559],[0.1614],[0.2020],[0.1270],[0.2555]])
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, z_regularization=100, z_threshold=0.000001)
trainer.fit(model, datamodule)
out = model(X)
out.sum().backward()
assert( torch.allclose(out, ref_out, atol=1e-3) )
# test position-less loss
ref_out = torch.Tensor([[0.2318],[0.2119],[0.3039],[0.2349],[0.1933],[0.2506],[0.1453],[0.2849],[0.2042],[0.2514],
[0.3391],[0.1043],[0.3083],[0.3091],[0.2147],[0.4049],[0.4225],[0.1488],[0.2421],[0.2429],
[0.2354],[0.1662],[0.3195],[0.3682],[0.2881],[0.2027],[0.3813],[0.2461],[0.2892],[0.2725],
[0.2833],[0.3431],[0.1060],[0.2795],[0.2566],[0.3266],[0.3747],[0.3010],[0.2916],[0.3081],
[0.3136],[0.2971],[0.1736],[0.2491],[0.3451],[0.3594],[0.1713],[0.2217],[0.1426],[0.2170],
[0.2296],[0.1287],[0.1386],[0.1911],[0.1898],[0.2731],[0.1899],[0.1999],[0.1325],[0.1380],
[0.1678],[0.3036],[0.2935],[0.3828],[0.2623],[0.2214],[0.2149],[0.2332],[0.1143],[0.2179],
[0.2237],[0.1641],[0.3423],[0.2643],[0.2220],[0.2521],[0.1774],[0.3679],[0.2753],[0.1834]])
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
model = Committor(layers=[6, 4, 2, 1], atomic_masses=None, alpha=1e-1, use_gradients_wrt_positions=False)
trainer.fit(model, datamodule)
out = model(X)
print(out)
out.sum().backward()
assert( torch.allclose(out, ref_out, atol=1e-3) )
# test z_regularization errors
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
for z_regularization, z_threshold in zip([10, 0, -1, 10],
[None, 10, 1, -1]):
try:
model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, z_regularization=z_regularization, z_threshold=z_threshold, n_dim=2)
trainer.fit(model, datamodule)
except ValueError as e:
print("[TEST LOG] Checked this error: ", e)
# test dimension error
try:
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, z_regularization=10, z_threshold=1, n_dim=2)
trainer.fit(model, datamodule)
except RuntimeError as e:
print("[TEST LOG] Checked this error: ", e)
def test_committor_with_derivatives():
from mlcolvar.cvs.committor.utils import initialize_committor_masses
from mlcolvar.data import DictModule, DictDataset
from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives, compute_descriptors_derivatives
from mlcolvar.core.transform import PairwiseDistances
torch.manual_seed(42)
n_atoms = 10
kT = 2.49432
# input positions for alanine example
ref_pos = torch.Tensor([[ 1.2980, 0.5370, 1.3370, 1.3270, 0.5710, 1.1960, 1.4110, 0.5070, 1.1310, 1.2520, 0.6710, 1.1440,
1.2490, 0.6890, 0.9990, 1.1270, 0.6130, 0.9550, 1.2340, 0.8420, 0.9810, 1.1860, 0.9140, 1.0700,
1.2790, 0.8870, 0.8630, 1.2550, 1.0230, 0.8240 ],
[ 2.7530, 0.7150, 0.5170, 2.8460, 0.6150, 0.5780, 2.9520, 0.6560, 0.6220, 2.8150, 0.4870, 0.5730,
2.9100, 0.3830, 0.6150, 2.9310, 0.3890, 0.7690, 2.8520, 0.2450, 0.5830, 2.7300, 0.2380, 0.5550,
2.9420, 0.1390, 0.5840, 2.9030, -0.0030, 0.5690 ],
[ 0.4830, 2.5610, 2.9980, 0.5620, 2.5410, 2.8660, 0.5080, 2.4950, 2.7660, 0.6960, 2.5590, 2.8790,
0.8060, 2.5410, 2.7750, 0.7890, 2.6570, 2.6680, 0.9450, 2.5390, 2.8400, 0.9620, 2.5380, 2.9610,
1.0510, 2.5430, 2.7590, 1.1860, 2.5410, 2.7990 ],
[ 1.0680, 0.1770, 0.1670, 0.9560, 0.2290, 0.0920, 0.9320, 0.1730, -0.0070, 0.8770, 0.3280, 0.1460,
0.7710, 0.4040, 0.0760, 0.7230, 0.5180, 0.1660, 0.8270, 0.4640, -0.0530, 0.9010, 0.5650, -0.0450,
0.7790, 0.4160, -0.1670, 0.8260, 0.4500, -0.2950 ],
[ 2.4600, 0.5670, 2.4940, 2.6050, 0.5640, 2.5060, 2.6660, 0.4630, 2.5020, 2.6640, 0.6830, 2.5220,
2.8040, 0.7250, 2.5200, 2.8880, 0.6370, 2.6190, 2.8690, 0.7270, 2.3820, 2.9600, 0.8080, 2.3570,
2.8260, 0.6310, 2.3010, 2.8630, 0.6170, 2.1580 ]]
)
# weights for inputs
ref_weights = torch.Tensor([1.4809, 0.0736, 0.3693, 0.1849, 0.0885])
# initialize dataset with positions
dataset = DictDataset({"data": ref_pos, "weights": ref_weights, "labels": torch.arange((len(ref_pos)))})
# initialize descriptors calculations: all pairwise distances
ComputeDistances = PairwiseDistances(n_atoms=10,
PBC=False,
cell=[1, 1, 1],
scaled_coords=False)
# create friction tensor
masses = initialize_committor_masses(atom_types=[0,0,1,2,0,0,0,1,2,0],
masses=[ 12.011, 12.011, 15.999, 14.0067, 12.011, 12.011, 12.011, 15.999, 14.0067, 12.011])
# --------------------------------- TRAIN MODELS ---------------------------------
# Train the models: positions as input, desc as input with smartderivatives and passing derivatives
for separate_boundary_dataset in [False, True]:
# 1 ------------ Positions as input ------------
# initialize datamodule
torch.manual_seed(42)
datamodule = DictModule(dataset, lengths=[1.0])
# seed for reproducibility
model = Committor(layers=[45, 20, 1],
atomic_masses=masses,
alpha=1,
separate_boundary_dataset=separate_boundary_dataset)
# here we use the preprocessing
model.preprocessing = ComputeDistances
trainer = lightning.Trainer(
accelerator='cpu',
callbacks=None,
max_epochs=6,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
limit_val_batches=0,
num_sanity_val_steps=0,
)
# fit
trainer.fit(model, datamodule)
# save outputs as a reference
X = dataset["data"]
# this is to check other strategies
ref_output = model(X)
if separate_boundary_dataset:
ref_output_check = torch.Tensor([[0.4759],
[0.4765],
[0.4828],
[0.4786],
[0.4725]])
else:
ref_output_check = torch.Tensor([[0.4756],
[0.4762],
[0.4825],
[0.4783],
[0.4723]])
assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3)))
if not separate_boundary_dataset:
# 2 ------------ Descriptors as input + explicit pass derivatives ------------
# get descriptor and their derivatives
pos, desc, d_desc_d_pos = compute_descriptors_derivatives(dataset=dataset,
descriptor_function=ComputeDistances,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset)
dataset_desc = DictDataset({"data": desc, "weights": ref_weights, "labels": torch.arange((len(ref_pos)))}, create_ref_idx=True)
# seed for reproducibility
torch.manual_seed(42)
datamodule = DictModule(dataset_desc, lengths=[1.0])
model = Committor(layers=[45, 20, 1],
atomic_masses=masses,
alpha=1,
separate_boundary_dataset=separate_boundary_dataset,
descriptors_derivatives=d_desc_d_pos)
trainer = lightning.Trainer(
accelerator='cpu',
callbacks=None,
max_epochs=6,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
limit_val_batches=0,
num_sanity_val_steps=0,
)
# fit
trainer.fit(model, datamodule)
# save outputs as a reference
X = dataset_desc["data"]
# this is to check other strategies
ref_output = model(X)
assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3)))
# test errors
try:
# separate boundary with explicit derivatives
model = Committor(layers=[45, 20, 1],
atomic_masses=masses,
alpha=1,
separate_boundary_dataset=True,
descriptors_derivatives=d_desc_d_pos)
trainer = lightning.Trainer(
accelerator='cpu',
callbacks=None,
max_epochs=6,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
limit_val_batches=0,
num_sanity_val_steps=0,
)
trainer.fit(model, datamodule)
except ValueError as e:
print("[TEST LOG] Checked this error: ", e)
# 3 ------------ Descriptors as input + SmartDerivatives ------------
# initialize smart derivatives, we do it explicitly to test different functionalities
smart_derivatives = SmartDerivatives()
smart_dataset = smart_derivatives.setup(dataset=dataset,
descriptor_function=ComputeDistances,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset)
# seed for reproducibility
torch.manual_seed(42)
datamodule = DictModule(smart_dataset, lengths=[1.0])
model = Committor(layers=[45, 20, 1],
atomic_masses=masses,
alpha=1,
separate_boundary_dataset=separate_boundary_dataset,
descriptors_derivatives=smart_derivatives)
trainer = lightning.Trainer(
accelerator='cpu',
callbacks=None,
max_epochs=6,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
limit_val_batches=0,
num_sanity_val_steps=0,
)
# fit
trainer.fit(model, datamodule)
# save outputs as a reference
X = smart_dataset["data"]
# this is to check other strategies
ref_output = model(X)
assert( (torch.allclose(ref_output, ref_output_check, atol=1e-3)))
# test errors
try:
# no ref_idx!
wrong_dataset = DictDataset(data=smart_dataset['data'], labels=smart_dataset['labels'], weights=smart_dataset['weights'])
wrong_datamodule = DictModule(wrong_dataset, lengths=[1.0])
trainer = lightning.Trainer(
accelerator='cpu',
callbacks=None,
max_epochs=6,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
limit_val_batches=0,
num_sanity_val_steps=0,
)
trainer.fit(model, wrong_datamodule)
except ValueError as e:
print("[TEST LOG] Checked this error: ", e)