#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Committor function Loss Function
"""
__all__ = ["CommittorLoss", "committor_loss"]
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
import torch
from typing import Tuple, Union
from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives
import torch_geometric
import warnings
from mlcolvar.utils._code import scatter_sum
# =============================================================================
# LOSS FUNCTIONS
# =============================================================================
[docs]
class CommittorLoss(torch.nn.Module):
"""Compute a loss function based on Kolmogorov's variational principle (or its approximation) for the determination of the committor function"""
[docs]
def __init__(self,
alpha: float,
atomic_masses: Union[torch.Tensor, None],
gamma: float = 10000.0,
delta_f: float = 0.0,
separate_boundary_dataset : bool = True,
descriptors_derivatives : Union[SmartDerivatives, torch.Tensor] = None,
log_var: bool = False,
use_gradients_wrt_positions: bool = True,
z_regularization: float = 0.0,
z_threshold: float = None,
n_dim : int = None,
):
"""Compute Kolmogorov's variational principle loss and impose boundary conditions on the metastable states
Parameters
----------
alpha : float
Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B)
atomic_masses : torch.Tensor, optional
List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z, by default None.
The mlcolvar.cvs.committor.utils.initialize_committor_masses can be used to simplify this.
If the position-less loss is used, this must be set to None.
gamma : float, optional
Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound), by default 10000
delta_f : float, optional
Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0.
State B is supposed to be higher in energy.
separate_boundary_dataset : bool, optional
Switch to exculde boundary condition labeled data from the variational loss, by default True
descriptors_derivatives : Union[SmartDerivatives, torch.Tensor], optional
Derivatives of descriptors wrt atomic positions (if used) to speed up calculation of gradients, by default None.
Can be either:
- A `SmartDerivatives` object to save both memory and time, see also mlcolvar.core.loss.committor_loss.SmartDerivatives
- A torch.Tensor with the derivatives to save time, memory-wise could be less efficient
ref_idx: torch.Tensor, optional
Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling
when descriptors derivatives are provided, by default None.
Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset.
See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives
log_var : bool, optional
Switch to minimize the log of the variational functional, by default False.
use_gradients_wrt_positions : bool, optional
Whether to use gradients with respect to positions as prescribed in the original Kolmogorov variational functional, by default True.
z_regularization : float, optional
Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'.
The magnitude of the regularization is scaled by the given number, by default 0.0
z_threshold : float, optional
Sets a maximum threshold for the z value during the training, by default None.
The magnitude of the regularization term is scaled via the `z_regularization` key.
n_dim : int
Number of dimensions, by default None.
If None, it defaults to 3 for the position-based loss and to 1 for the position-less loss.
"""
super().__init__()
self.register_buffer("atomic_masses", atomic_masses)
self.alpha = alpha
self.gamma = gamma
self.delta_f = delta_f
self.descriptors_derivatives = descriptors_derivatives
self.separate_boundary_dataset = separate_boundary_dataset
self.log_var = log_var
self.use_gradients_wrt_positions = use_gradients_wrt_positions
self.z_regularization = z_regularization
self.z_threshold = z_threshold
self.n_dim = n_dim
[docs]
def forward(self,
x: Union[torch.Tensor, torch_geometric.data.Batch],
z: torch.Tensor,
q: torch.Tensor,
labels: torch.Tensor,
w: torch.Tensor,
ref_idx: torch.Tensor = None,
create_graph: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Committor loss forward pass
Parameters
----------
x : torch.Tensor or torch_geometric.data.Batch
Model input, i.e., either positions or descriptors if using descriptors_derivatives
z : torch.Tensor
Model unactivated output, i.e., z value
q : torch.Tensor
Model final output, i.e., committor value
labels : torch.Tensor
Input labels
w : torch.Tensor
Input weights
ref_idx : torch.Tensor, optional
Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling
when descriptors derivatives are provided, by default None.
Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset.
See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives
create_graph : bool, optional
Whether to create the graph during the computation for backpropagation, by default True
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Total loss and its components, i.e., variational, boundary A, and boundary B
"""
return committor_loss(x=x,
z=z,
q=q,
labels=labels,
w=w,
atomic_masses=self.atomic_masses,
alpha=self.alpha,
gamma=self.gamma,
delta_f=self.delta_f,
create_graph=create_graph,
separate_boundary_dataset=self.separate_boundary_dataset,
descriptors_derivatives=self.descriptors_derivatives,
log_var=self.log_var,
use_gradients_wrt_positions=self.use_gradients_wrt_positions,
z_regularization=self.z_regularization,
z_threshold=self.z_threshold,
ref_idx=ref_idx,
n_dim=self.n_dim
)
def committor_loss(x: torch.Tensor,
z: torch.Tensor,
q: torch.Tensor,
labels: torch.Tensor,
w: torch.Tensor,
alpha: float,
atomic_masses: Union[torch.Tensor, None] = None,
gamma: float = 10000,
delta_f: float = 0,
create_graph: bool = True,
separate_boundary_dataset: bool = True,
descriptors_derivatives: Union[SmartDerivatives, torch.Tensor] = None,
log_var: bool = False,
use_gradients_wrt_positions: bool = True,
z_regularization: float = 0.0,
z_threshold : float = None,
ref_idx: torch.Tensor = None,
n_dim : int = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute variational loss for committor optimization with boundary conditions
Parameters
----------
x : torch.Tensor
Input of the NN
z : torch.Tensor
z value z(x), it is the unactivated output of NN
q : torch.Tensor
Committor guess q(x), it is the output of NN
labels : torch.Tensor
Labels for states, A and B states for boundary conditions
w : torch.Tensor
Reweighing factors to Boltzmann distribution. This should depend on the simulation in which the data were collected.
alpha : float
Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B)
atomic_masses : torch.Tensor, optional
List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z, by default None.
The mlcolvar.cvs.committor.utils.initialize_committor_masses can be used to simplify this.
If the position-less loss is used, this must be set to None.
gamma : float
Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound)
By default 10000
delta_f : float
Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0.
create_graph : bool
Make loss backwardable, deactivate for validation to save memory, default True
separate_boundary_dataset : bool, optional
Switch to exculde boundary condition labeled data from the variational loss, by default True
descriptors_derivatives : Union[SmartDerivatives, torch.Tensor], optional
Derivatives of descriptors wrt atomic positions (if used) to speed up calculation of gradients, by default None.
Can be either:
- A `SmartDerivatives` object to save both memory and time, see also mlcolvar.core.loss.committor_loss.SmartDerivatives
- A torch.Tensor with the derivatives to save time, memory-wise could be less efficient
log_var : bool, optional
Switch to minimize the log of the variational functional, by default False.
use_gradients_wrt_positions : bool, optional
Whether to use gradients with respect to positions as prescribed in the original Kolmogorov variational functional, by default True.
Set to false to use the approximated variational principle defined in Ref. [3] without explicit dependence on the atomic coordinates derivatives.
z_regularization : float, optional
Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'.
The magnitude of the regularization is scaled by the given number, by default 0.0
z_threshold : float, optional
Sets a maximum threshold for the z value during the training, by default None.
The magnitude of the regularization term is scaled via the `z_regularization` key.
ref_idx: torch.Tensor, optional
Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling
when descriptors derivatives are provided, by default None.
Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset.
See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives
n_dim : int
Number of dimensions, by default None.
If None, it defaults to 3 for the position-based loss and to 1 for the position-less loss.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Total loss and its components, i.e., variational, boundary A, and boundary B
"""
# internal consistency checks
if descriptors_derivatives is not None and ref_idx is None:
raise ValueError ("Descriptors derivatives need reference indeces from the dataset! Use a dataset with the ref_idx, see docstrign for details")
if use_gradients_wrt_positions and atomic_masses is None:
raise ValueError ("atomic_masses must be provided when using Kolmogorov variational functional (use_gradients_wrt_positions is True)")
elif not use_gradients_wrt_positions:
if atomic_masses is not None:
raise ValueError ("atomic_masses must be None when using approximated variational principle (use_gradients_wrt_positions is False)")
if descriptors_derivatives is not None:
raise ValueError ("descriptors_derivatives must be None when using approximated variational principle (use_gradients_wrt_positions is False)")
if isinstance(descriptors_derivatives, torch.Tensor) and separate_boundary_dataset:
raise ValueError ("Descriptors derivatives via explicit tensor are not implemented with separate_boundary_dataset key! Either use SmartDerivatives or deactivate separate_boundary_dataset")
if (z_threshold is not None and (z_regularization == 0 or z_threshold <= 0)) or (z_threshold is None and z_regularization != 0) or z_regularization < 0:
raise ValueError(f"To apply the regularization on z space both z_threshold and z_regularization key must be positive. Found {z_threshold} and {z_regularization}!")
# check if input is graph
if isinstance(x, torch_geometric.data.batch.Batch):
_is_graph_data = True
batch = torch.clone(x['batch'])
node_types = torch.where(x['node_attrs'])[1]
x = x['positions']
else:
_is_graph_data = False
if n_dim is None:
n_dim = 3 if use_gradients_wrt_positions else 1
else:
if n_dim < 1 or n_dim > 3:
raise ValueError(f"n_dim should be a positive integer between 1 and 3, found {n_dim}!")
if _is_graph_data and n_dim != 3:
raise ValueError(f"n_dim should be 3 for graph-based models, found {n_dim}!")
if _is_graph_data and not use_gradients_wrt_positions:
raise ValueError("use_gradients_wrt_positions should be set to True (default) when using graph-based models!")
# checks and warnings
if _is_graph_data and descriptors_derivatives is not None:
raise ValueError("The descriptors_derivatives key cannot be used with GNN-based models!")
if _is_graph_data and separate_boundary_dataset:
warnings.warn("Using GNN-based models it may be better to set separate_boundary_dataset to False")
# ------------------------ SETUP ------------------------
# inherit right device
device = x.device
# expand mass tensor to [1, n_atoms*spatial_dims]
if atomic_masses is not None:
atomic_masses = atomic_masses.to(device).repeat_interleave(n_dim)
else:
atomic_masses = torch.ones(x.shape[1], device=device)
# squeeze labels
labels = labels.squeeze()
# Create masks to access different states data
mask_A = torch.nonzero(labels == 0, as_tuple=True)
mask_B = torch.nonzero(labels == 1, as_tuple=True)
# create mask for variational data
if separate_boundary_dataset:
mask_var = torch.nonzero(labels > 1, as_tuple=not(_is_graph_data))
else:
mask_var = torch.ones_like(labels, dtype=torch.bool)
if _is_graph_data:
# this needs to be on the batch index, not only the labels
aux = torch.where(mask_var)[0].to(device)
mask_var_batches = torch.isin(batch, aux)
mask_var_batches = batch[mask_var_batches]
else:
mask_var_batches = mask_var
# setup atomic masses
atomic_masses = atomic_masses.to(device)
# mass should have size [1, n_atoms*spatial_dims]
if _is_graph_data:
atomic_masses = atomic_masses[node_types[mask_var_batches]].unsqueeze(-1)
else:
atomic_masses = atomic_masses.unsqueeze(0)
# Update weights of basin B using the information on the delta_f
delta_f = torch.Tensor([delta_f]).to(device)
# B higher in energy --> A-B < 0
if delta_f < 0:
w[mask_B] *= torch.exp(delta_f)
# A higher in energy --> A-B > 0
elif delta_f > 0:
w[mask_A] *= torch.exp(-delta_f)
# weights should have size [n_batch, 1]
w = w.unsqueeze(-1)
# ------------------------ LOSS ------------------------
# Each loss contribution is scaled by the number of samples
# 1. ----- VARIATIONAL LOSS
# Compute gradients of q(x) wrt x
grad_outputs = torch.ones_like(q[mask_var])
grad = torch.autograd.grad(q[mask_var],
x,
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=create_graph)[0]
grad = grad[mask_var_batches]
if descriptors_derivatives is not None:
# in case the input is not positions but descriptors, we need to correct the gradients up to the positions
if use_gradients_wrt_positions and isinstance(descriptors_derivatives, SmartDerivatives):
# we use the precomputed derivatives from descriptors to pos
gradient_positions = descriptors_derivatives(grad, ref_idx[mask_var]).view(x[mask_var].shape[0], -1)
# --> If we directly pass the matrix d_desc/d_pos
elif use_gradients_wrt_positions and isinstance(descriptors_derivatives, torch.Tensor):
descriptors_derivatives = descriptors_derivatives.to(device)
gradient_positions = torch.einsum("bd,badx->bax", grad, descriptors_derivatives[ref_idx[mask_var]]).contiguous()
gradient_positions = gradient_positions.view(x[mask_var].shape[0], -1)
else:
# we get the square of grad(q) and we multiply by the weight
gradient_positions = grad
# we do the square
grad_square = torch.pow(gradient_positions, 2)
grad_square = torch.sum((grad_square * (1/atomic_masses)), axis=1, keepdim=True)
if _is_graph_data:
# we need to sum on the right batch first
grad_square = scatter_sum(grad_square, mask_var_batches, dim=0)
# for position-less committor
if not use_gradients_wrt_positions:
grad_square = grad_square**2
# variational contribution to loss: we sum over the batch
loss_var = torch.mean( grad_square * w[mask_var])
if log_var:
loss_var = gamma*torch.log1p(loss_var)
else:
loss_var = gamma*loss_var
# 2. ----- BOUNDARY LOSS
loss_A = gamma * torch.mean( q[mask_A].pow(2) )
loss_B = gamma * torch.mean( (q[mask_B] - 1).pow(2) )
# 3. ----- OPTIONAL regularization on z
if z_threshold is not None:
over_threshold = torch.relu(z.abs() - z_threshold)
loss_z_diff = z_regularization * torch.mean(over_threshold.pow(2))
else:
loss_z_diff = 0
# 4. ----- TOTAL LOSS
loss = loss_var + alpha*(loss_A + loss_B) + loss_z_diff
# TODO maybe there is no need to detach them for logging
return loss, loss_var.detach(), alpha*loss_A.detach(), alpha*loss_B.detach()