Learning the committor with graph neural networks (GNNs)

Open in Colab

[1]:
# Colab setup
import os

if os.getenv("COLAB_RELEASE_TAG"):
    import subprocess
    subprocess.run('wget https://raw.githubusercontent.com/luigibonati/mlcolvar/main/colab_setup.sh', shell=True)
    cmd = subprocess.run('bash colab_setup.sh TUTORIAL', shell=True, stdout=subprocess.PIPE)
    print(cmd.stdout.decode('utf-8'))

Common variables

[2]:
T = 300

# Boltzmann factor in the RIGHT ENERGY UNITS!
kb = 0.0083144621 # kJ/mol
beta = 1/(kb*T)

Iter 0

Load data

[ ]:
import torch
from mlcolvar.cvs.committor.utils import compute_committor_weights
from mlcolvar.data import DictModule
from mlcolvar.io import create_dataset_from_trajectories, load_dataframe

# loading arguments for each file
load_args = [{'start' : 0, 'stop' : 10000, 'stride' : 2},
             {'start' : 0, 'stop' : 10000, 'stride' : 2}]

# create dataset from trajectory files, xyz and anything compatible with MDtraj can be used
dataset = create_dataset_from_trajectories(trajectories=["https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/A/traj_comp.xtc",
                                                         "https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/B/traj_comp.xtc"],
                                           topologies="https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/A/confAvac.gro", # with xyz use none
                                           cutoff=10.0,    # Angstrom
                                           system_selection='all and not type H', # this uses MDtraj selection syntax
                                           show_progress=False,
                                           load_args=load_args,
                                           lengths_conversion=10.0, # MDTraj uses nm by defualt, we use Angstroms
                                        )
print('Dataset info:\n', dataset, end="\n\n")


# load corresponding colvar files, the easiest thing is to print in synch trajectory and COLVAR files!
colvar = load_dataframe(["https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/A/COLVAR",
                        "https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/B/COLVAR"],
                        start=0,
                        stop=10000,
                        stride=2)


# initialize weights, here we have unbiased simulations so no bias
bias = torch.zeros(len(dataset))
dataset = compute_committor_weights(dataset=dataset,
                                    bias=bias,
                                    data_groups=[0,1],
                                    beta=beta)


# load dataset into a DictModule
datamodule = DictModule(dataset, lengths=[1])
print('Datamodule info:\n', datamodule)
/home/etrizio@iit.local/Bin/dev/mlcolvar/mlcolvar/data/graph/utils.py:64: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  graph_labels = torch.tensor( config.graph_labels, dtype=torch.get_default_dtype() )   if config.graph_labels is not None else None
Dataset info:
 DictDataset( "data_list": 10000,
            metadata={"atomic_numbers": [6, 7, 8],
                      "cutoff": 10.0,
                      "buffer": 0.0,
                      "used_idx": tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
                      "used_names": [ACE1-CH3, ACE1-C, ACE1-O, ALA2-N, ALA2-CA, ALA2-CB, ALA2-C, ALA2-O, NME3-N, NME3-C],
                      "data_type": graphs } )

Datamodule info:
 DictModule(dataset -> DictDataset( "data_list": 10000,
            metadata={"atomic_numbers": [6, 7, 8],
                      "cutoff": 10.0,
                      "buffer": 0.0,
                      "used_idx": tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
                      "used_names": [ACE1-CH3, ACE1-C, ACE1-O, ALA2-N, ALA2-CA, ALA2-CB, ALA2-C, ALA2-O, NME3-N, NME3-C],
                      "data_type": graphs } ),
             train_loader -> DictLoader(length=1, batch_size=10000, shuffle=True))

Initialize model

[ ]:
from mlcolvar.core.nn.graph import SchNetModel
from mlcolvar.cvs import Committor
import copy

# initialize GNN model, here we use SchNet
gnn_model = SchNetModel(n_out=1,
                        dataset_for_initialization=dataset,
                        n_bases=16,
                        n_layers=2,
                        n_filters=16,
                        n_hidden_channels=16,
                        aggr='min',
                        w_out_after_pool=True,
                    )

# model options
options = {'optimizer' : {'lr' : 1e-3},
           'lr_scheduler': {
            'scheduler': torch.optim.lr_scheduler.ExponentialLR,
            'gamma': 0.99995
        }}

# initialize committor model
model = Committor(model=gnn_model,
                  atomic_masses=torch.Tensor([12.011, 14.007, 15.999]), # C, N, O
                  alpha=100,
                  separate_boundary_dataset=False, # here we only have the boundary dataset
                  z_regularization=100,
                  z_threshold=5,
                  log_var=True,
                  gamma=1,
                  options=options)

# save the model sigmoid to switch easily between q and z
sigmoid = copy.deepcopy(model.sigmoid)
/home/etrizio@iit.local/Bin/miniconda3/envs/graph_mlcolvar_test_2.5/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.

Train model

[ ]:
from lightning import Trainer
from mlcolvar.utils.trainer import MetricsCallback

# define callbacks
metrics = MetricsCallback()

# initialize lightning trainer
trainer = Trainer(
    callbacks=[metrics],
    logger=False,
    enable_checkpointing=False,
    accelerator='cpu', # cpu for testing, use cuda!
    max_epochs=5, # small for testing, use 2/5k or more
    enable_model_summary=False,
    limit_val_batches=0,
    num_sanity_val_steps=0
)

# fit model on datamodule
trainer.fit(model, datamodule)

Check training metrics

[6]:
from mlcolvar.utils.plot import plot_metrics
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1,2,figsize=(10,3))
for ax in axs:
    plot_metrics(metrics.metrics,
                    keys=['train_loss', 'train_loss_var', 'train_loss_bound_A', 'train_loss_bound_B'],
                    colors=['fessa0', 'fessa3', 'fessa5', 'fessa6'],
                    yscale='log',
                    ax = ax)

# we zoom more in one plot
axs[1].set_ylim(1e-5, 1e-1)

plt.tight_layout()
../../_images/notebooks_examples_ex_GNN_committor_12_0.png

Check training results

[7]:
from mlcolvar.cvs.committor.utils import KolmogorovBias

fig, axs = plt.subplots(1,3,figsize=(13,3))

# compute and plot z value over training set
model.sigmoid = None
with torch.no_grad(): out_z = model(dataset.get_graph_inputs())

ax = axs[0]
cp = ax.scatter(colvar['phi'], colvar['psi'], c=out_z, s=5, cmap='fessa')
plt.colorbar(cp, ax=ax, label='z')
ax.set_ylabel("$\psi$")
ax.set_xlabel("$\phi$")



# compute and plot q value over training set
model.sigmoid = sigmoid
with torch.no_grad(): out_q = model(dataset.get_graph_inputs()).detach()

ax = axs[1]
cp = ax.scatter(colvar['phi'], colvar['psi'], c=out_q, s=5, cmap='fessa')
plt.colorbar(cp, ax=ax, label='q')
ax.set_ylabel("$\psi$")
ax.set_xlabel("$\phi$")



# compute and plot V_K value over training set
bias_model = KolmogorovBias(model, beta)
bias = bias_model(dataset.get_graph_inputs())

ax = axs[2]
cp = ax.scatter(colvar['phi'], colvar['psi'], c=bias, s=5, cmap='fessa')
plt.colorbar(cp, ax=ax, label='V$_K$')
ax.set_xlabel("$\phi$")
ax.set_ylabel("$\psi$")


plt.tight_layout()
../../_images/notebooks_examples_ex_GNN_committor_14_0.png

Export torchscript models to be used in PLUMED

[8]:
# export z model
model.sigmoid = None
traced_model = model.to_torchscript('gnn_model_iter_0_z.pt', method='trace')

# export q model
model.sigmoid = sigmoid
traced_model = model.to_torchscript('gnn_model_iter_0_q.pt', method='trace')


# we can also check the outputs coincide
torch.allclose(model(dataset.get_graph_inputs()), traced_model(dataset.get_graph_inputs()))
/home/etrizio@iit.local/Bin/dev/mlcolvar/mlcolvar/data/datamodule.py:322: UserWarning: Length of split at index 1 is 0. This might result in an empty dataset.
  warnings.warn(
[8]:
True

Run enhanced sampling simulations with PLUMED

Check simulations results

[18]:
fig, axs = plt.subplots(2,2,figsize=(10,6))

color_plot = 'gnn.node-z'

ax = axs[0, 0]
sampling = load_dataframe("https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/biased/iter_0/A/COLVAR", stop=-5)
cp = ax.scatter(sampling['phi'], sampling['psi'], c=sampling[color_plot], cmap='fessa', s=2)
plt.colorbar(cp, ax = ax, label=color_plot)
ax.set_xlabel('$\phi$')
ax.set_ylabel('$\psi$')
ax.set_title("From state A")

ax = axs[1, 0]
cp = ax.scatter(sampling['time'], sampling['phi'], c=sampling[color_plot], cmap='fessa', s=2)
plt.colorbar(cp, ax = ax, label=color_plot)
ax.set_xlabel('time [ps]')
ax.set_ylabel('$\phi$')

ax = axs[0, 1]
sampling = load_dataframe("https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/biased/iter_0/B/COLVAR", stop=-5)
cp = ax.scatter(sampling['phi'], sampling['psi'], c=sampling[color_plot], cmap='fessa', s=2)
plt.colorbar(cp, ax = ax, label=color_plot)
ax.set_xlabel('$\phi$')
ax.set_ylabel('$\psi$')
ax.set_title("From state B")

ax = axs[1, 1]
cp = ax.scatter(sampling['time'], sampling['phi'], c=sampling[color_plot], cmap='fessa', s=2)
plt.colorbar(cp, ax = ax, label=color_plot)
ax.set_xlabel('time [ps]')
ax.set_ylabel('$\phi$')


plt.tight_layout()
plt.show()
../../_images/notebooks_examples_ex_GNN_committor_19_0.png

Iter 1

Load data

[ ]:
# loading arguments for each file
load_args = [{'start' : 1000, 'stop' : 10000, 'stride' : 4},
             {'start' : 1000, 'stop' : 10000, 'stride' : 4},
             {'start' : 1000, 'stop' : 10000, 'stride' : 4}]

# create dataset from trajectory files, xyz and anything compatible with MDtraj can be used
dataset = create_dataset_from_trajectories(trajectories=["https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/A/traj_comp.xtc",
                                                         "https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/B/traj_comp.xtc",
                                                         "https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main//biased/iter_0/B/traj_comp.xtc"],
                                           topologies="https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/B/confAvac.gro", # with xyz use none
                                           cutoff=10.0,    # Angstrom
                                           system_selection='all and not type H', # this uses MDtraj selection syntax
                                           show_progress=False,
                                           load_args=load_args,
                                           lengths_conversion=10.0, # MDTraj uses nm by defualt, we use Angstroms
                                        )
print('Dataset info:\n', dataset, end="\n\n")


# load corresponding colvar files, the easiest thing is to print in synch trajectory and COLVAR files!
colvar = load_dataframe(["https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/A/COLVAR",
                        "https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/unbiased/B/COLVAR",
                        "https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/biased/iter_0/B/COLVAR"],
                        start=1000,
                        stop=10000,
                        stride=4)


# initialize weights, here we have unbiased simulations so no bias
colvar = colvar.fillna({'opes.bias': 0, 'bias' : 0})
bias = torch.Tensor(colvar['opes.bias'].values + colvar['bias'].values)
dataset = compute_committor_weights(dataset=dataset,
                                    bias=bias,
                                    data_groups=[0,1,2],
                                    beta=beta)


# load dataset into a DictModule
datamodule = DictModule(dataset, lengths=[1])
print('Datamodule info:\n', datamodule)
Dataset info:
 DictDataset( "data_list": 6750, metadata={ "atomic_numbers": [6, 7, 8], "cutoff": 10.0, "used_idx": tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), "used_names": [ACE1-CH3, ACE1-C, ACE1-O, ALA2-N, ALA2-CA, ALA2-CB, ALA2-C, ALA2-O, NME3-N, NME3-C], "data_type": graphs } )

Datamodule info:
 DictModule(dataset -> DictDataset( "data_list": 6750, metadata={ "atomic_numbers": [6, 7, 8], "cutoff": 10.0, "used_idx": tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), "used_names": [ACE1-CH3, ACE1-C, ACE1-O, ALA2-N, ALA2-CA, ALA2-CB, ALA2-C, ALA2-O, NME3-N, NME3-C], "data_type": graphs } ),
                     train_loader -> DictLoader(length=1, batch_size=6750, shuffle=True))

Initialize model

[ ]:
# initialize GNN model, here we use SchNet
gnn_model = SchNetModel(n_out=1,
                        cutoff=dataset.metadata['cutoff'],
                        atomic_numbers=dataset.metadata['atomic_numbers'],
                        n_bases=16,
                        n_layers=2,
                        n_filters=16,
                        n_hidden_channels=16,
                        aggr='min',
                        w_out_after_pool=True,
                    )

# model options
options = {'optimizer' : {'lr' : 1e-3},
           'lr_scheduler': {
            'scheduler': torch.optim.lr_scheduler.ExponentialLR,
            'gamma': 0.9999
        }}

# initialize committor model
model = Committor(model=gnn_model,
                  atomic_masses=torch.Tensor([12.011, 14.007, 15.999]), # C, N, O
                  alpha=100,
                  separate_boundary_dataset=True, # here we only have the boundary dataset
                  z_regularization=100,
                  z_threshold=5,
                  log_var=True,
                  gamma=1,
                  options=options)

# save the model sigmoid to switch easily between q and z
sigmoid = copy.deepcopy(model.sigmoid)
/home/etrizio@iit.local/Bin/miniconda3/envs/graph_mlcolvar_test_2.5/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.

Train model

[ ]:
# define callbacks
metrics = MetricsCallback()

# initialize lightning trainer
trainer = Trainer(
    callbacks=[metrics],
    logger=False,
    enable_checkpointing=False,
    accelerator='cpu', # cpu for testing, use cuda!
    max_epochs=5, # small for testing, use 2/5k or more
    enable_model_summary=False,
    limit_val_batches=0,
    num_sanity_val_steps=0
)

# fit model on datamodule
trainer.fit(model, datamodule)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/etrizio@iit.local/Bin/miniconda3/envs/graph_mlcolvar_test_2.5/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/home/etrizio@iit.local/Bin/dev/mlcolvar/mlcolvar/core/loss/committor_loss.py:251: UserWarning: Using GNN-based models it may be better to set separate_boundary_dataset to False
  warnings.warn("Using GNN-based models it may be better to set separate_boundary_dataset to False")
`Trainer.fit` stopped: `max_epochs=5000` reached.

Check training metrics

[18]:
fig, axs = plt.subplots(1,2,figsize=(10,3))
for ax in axs:
    plot_metrics(metrics.metrics,
                    keys=['train_loss', 'train_loss_var', 'train_loss_bound_A', 'train_loss_bound_B'],
                    colors=['fessa0', 'fessa3', 'fessa5', 'fessa6'],
                    yscale='log',
                    ax = ax)

# we zoom more in one plot
axs[1].set_ylim(1e-5, 1e-1)

plt.tight_layout()
../../_images/notebooks_examples_ex_GNN_committor_28_0.png

Check training results

[19]:
fig, axs = plt.subplots(1,3,figsize=(13,3))

# compute and plot z value over training set
model.sigmoid = None
with torch.no_grad(): out_z = model(dataset.get_graph_inputs())

ax = axs[0]
cp = ax.scatter(colvar['phi'], colvar['psi'], c=out_z, s=5, cmap='fessa')
plt.colorbar(cp, ax=ax, label='z')
ax.set_ylabel("$\psi$")
ax.set_xlabel("$\phi$")



# compute and plot q value over training set
model.sigmoid = sigmoid
with torch.no_grad(): out_q = model(dataset.get_graph_inputs()).detach()

ax = axs[1]
cp = ax.scatter(colvar['phi'], colvar['psi'], c=out_q, s=5, cmap='fessa')
plt.colorbar(cp, ax=ax, label='q')
ax.set_ylabel("$\psi$")
ax.set_xlabel("$\phi$")



# compute and plot V_K value over training set
bias_model = KolmogorovBias(model, beta)
bias = bias_model(dataset.get_graph_inputs())

ax = axs[2]
cp = ax.scatter(colvar['phi'], colvar['psi'], c=bias, s=5, cmap='fessa')
plt.colorbar(cp, ax=ax, label='V$_K$')
ax.set_xlabel("$\phi$")
ax.set_ylabel("$\psi$")


plt.tight_layout()
../../_images/notebooks_examples_ex_GNN_committor_30_0.png

Export torchscript models to be used in PLUMED

[287]:
# export z model
model.sigmoid = None
traced_model = model.to_torchscript('gnn_model_iter_1_z.pt', method='trace')

# export q model
model.sigmoid = sigmoid
traced_model = model.to_torchscript('gnn_model_iter_1_q.pt', method='trace')


# we can also check the outputs coincide
torch.allclose(model(dataset.get_graph_inputs()), traced_model(dataset.get_graph_inputs()))
/home/etrizio@iit.local/Bin/dev/mlcolvar/mlcolvar/data/datamodule.py:322: UserWarning: Length of split at index 1 is 0. This might result in an empty dataset.
  warnings.warn(
[287]:
True

Run enhanced sampling simulations with PLUMED

Check simulations results

[19]:
fig, axs = plt.subplots(2,2,figsize=(10,6))

color_plot = 'gnn.node-z'

ax = axs[0, 0]
sampling = load_dataframe("https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/biased/iter_1/A/COLVAR", stop=-5)
cp = ax.scatter(sampling['phi'], sampling['psi'], c=sampling[color_plot], cmap='fessa', s=2)
plt.colorbar(cp, ax = ax, label=color_plot)
ax.set_xlabel('$\phi$')
ax.set_ylabel('$\psi$')
ax.set_title("From state A")


ax = axs[1, 0]
cp = ax.scatter(sampling['time'], sampling['phi'], c=sampling[color_plot], cmap='fessa', s=2)
plt.colorbar(cp, ax = ax, label=color_plot)
ax.set_xlabel('time [ps]')
ax.set_ylabel('$\phi$')

ax = axs[0, 1]
sampling = load_dataframe("https://github.com/EnricoTrizio/alanine_gnn_committor_data/raw/refs/heads/main/biased/iter_1/B/COLVAR", stop=-5)
cp = ax.scatter(sampling['phi'], sampling['psi'], c=sampling[color_plot], cmap='fessa', s=2)
plt.colorbar(cp, ax = ax, label=color_plot)
ax.set_xlabel('$\phi$')
ax.set_ylabel('$\psi$')
ax.set_title("From state B")


ax = axs[1, 1]
cp = ax.scatter(sampling['time'], sampling['phi'], c=sampling[color_plot], cmap='fessa', s=2)
plt.colorbar(cp, ax = ax, label=color_plot)
ax.set_xlabel('time [ps]')
ax.set_ylabel('$\phi$')


plt.tight_layout()
plt.show()
../../_images/notebooks_examples_ex_GNN_committor_35_0.png