#!/usr/bin/env python
# =============================================================================
# MODULE DOCSTRING
# =============================================================================
"""
Committor function Loss Function
"""
__all__ = ["CommittorLoss", "committor_loss"]
# =============================================================================
# GLOBAL IMPORTS
# =============================================================================
import torch
from typing import Tuple, Union
from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives
# =============================================================================
# LOSS FUNCTIONS
# =============================================================================
[docs]
class CommittorLoss(torch.nn.Module):
"""Compute a loss function based on Kolmogorov's variational principle for the determination of the committor function"""
[docs]
def __init__(self,
atomic_masses: torch.Tensor,
alpha: float,
cell: float = None,
gamma: float = 10000.0,
delta_f: float = 0.0,
separate_boundary_dataset : bool = True,
descriptors_derivatives : Union[SmartDerivatives, torch.Tensor] = None,
log_var: bool = False,
z_regularization: float = 0.0,
z_threshold: float = None,
n_dim : int = 3,
):
"""Compute Kolmogorov's variational principle loss and impose boundary conditions on the metastable states
Parameters
----------
atomic_masses : torch.Tensor
Atomic masses of the atoms in the system
alpha : float
Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B)
cell : float, optional
CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, by default None
gamma : float, optional
Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound), by default 10000
delta_f : float, optional
Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0.
State B is supposed to be higher in energy.
separate_boundary_dataset : bool, optional
Switch to exculde boundary condition labeled data from the variational loss, by default True
descriptors_derivatives : Union[SmartDerivatives, torch.Tensor], optional
Derivatives of descriptors wrt atomic positions (if used) to speed up calculation of gradients, by default None.
Can be either:
- A `SmartDerivatives` object to save both memory and time, see also mlcolvar.core.loss.committor_loss.SmartDerivatives
- A torch.Tensor with the derivatives to save time, memory-wise could be less efficient
ref_idx: torch.Tensor, optional
Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling
when descriptors derivatives are provided, by default None.
Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset.
See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives
log_var : bool, optional
Switch to minimize the log of the variational functional, by default False.
z_regularization : float, optional
Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'.
The magnitude of the regularization is scaled by the given number, by default 0.0
z_threshold : float, optional
Sets a maximum threshold for the z value during the training, by default None.
The magnitude of the regularization term is scaled via the `z_regularization` key.
n_dim : int
Number of dimensions, by default 3.
"""
super().__init__()
self.register_buffer("atomic_masses", atomic_masses)
self.alpha = alpha
self.cell = cell
self.gamma = gamma
self.delta_f = delta_f
self.descriptors_derivatives = descriptors_derivatives
self.separate_boundary_dataset = separate_boundary_dataset
self.log_var = log_var
self.z_regularization = z_regularization
self.z_threshold = z_threshold
self.n_dim = n_dim
[docs]
def forward(self,
x: torch.Tensor,
z: torch.Tensor,
q: torch.Tensor,
labels: torch.Tensor,
w: torch.Tensor,
ref_idx: torch.Tensor = None,
create_graph: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Committor loss forward pass
Parameters
----------
x : torch.Tensor
Model input, i.e., either positions or descriptors if using descriptors_derivatives
z : torch.Tensor
Model unactivated output, i.e., z value
q : torch.Tensor
Model final output, i.e., committor value
labels : torch.Tensor
Input labels
w : torch.Tensor
Input weights
ref_idx : torch.Tensor, optional
Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling
when descriptors derivatives are provided, by default None.
Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset.
See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives
create_graph : bool, optional
Whether to create the graph during the computation for backpropagation, by default True
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Total loss and its components, i.e., variational, boundary A, and boundary B
"""
return committor_loss(x=x,
z=z,
q=q,
labels=labels,
w=w,
atomic_masses=self.atomic_masses,
alpha=self.alpha,
gamma=self.gamma,
delta_f=self.delta_f,
create_graph=create_graph,
cell=self.cell,
separate_boundary_dataset=self.separate_boundary_dataset,
descriptors_derivatives=self.descriptors_derivatives,
log_var=self.log_var,
z_regularization=self.z_regularization,
z_threshold=self.z_threshold,
ref_idx=ref_idx,
n_dim=self.n_dim
)
def committor_loss(x: torch.Tensor,
z: torch.Tensor,
q: torch.Tensor,
labels: torch.Tensor,
w: torch.Tensor,
atomic_masses: torch.Tensor,
alpha: float,
gamma: float = 10000,
delta_f: float = 0,
create_graph: bool = True,
cell: float = None,
separate_boundary_dataset: bool = True,
descriptors_derivatives: Union[SmartDerivatives, torch.Tensor] = None,
log_var: bool = False,
z_regularization: float = 0.0,
z_threshold : float = None,
ref_idx: torch.Tensor = None,
n_dim : int = 3,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute variational loss for committor optimization with boundary conditions
Parameters
----------
x : torch.Tensor
Input of the NN
z : torch.Tensor
z value z(x), it is the unactivated output of NN
q : torch.Tensor
Committor guess q(x), it is the output of NN
labels : torch.Tensor
Labels for states, A and B states for boundary conditions
w : torch.Tensor
Reweighing factors to Boltzmann distribution. This should depend on the simulation in which the data were collected.
atomic_masses : torch.Tensor
List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z.
Can be created using `committor.utils.initialize_committor_masses`
alpha : float
Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B)
gamma : float
Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound)
By default 10000
delta_f : float
Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0.
create_graph : bool
Make loss backwardable, deactivate for validation to save memory, default True
cell : float
CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, default None
separate_boundary_dataset : bool, optional
Switch to exculde boundary condition labeled data from the variational loss, by default True
descriptors_derivatives : Union[SmartDerivatives, torch.Tensor], optional
Derivatives of descriptors wrt atomic positions (if used) to speed up calculation of gradients, by default None.
Can be either:
- A `SmartDerivatives` object to save both memory and time, see also mlcolvar.core.loss.committor_loss.SmartDerivatives
- A torch.Tensor with the derivatives to save time, memory-wise could be less efficient
log_var : bool, optional
Switch to minimize the log of the variational functional, by default False.
z_regularization : float, optional
Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'.
The magnitude of the regularization is scaled by the given number, by default 0.0
z_threshold : float, optional
Sets a maximum threshold for the z value during the training, by default None.
The magnitude of the regularization term is scaled via the `z_regularization` key.
ref_idx: torch.Tensor, optional
Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling
when descriptors derivatives are provided, by default None.
Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset.
See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives
n_dim : int
Number of dimensions, by default 3.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Total loss and its components, i.e., variational, boundary A, and boundary B
"""
if descriptors_derivatives is not None and ref_idx is None:
raise ValueError ("Descriptors derivatives need reference indeces from the dataset! Use a dataset with the ref_idx, see docstrign for details")
if isinstance(descriptors_derivatives, torch.Tensor) and separate_boundary_dataset:
raise ValueError ("Descriptors derivatives via explicit tensor are not implemented with separate_boundary_dataset key! Either use SmartDerivatives or deactivate separate_boundary_dataset")
if (z_threshold is not None and (z_regularization == 0 or z_threshold <= 0)) or (z_threshold is None and z_regularization != 0) or z_regularization < 0:
raise ValueError(f"To apply the regularization on z space both z_threshold and z_regularization key must be positive. Found {z_threshold} and {z_regularization}!")
# ------------------------ SETUP ------------------------
# inherit right device
device = x.device
# expand mass tensor to [1, n_atoms*spatial_dims]
atomic_masses = atomic_masses.to(device).repeat_interleave(n_dim)
# squeeze labels
labels = labels.squeeze()
# Create masks to access different states data
mask_A = labels == 0
mask_B = labels == 1
# create mask for variational data
if separate_boundary_dataset:
mask_var = labels > 1
else:
mask_var = torch.ones_like(labels, dtype=torch.bool)
# Update weights of basin B using the information on the delta_f
delta_f = torch.Tensor([delta_f]).to(device)
# B higher in energy --> A-B < 0
if delta_f < 0:
w[mask_B] *= torch.exp(delta_f)
# A higher in energy --> A-B > 0
elif delta_f > 0:
w[mask_A] *= torch.exp(-delta_f)
# weights should have size [n_batch, 1]
w = w.unsqueeze(-1)
# ------------------------ LOSS ------------------------
# Each loss contribution is scaled by the number of samples
# 1. ----- VARIATIONAL LOSS
# Compute gradients of q(x) wrt x
grad_outputs = torch.ones_like(q[mask_var])
grad = torch.autograd.grad(q[mask_var],
x,
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=create_graph)[0]
grad = grad[mask_var]
if cell is not None:
grad = grad / cell
# in case the input is not positions but descriptors, we need to correct the gradients up to the positions
if isinstance(descriptors_derivatives, SmartDerivatives):
# we use the precomputed derivatives from descriptors to pos
gradient_positions = descriptors_derivatives(grad, ref_idx[mask_var]).view(x[mask_var].shape[0], -1)
# --> If we directly pass the matrix d_desc/d_pos
elif isinstance(descriptors_derivatives, torch.Tensor):
descriptors_derivatives = descriptors_derivatives.to(device)
gradient_positions = torch.einsum("bd,badx->bax", grad, descriptors_derivatives[ref_idx[mask_var]]).contiguous()
gradient_positions = gradient_positions.view(x[mask_var].shape[0], -1)
# If the input was already positions
else:
gradient_positions = grad
# we do the square
grad_square = torch.pow(gradient_positions, 2)
# multiply by masses
try:
grad_square = torch.sum((grad_square * (1/atomic_masses)),
axis=1,
keepdim=True)
except RuntimeError as e:
raise RuntimeError(e, """[HINT]: Is you system in 3 dimension? By default the code assumes so, if it's not the case change the n_dim key to the right dimensionality.""")
# variational contribution to loss: we sum over the batch
loss_var = torch.mean(grad_square * w[mask_var])
if log_var:
loss_var = torch.log1p(loss_var)
else:
loss_var *= gamma
# 2. ----- BOUNDARY LOSS
loss_A = gamma * torch.mean( q[mask_A].pow(2) )
loss_B = gamma * torch.mean( (q[mask_B] - 1).pow(2) )
# 3. ----- OPTIONAL regularization on z
if z_threshold is not None:
over_threshold = torch.relu(z.abs() - z_threshold)
loss_z_diff = z_regularization * torch.mean(over_threshold.pow(2))
else:
loss_z_diff = 0
# 4. ----- TOTAL LOSS
loss = loss_var + alpha*(loss_A + loss_B) + loss_z_diff
return loss, loss_var.detach(), alpha*loss_A.detach(), alpha*loss_B.detach()