Source code for mlcolvar.core.loss.committor_loss

#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Committor function Loss Function
"""

__all__ = ["CommittorLoss", "committor_loss"]

# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

import torch
from typing import Tuple, Union
from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives

# =============================================================================
# LOSS FUNCTIONS
# =============================================================================


[docs] class CommittorLoss(torch.nn.Module): """Compute a loss function based on Kolmogorov's variational principle (or its approximation) for the determination of the committor function"""
[docs] def __init__(self, alpha: float, atomic_masses: Union[torch.Tensor, None], cell: float = None, gamma: float = 10000.0, delta_f: float = 0.0, separate_boundary_dataset : bool = True, descriptors_derivatives : Union[SmartDerivatives, torch.Tensor] = None, log_var: bool = False, use_gradients_wrt_positions: bool = True, z_regularization: float = 0.0, z_threshold: float = None, n_dim : int = None, ): """Compute Kolmogorov's variational principle loss and impose boundary conditions on the metastable states Parameters ---------- alpha : float Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B) atomic_masses : torch.Tensor, optional List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z, by default None. The mlcolvar.cvs.committor.utils.initialize_committor_masses can be used to simplify this. If the position-less loss is used, this must be set to None. cell : float, optional CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, by default None gamma : float, optional Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound), by default 10000 delta_f : float, optional Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0. State B is supposed to be higher in energy. separate_boundary_dataset : bool, optional Switch to exculde boundary condition labeled data from the variational loss, by default True descriptors_derivatives : Union[SmartDerivatives, torch.Tensor], optional Derivatives of descriptors wrt atomic positions (if used) to speed up calculation of gradients, by default None. Can be either: - A `SmartDerivatives` object to save both memory and time, see also mlcolvar.core.loss.committor_loss.SmartDerivatives - A torch.Tensor with the derivatives to save time, memory-wise could be less efficient ref_idx: torch.Tensor, optional Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling when descriptors derivatives are provided, by default None. Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset. See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives log_var : bool, optional Switch to minimize the log of the variational functional, by default False. use_gradients_wrt_positions : bool, optional Whether to use gradients with respect to positions as prescribed in the original Kolmogorov variational functional, by default True. z_regularization : float, optional Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'. The magnitude of the regularization is scaled by the given number, by default 0.0 z_threshold : float, optional Sets a maximum threshold for the z value during the training, by default None. The magnitude of the regularization term is scaled via the `z_regularization` key. n_dim : int Number of dimensions, by default None. If None, it defaults to 3 for the position-based loss and to 1 for the position-less loss. """ super().__init__() self.register_buffer("atomic_masses", atomic_masses) self.alpha = alpha self.cell = cell self.gamma = gamma self.delta_f = delta_f self.descriptors_derivatives = descriptors_derivatives self.separate_boundary_dataset = separate_boundary_dataset self.log_var = log_var self.use_gradients_wrt_positions = use_gradients_wrt_positions self.z_regularization = z_regularization self.z_threshold = z_threshold self.n_dim = n_dim
[docs] def forward(self, x: torch.Tensor, z: torch.Tensor, q: torch.Tensor, labels: torch.Tensor, w: torch.Tensor, ref_idx: torch.Tensor = None, create_graph: bool = True ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Committor loss forward pass Parameters ---------- x : torch.Tensor Model input, i.e., either positions or descriptors if using descriptors_derivatives z : torch.Tensor Model unactivated output, i.e., z value q : torch.Tensor Model final output, i.e., committor value labels : torch.Tensor Input labels w : torch.Tensor Input weights ref_idx : torch.Tensor, optional Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling when descriptors derivatives are provided, by default None. Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset. See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives create_graph : bool, optional Whether to create the graph during the computation for backpropagation, by default True Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] Total loss and its components, i.e., variational, boundary A, and boundary B """ return committor_loss(x=x, z=z, q=q, labels=labels, w=w, atomic_masses=self.atomic_masses, alpha=self.alpha, gamma=self.gamma, delta_f=self.delta_f, create_graph=create_graph, cell=self.cell, separate_boundary_dataset=self.separate_boundary_dataset, descriptors_derivatives=self.descriptors_derivatives, log_var=self.log_var, use_gradients_wrt_positions=self.use_gradients_wrt_positions, z_regularization=self.z_regularization, z_threshold=self.z_threshold, ref_idx=ref_idx, n_dim=self.n_dim )
def committor_loss(x: torch.Tensor, z: torch.Tensor, q: torch.Tensor, labels: torch.Tensor, w: torch.Tensor, alpha: float, atomic_masses: Union[torch.Tensor, None] = None, gamma: float = 10000, delta_f: float = 0, create_graph: bool = True, cell: float = None, separate_boundary_dataset: bool = True, descriptors_derivatives: Union[SmartDerivatives, torch.Tensor] = None, log_var: bool = False, use_gradients_wrt_positions: bool = True, z_regularization: float = 0.0, z_threshold : float = None, ref_idx: torch.Tensor = None, n_dim : int = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute variational loss for committor optimization with boundary conditions Parameters ---------- x : torch.Tensor Input of the NN z : torch.Tensor z value z(x), it is the unactivated output of NN q : torch.Tensor Committor guess q(x), it is the output of NN labels : torch.Tensor Labels for states, A and B states for boundary conditions w : torch.Tensor Reweighing factors to Boltzmann distribution. This should depend on the simulation in which the data were collected. alpha : float Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B) atomic_masses : torch.Tensor, optional List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z, by default None. The mlcolvar.cvs.committor.utils.initialize_committor_masses can be used to simplify this. If the position-less loss is used, this must be set to None. gamma : float Hyperparamer that scales the whole loss to avoid too small numbers, i.e. gamma*(loss_var + loss_bound) By default 10000 delta_f : float Delta free energy between A (label 0) and B (label 1), units is kBT, by default 0. create_graph : bool Make loss backwardable, deactivate for validation to save memory, default True cell : float CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, default None separate_boundary_dataset : bool, optional Switch to exculde boundary condition labeled data from the variational loss, by default True descriptors_derivatives : Union[SmartDerivatives, torch.Tensor], optional Derivatives of descriptors wrt atomic positions (if used) to speed up calculation of gradients, by default None. Can be either: - A `SmartDerivatives` object to save both memory and time, see also mlcolvar.core.loss.committor_loss.SmartDerivatives - A torch.Tensor with the derivatives to save time, memory-wise could be less efficient log_var : bool, optional Switch to minimize the log of the variational functional, by default False. use_gradients_wrt_positions : bool, optional Whether to use gradients with respect to positions as prescribed in the original Kolmogorov variational functional, by default True. Set to false to use the approximated variational principle defined in Ref. [3] without explicit dependence on the atomic coordinates derivatives. z_regularization : float, optional Scales a regularization on the learned z space preventing it from exceeding the threshold given with 'z_threshold'. The magnitude of the regularization is scaled by the given number, by default 0.0 z_threshold : float, optional Sets a maximum threshold for the z value during the training, by default None. The magnitude of the regularization term is scaled via the `z_regularization` key. ref_idx: torch.Tensor, optional Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling when descriptors derivatives are provided, by default None. Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset. See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives n_dim : int Number of dimensions, by default None. If None, it defaults to 3 for the position-based loss and to 1 for the position-less loss. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] Total loss and its components, i.e., variational, boundary A, and boundary B """ # internal consistency checks if descriptors_derivatives is not None and ref_idx is None: raise ValueError ("Descriptors derivatives need reference indeces from the dataset! Use a dataset with the ref_idx, see docstrign for details") if use_gradients_wrt_positions and atomic_masses is None: raise ValueError ("atomic_masses must be provided when using Kolmogorov variational functional (use_gradients_wrt_positions is True)") elif not use_gradients_wrt_positions: if atomic_masses is not None: raise ValueError ("atomic_masses must be None when using approximated variational principle (use_gradients_wrt_positions is False)") if descriptors_derivatives is not None: raise ValueError ("descriptors_derivatives must be None when using approximated variational principle (use_gradients_wrt_positions is False)") if isinstance(descriptors_derivatives, torch.Tensor) and separate_boundary_dataset: raise ValueError ("Descriptors derivatives via explicit tensor are not implemented with separate_boundary_dataset key! Either use SmartDerivatives or deactivate separate_boundary_dataset") if (z_threshold is not None and (z_regularization == 0 or z_threshold <= 0)) or (z_threshold is None and z_regularization != 0) or z_regularization < 0: raise ValueError(f"To apply the regularization on z space both z_threshold and z_regularization key must be positive. Found {z_threshold} and {z_regularization}!") if n_dim is None: n_dim = 3 if use_gradients_wrt_positions else 1 else: if n_dim < 1 or n_dim > 3: raise ValueError(f"n_dim should be a positive integer between 1 and 3, found {n_dim}!") # ------------------------ SETUP ------------------------ # inherit right device device = x.device # expand mass tensor to [1, n_atoms*spatial_dims] if atomic_masses is not None: atomic_masses = atomic_masses.to(device).repeat_interleave(n_dim) else: atomic_masses = torch.ones(x.shape[1], device=device) # squeeze labels labels = labels.squeeze() # Create masks to access different states data mask_A = labels == 0 mask_B = labels == 1 # create mask for variational data if separate_boundary_dataset: mask_var = labels > 1 else: mask_var = torch.ones_like(labels, dtype=torch.bool) # Update weights of basin B using the information on the delta_f delta_f = torch.Tensor([delta_f]).to(device) # B higher in energy --> A-B < 0 if delta_f < 0: w[mask_B] *= torch.exp(delta_f) # A higher in energy --> A-B > 0 elif delta_f > 0: w[mask_A] *= torch.exp(-delta_f) # weights should have size [n_batch, 1] w = w.unsqueeze(-1) # ------------------------ LOSS ------------------------ # Each loss contribution is scaled by the number of samples # 1. ----- VARIATIONAL LOSS # Compute gradients of q(x) wrt x grad_outputs = torch.ones_like(q[mask_var]) grad = torch.autograd.grad(q[mask_var], x, grad_outputs=grad_outputs, retain_graph=True, create_graph=create_graph)[0] grad = grad[mask_var] if cell is not None: grad = grad / cell # in case the input is not positions but descriptors, we need to correct the gradients up to the positions if use_gradients_wrt_positions and isinstance(descriptors_derivatives, SmartDerivatives): # we use the precomputed derivatives from descriptors to pos gradient_positions = descriptors_derivatives(grad, ref_idx[mask_var]).view(x[mask_var].shape[0], -1) # --> If we directly pass the matrix d_desc/d_pos elif use_gradients_wrt_positions and isinstance(descriptors_derivatives, torch.Tensor): descriptors_derivatives = descriptors_derivatives.to(device) gradient_positions = torch.einsum("bd,badx->bax", grad, descriptors_derivatives[ref_idx[mask_var]]).contiguous() gradient_positions = gradient_positions.view(x[mask_var].shape[0], -1) # If the input was already positions else: gradient_positions = grad # we do the square grad_square = torch.pow(gradient_positions, 2) # multiply by masses try: grad_square = torch.sum((grad_square * (1/atomic_masses)), axis=1, keepdim=True) except RuntimeError as e: raise RuntimeError(e, """[HINT]: Is you system in 3 dimension? By default the code assumes so, if it's not the case change the n_dim key to the right dimensionality.""") if not use_gradients_wrt_positions: grad_square = grad_square**2 # variational contribution to loss: we sum over the batch loss_var = torch.mean( grad_square * w[mask_var]) if log_var: loss_var = gamma*torch.log1p(loss_var) else: loss_var *= gamma # 2. ----- BOUNDARY LOSS loss_A = gamma * torch.mean( q[mask_A].pow(2) ) loss_B = gamma * torch.mean( (q[mask_B] - 1).pow(2) ) # 3. ----- OPTIONAL regularization on z if z_threshold is not None: over_threshold = torch.relu(z.abs() - z_threshold) loss_z_diff = z_regularization * torch.mean(over_threshold.pow(2)) else: loss_z_diff = 0 # 4. ----- TOTAL LOSS loss = loss_var + alpha*(loss_A + loss_B) + loss_z_diff return loss, loss_var.detach(), alpha*loss_A.detach(), alpha*loss_B.detach()