Source code for mlcolvar.core.loss.generator_loss

__all__ = ["GeneratorLoss"]

# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

import torch
from typing import Union, Tuple
from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives


[docs] class GeneratorLoss(torch.nn.Module): """Computes the loss function to learn a representation for the resolvent of the infinitesimal generator"""
[docs] def __init__(self, r: int, eta: float, friction: torch.Tensor, alpha: float, descriptors_derivatives: Union[SmartDerivatives, torch.Tensor] = None, n_dim: int = 3, u_stat: bool = True, ): """Computes the loss to learn a representation on which the resolvent of the infinitesimal generator can be learned Parameters ---------- r : int Number of eigenfunctions wanted, i.e., number of outputs of model. eta : float Hyperparameter for the shift to define the resolvent, i.e., $(\eta I-_mathcal{L})^{-1}$ friction : torch.Tensor Langevin friction, i.e., $\sqrt{k_B*T/(gamma*m_i)}$ alpha : float Hyperparamer that scales the contribution of orthonormality loss to the total loss, i.e., L = L_ef + alpha*L_ortho descriptors_derivatives : Union[SmartDerivatives, torch.Tensor], optional Derivatives of descriptors wrt atomic positions (if used) to speed up calculation of gradients, by default None. Can be either: - A `SmartDerivatives` object to save both memory and time, see also mlcolvar.core.loss.committor_loss.SmartDerivatives - A torch.Tensor with the derivatives to save time, memory-wise could be less efficient ref_idx: torch.Tensor, optional Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling when descriptors derivatives are provided, by default None. Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset. See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives n_dim : int Number of dimensions, by default 3. u_stat : bool Do we use U-statistics to compute the loss """ super().__init__() self.eta = eta self.register_buffer("friction", friction) self.lambdas = torch.nn.Parameter(10 * torch.randn(r), requires_grad=True) self.alpha = alpha self.descriptors_derivatives = descriptors_derivatives self.n_dim = n_dim self.u_stat=u_stat
[docs] def forward(self, input : torch.Tensor, output : torch.Tensor, weights : torch.Tensor, ref_idx : torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # preload descriptors matrix on device if isinstance(self.descriptors_derivatives, torch.Tensor): if self.descriptors_derivatives.device != input.device: self.descriptors_derivatives = self.descriptors_derivatives.to(input.device) return generator_loss(input=input, output=output, weights=weights, eta=self.eta, alpha=self.alpha, friction=self.friction, lambdas=self.lambdas, descriptors_derivatives=self.descriptors_derivatives, ref_idx=ref_idx, n_dim=self.n_dim, u_stat=self.u_stat )
# TODO check that maybe we can replace this by the one from deepTICA def compute_covariance(X, weights): n = X.size(0) pre_factor = n / (n - 1) if X.ndim == 2: return pre_factor * ( torch.einsum("ij,ik,i->jk", X, X, weights) / n ) # (X.T @ X / n - mean @ mean.T) else: return pre_factor * (torch.einsum("ijk,ilk,i->jl", X, X, weights) / n) def generator_loss(input : torch.Tensor, output : torch.Tensor, weights : torch.Tensor, eta : float, alpha : float, friction : torch.Tensor, lambdas : torch.Tensor, descriptors_derivatives : Union[SmartDerivatives, torch.Tensor] = None, ref_idx : torch.Tensor = None, n_dim : int = 3, u_stat : bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Optimizes r functions to be the representation on which the resolvent of the infinitesimal generator can be learned Parameters ---------- input : torch.Tensor Input of the (set of) neural networks output : torch.Tensor Output of the (set of) neural networks weights : torch.Tensor Statistical weights of the samples, this could be from reweighting. eta : float Hyperparameter for the shift to define the resolvent, i.e., $(\eta I-_mathcal{L})^{-1}$ alpha : float Hyperparamer that scales the contribution of orthonormality loss to the total loss, i.e., L = L_ef + alpha*L_ortho friction : torch.Tensor Langevin friction, i.e., $\sqrt{k_B*T/(gamma*m_i)}$ lambdas : torch.Tensor Trainable parameters. After training, they should correspond to the resolvent eigenvalues. descriptors_derivatives : Union[SmartDerivatives, torch.Tensor], optional Derivatives of descriptors wrt atomic positions (if used) to speed up calculation of gradients, by default None. Can be either: - A `SmartDerivatives` object to save both memory and time, see also mlcolvar.core.loss.committor_loss.SmartDerivatives - A torch.Tensor with the derivatives to save time, memory-wise could be less efficient ref_idx: torch.Tensor, optional Reference indeces for the unshuffled dataset for properly handling batching/splitting/shuffling when descriptors derivatives are provided, by default None. Ref_idx can be generated automatically using SmartDerivatives or by setting create_ref_idx=True when initializing a DictDataset. See also mlcolvar.core.loss.utils.smart_derivatives.SmartDerivatives n_dim : int Number of dimensions, by default 3. u_stat : bool Do we use U-statistics to compute the loss Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Total loss, eigenfunctions loss, orthonormality loss """ if descriptors_derivatives is not None and ref_idx is None: raise ValueError ("Descriptors derivatives need reference indeces from the dataset! Use a dataset with the ref_idx, see docstrign for details") # ------------------------ SETUP ------------------------ # get correct device device = input.device # move and process lambdas to device lambdas = lambdas.to(device) diag_lamb = torch.diag(lambdas**2) # get number of outputs and sample sizes r = output.shape[1] sample_size = output.shape[0] // 2 # expand friction tensor friction = friction.repeat_interleave(n_dim) # ------------------------ GRADIENTS ------------------------ # compute gradients of output wrt to the input iterating on the outputs grad_outputs = torch.ones(len(output), device=device) gradient = torch.stack([torch.autograd.grad(outputs=output[:, idx], inputs=input, grad_outputs=grad_outputs, retain_graph=True, create_graph=True)[0] for idx in range(r) ], dim=2) # in case the input is not positions but descriptors, we need to correct the gradients up to the positions # --> If we pass a SmartDerivative object that takes the nonzero elements of the matrix d_desc/d_pos if isinstance(descriptors_derivatives, SmartDerivatives): gradient_positions = descriptors_derivatives(gradient, ref_idx).view(input.shape[0], -1, r) # --> If we directly pass the matrix d_desc/d_pos elif isinstance(descriptors_derivatives, torch.Tensor): descriptors_derivatives = descriptors_derivatives.to(device) gradient_positions = torch.einsum("bdo,badx->baxo", gradient, descriptors_derivatives[ref_idx]).contiguous() gradient_positions = gradient_positions.view(input.shape[0], # number of entries descriptors_derivatives.shape[1] * 3, # number of atoms * 3 output.shape[-1] # number of outputs ) # If the input was already positions else: gradient_positions = gradient if r==1: gradient_positions = gradient_positions.unsqueeze(-1) # this is to make the following computation easier to write gradient_positions = gradient_positions.transpose(2,1).contiguous() # multiply by friction try: gradient_positions = gradient_positions * torch.sqrt(friction) except RuntimeError as e: raise RuntimeError(e, """[HINT]: Is you system in 3 dimension? By default the code assumes so, if it's not the case change the n_dim key to the right dimensionality.""") # ------------------------ COVARIANCES ------------------------ if u_stat: first = slice(0, sample_size) second = slice(sample_size, None) # In order to have unbiased estimation, we split the dataset in two chunks weights_X, weights_Y = weights[first], weights[second] gradient_X, gradient_Y = gradient_positions[first], gradient_positions[second] psi_X, psi_Y = output[first], output[second] # compute covariances cov_X = compute_covariance(psi_X, weights_X) cov_Y = compute_covariance(psi_Y, weights_Y) dcov_X = compute_covariance(gradient_X, weights_X) dcov_Y = compute_covariance(gradient_Y, weights_Y) # action of shifted generator on the two chunks W1 = (eta * cov_X + dcov_X) @ diag_lamb W2 = (eta * cov_Y + dcov_Y) @ diag_lamb # ------------------------ COMPUTE LOSSES ------------------------ # Unbiased estimation of the "variational part" loss_ef = torch.trace( ((cov_X @ diag_lamb) @ W2 + (cov_Y @ diag_lamb) @ W1) / 2 - cov_X @ diag_lamb - cov_Y @ diag_lamb ) # Orthonormality part I = torch.eye(output.shape[1], device=output.device, dtype=output.dtype) loss_ortho = alpha * torch.trace( (I - cov_X) @ (I - cov_Y) ) else: # compute covariances cov = compute_covariance(output, weights) dcov = compute_covariance(gradient_positions, weights) W = (eta * cov + dcov) @ diag_lamb # ------------------------ COMPUTE LOSSES ------------------------ # Unbiased estimation of the "variational part" loss_ef = torch.trace( (cov @ diag_lamb) @ W - 2 * cov @ diag_lamb ) # Orthonormality part I = torch.eye(output.shape[1], device=output.device, dtype=output.dtype) loss_ortho = alpha * torch.trace( (I - cov) @ (I - cov) ) # combine loss = loss_ef + loss_ortho return loss, loss_ef.detach(), loss_ortho.detach()