import torch
import gc
import numpy as np
import inspect
from mlcolvar.utils._code import scatter_sum
from mlcolvar.data import DictDataset
from mlcolvar.core.transform import Transform
from mlcolvar.core.transform.descriptors.utils import sanitize_positions_shape, resolve_cell
__all__ = ["SmartDerivatives", "compute_descriptors_derivatives"]
[docs]
class SmartDerivatives(torch.nn.Module):
"""
Utils to compute efficently (time and memory wise) the derivatives of the model output wrt the positions
used to compute the input descriptors.
Rather than computing explicitly the derivatives wrt the positions, we compute those wrt the descriptors (right input)
and multiply them by the matrix of the derivatives of the descriptors wrt the positions (left input).
Overview
Preparation:
- Finds the non-zero entries of the derivatives of the descriptors wrt the positions (left)
- Stores such entries in a big 1D tensor
- Stores the indeces to find such entries in the original derivatives matrix
- Creates a big 1D tensor of indeces that allows properly taking together the contributions
Forward:
- Use the matrix indeces to retrieve the corresponding elements from the derivatives of output wrt the descriptors (right)
into a big 1D tensor
- Get the single contributions via element-wise multiplication (i.e., of each atom to the output due
to a single descriptor along a single space dimension)
- Scatter the single contributions to global contributions (of each atom to each output along each space dimension)
When working with batches or splits the scatter indeces are rescaled from the whole dataset to the batched entry.
"""
[docs]
def __init__(self,
setup_device : str = 'cpu',
force_all_atoms : bool = False
):
"""Initialize the smart derivatives object.
To setup the class, use the `setup` method.
Parameters
----------
der_desc_wrt_pos : torch.Tensor
Tensor containing the derivatives of the descriptors wrt the atomic positions
n_atoms : int
Number of atoms in the systems, all the atoms should be used in at least one of the descriptors
setup_device : str
Device on which to perform the expensive calculations. Either 'cpu' or 'cuda', by default 'cpu'
force_all_atoms: bool
Whether to allow the use of atoms that are non involved in the calculation of any descriptor, by default False
"""
super().__init__()
self.force_all_atoms = force_all_atoms
self.setup_device = setup_device
# auxiliary variable to check if the moduel has been properly set up
self._check_setup = False
# auxiliary variable to check if elements have been loaded on computation device
self._device_preload = False
[docs]
def setup(self,
dataset: DictDataset,
descriptor_function: Transform,
n_atoms : int,
separate_boundary_dataset = False,
descriptors_batch_size : int = None
) -> DictDataset:
"""Setup the smart derivatives object from a dataset and a descriptor function.
Returns a properly formatted new dataset with the descriptors as data.
Parameters
----------
dataset : DictDataset
Input dataset containing atomic positions as `data` and the needed entries
descriptor_function : Transform
Function to compute the descriptors from the atomic positions, it should be taken from the mlcolvar.core.tranform module
n_atoms : int
Number of atoms in the dataset
separate_boundary_dataset : bool, optional
Whether to separate the boundary dataset from the variational one, by default False
NB: Should be used only for mlcolvar.cvs.committor.
batch_size : int
Size of batches to process data, useful for heavy computation to avoid memory overflows, if None a singel batch is used, by default None
Returns
-------
DictDataset
Updated dataset with the computed descriptors as 'data'.
"""
self.n_atoms = n_atoms
# compute descriptors and their derivatives from original dataset
pos, desc, d_desc_d_x = compute_descriptors_derivatives(dataset=dataset,
descriptor_function=descriptor_function,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset,
batch_size=descriptors_batch_size)
# create a new dataset with the descriptors and reference indeces
smart_dataset = create_smart_dataset(desc=desc,
dataset=dataset,
separate_boundary_dataset=separate_boundary_dataset)
# initialize the fixed part of the calculation of smart derivatives (i.e., left part)
self._setup_left(left_input=d_desc_d_x, setup_device=self.setup_device)
self._check_setup = True
return smart_dataset
def _setup_left(self,
left_input : torch.Tensor,
setup_device : str = 'cpu'):
"""Setup the fixed part of the computation: the non-zero elements of the derivatives of the descriptors wrt the positions and the related indeces
"""
with torch.no_grad():
self.total_dataset_length = len(left_input)
# all the setup should be done on the CPU by default
left_input = left_input.to(torch.device(setup_device))
# the indeces in mat_ind are: batch, atom, descriptor and dimension
self.left, mat_ind = self._create_nonzero_left(left_input)
# save them with clearer names
self.batch_ind = mat_ind[0].long().detach()
self.atom_ind = mat_ind[1].long().detach()
self.desc_ind = mat_ind[2].long().detach()
self.dim_ind = mat_ind[3].long().detach()
# get indeces to scatter the contributions to the right place at the end
self.scatter_indeces, self.batch_shift = self._get_scatter_indices(batch_ind = self.batch_ind,
atom_ind=self.atom_ind,
dim_ind=self.dim_ind)
self.scatter_indeces = self.scatter_indeces.long().detach()
def _create_nonzero_left(self, x):
"""Find the indeces of the non-zero elements of the left input (i.e., derivatives of descriptors wrt positions)
"""
# check if there are inhomogeneous entries in the derivatives
# e.g., with only one xyz component that is zero and the others nonzero
row_inhomogeneous = ~( x.all(dim=-1) | ~x.any(dim=-1) )
# make them homogeneous
x[row_inhomogeneous] += 10
# find indeces of nonzero entries of the d_desc_d_pos
mat_ind = x.nonzero(as_tuple=True)
# it is possible that some atoms are not involved in any descriptor
used_atoms = torch.unique(mat_ind[1])
n_effective_atoms = len(used_atoms)
# check if there are atoms that have not been used
if n_effective_atoms < self.n_atoms:
# find not used atoms
missing_atoms = torch.arange(self.n_atoms)[torch.logical_not(torch.isin(torch.arange(self.n_atoms), used_atoms))]
if self.force_all_atoms:
# add by hand a contribute to at least one batch and one descriptor
# we use it to add the correct indeces, then we revert it
x[:, missing_atoms, 0, :] = x[:, missing_atoms, 0, :] + 10
# find indeces of nonzero entries of augmented d_desc_d_pos
mat_ind = x.nonzero(as_tuple=True)
# find indeces of nonzero entries of flattened augmented matrix
vec_ind = x.ravel().nonzero(as_tuple=True)
# revert the modification
x[:, missing_atoms, 0, :] = x[:, missing_atoms, 0, :] - 10
else:
raise ValueError(f"Some of the input atoms are not used in any of the descriptors. The not used atom IDs are : {missing_atoms}. If you want to include all atoms even if not used swtich the force_all_atoms key on. ")
else:
# find indeces of nonzero entries of flattened matrix
vec_ind = x.ravel().nonzero(as_tuple=True)
# remove the modification that ensured homogeneity
x[row_inhomogeneous] -= 10
# create vector with the nonzero entries only
x_vec = x.ravel()[vec_ind[0].long()]
del(vec_ind)
return x_vec, mat_ind
def _get_scatter_indices(self, batch_ind, atom_ind, dim_ind):
"""Compute the general indices to map the long vector of nonzero derivatives to the right atom, dimension and descriptor also in the case of non homogenoeus input.
We need to gather the derivatives with respect to the same atom coming from different descriptors to obtain the total gradient.
"""
# ====================================== INITIAL NOTE ======================================
# in the comment there's the example of the distances in a 3 atoms system with 4 batches
# i.e. 3desc*3*atom*3dim*2pairs*4batch = 72 values needs to be mappped to 3atoms*3dims*4batch = 36
# Ref_idx: tensor([ 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8,
# 9, 10, 11, 9, 10, 11, 12, 13, 14, 12, 13, 14, 15, 16, 17, 15, 16, 17,
# 18, 19, 20, 18, 19, 20, 21, 22, 23, 21, 22, 23, 24, 25, 26, 24, 25, 26,
# 27, 28, 29, 27, 28, 29, 30, 31, 32, 30, 31, 32, 33, 34, 35, 33, 34, 35])
# ==========================================================================================
# these would be the indeces in the case of uniform batches and number of atom/descriptor dependence
# it just repeats the atom index in a cycle
# e.g. [0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2,
# 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5,
# 6, 7, 8, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8]
not_shifted_indeces = atom_ind*3 + dim_ind
# get the number of elements in each batch
# e.g. [17, 18, 18, 18]
batch_elements = scatter_sum(torch.ones_like(batch_ind), batch_ind)
batch_elements[0] -= 1 # to make the later indexing consistent
# compute the pointer idxs to the beginning of each batch by summing the number of elements in each batch
# e.g. [ 0., 17., 35., 53.] NB. These are indeces!
batch_pointers = torch.Tensor([batch_elements[:i].sum() for i in range(len(batch_elements))])
del(batch_elements)
# number of entries in the scattered vector before each batch
# e.g. [ 0., 9., 18., 27.]
markers = not_shifted_indeces[batch_pointers.long()] # highest not_shifted index for each batch
del(not_shifted_indeces)
del(batch_pointers)
cumulative_markers = torch.Tensor([markers[:i+1].sum() for i in range(len(markers))]).to(batch_ind.device) # stupid sum of indeces
del(markers)
cumulative_markers += torch.unique(batch_ind) # markers correction by the number of batches
# get the index shift in the scattered vector based on the batch
# e.g. [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9,
# 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
# 18, 18, 18, 18, 18, 18, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27 ]
batch_shift = torch.gather(cumulative_markers, 0, batch_ind)
del(cumulative_markers)
# finally compute the scatter indeces by including also their shift due to the batch
shifted_indeces = atom_ind*3 + dim_ind + batch_shift
return shifted_indeces, batch_shift
def _preload_on_device(self, device):
"""Preloads the tensors used in the forward pass onto the desired device for speeding up.
This can be reverted moving everything to cpu using the method `move._to_cpu`.
"""
for attr in ["left", "batch_ind", "desc_ind", "scatter_indeces", "batch_shift"]:
if self.__dict__[attr].device != device:
print(f"[SmartDerivatives] Moving {attr} to {device}")
self.__setattr__(attr, self.__dict__[attr].to(device))
print("[SmartDerivatives] To move the preloaded tensors back to cpu, use the `SmartDerivatives.move_to_cpu` method")
self._device_preload = True
[docs]
def move_to_cpu(self):
"""Moves the tensors used in the forward pass onto the cpu."""
for attr in ["left", "batch_ind", "desc_ind", "scatter_indeces", "batch_shift"]:
print(f"[SmartDerivatives] Moving {attr} to cpu")
self.__setattr__(attr, self.__dict__[attr].to(torch.device("cpu")))
self._device_preload = False
[docs]
def forward(self, x : torch.Tensor, ref_idx : torch.Tensor = None):
"""Adds the derivatives of descriptors wrt atomic positions to the derivatives of output using the chain rule only for non-zero contributions.
Parameters
----------
x : torch.Tensor
Derivatives of output wrt to descriptors
ref_idx : torch.Tensor
Reference indeces of the pristine dataset (i.e., before splitting, shuffling..)
Returns
-------
torch.Tensor
Derivatives of the output wrt atomic positions, shape (N, n_atoms, n_dim, (n_out)))
"""
if not self._device_preload:
self._preload_on_device(device=x.device)
if ref_idx is None:
ref_idx = torch.arange(x.size(0), dtype=torch.int, device=x.device)
# =========================== SORT DATA ==========================
# order by ref_idx, this way it's easier to handle later
ref_idx, ordering = torch.sort(ref_idx)
x = x.index_select(0, ordering)
# we store the indeces to properly re-order the output
revert_ordering = torch.empty_like(ordering)
revert_ordering[ordering] = torch.arange(ordering.size(0), device=ordering.device)
batch_size = x.size(0)
# =========================== HANDLE BATCHING/SPLITTING ==========================
# If we have batches we need to get the right:
# 1) non-zero elements from self.left
# 2) scatter indeces from self.scatter_indeces and shift them to be consistent with the batch
# if there is no batching, the shift to the scatter indeces will be fake
scatter_indeces = self.scatter_indeces
shifts = torch.zeros_like(self.scatter_indeces)
# check, based on ref_idx, which batch entries are used. If there are no batches, we just get a fully true mask
max_val = max(self.batch_ind.max(), ref_idx.max()) + 1
lookup = torch.zeros(max_val, dtype=torch.bool, device=self.batch_ind.device)
lookup[ref_idx] = True
used_batch = lookup[self.batch_ind]
# if we detect batches/splits we update scatter_indeces and shifts
if not used_batch.all():
# get the corresponding scatter indeces, these allow properly recombining the contributions
scatter_indeces = self.scatter_indeces[used_batch]
# get the indeces shift due to batches, these map how many entries there were before the indeces we took
batch_shift_used = self.batch_shift[used_batch] # This is increasing but *not* sequential!
# find uniques to get markers and index used batches
uniques, indeces = torch.unique_consecutive(batch_shift_used, return_inverse=True) # This is increasing *and also* sequential!
# we need to shift the indeces of each batch so that they start after the ones of the previous batch
# Get max and min scatter index for each group
num_groups = uniques.numel()
scatter_min = torch.full((num_groups,), 1e8, device=x.device, dtype=torch.int32)
scatter_max = torch.full((num_groups,), -1e8, device=x.device, dtype=torch.int32)
scatter_min.scatter_reduce_(0, indeces, scatter_indeces.to(torch.int32), reduce='amin', include_self=True)
scatter_max.scatter_reduce_(0, indeces, scatter_indeces.to(torch.int32), reduce='amax', include_self=True)
# Compute group spans
group_spans = (scatter_max - scatter_min + 1)
# Compute exclusive cumulative sum
n_previous_entries = torch.cat([torch.zeros(1, device=x.device, dtype=torch.int64),
torch.cumsum(group_spans[:-1], dim=0)])
# get the final shifts tensor, uniques make all of them zero-based, n_previous_entries make shifts them to remove overlaps
shifts = torch.gather(uniques - n_previous_entries, 0, indeces).to(torch.int64)
# apply shift to the original scatter_indeces
scatter_indeces = scatter_indeces - shifts
# get the used part of the left elements
left = self.left[used_batch]
# get the vector with non-zero elements of derivatives of q wrt the descriptors
right = self._create_right(x=x, used_batch=used_batch)
# do element-wise product between:
# left: desc/pos derivatives matrix non-zero elements
# right: out/desc derivatives matrix non-zero elements
if left.shape == right.shape:
src = left * right
else:
src = torch.einsum("j,jr->jr", left, right)
# sum contributions from different descriptors to the same atoms
out = self._sum_desc_contributions(x=src, scatter_indeces=scatter_indeces, batch_size=batch_size)
# get the original order in case
out = out.index_select(0, revert_ordering)
return out
def _create_right(self, x : torch.Tensor, used_batch : torch.Tensor):
"""Create a big 1D tensor with the elements of the derivatives of the output
wrt the descriptors needed to propagate the derivatives to the positions.
"""
# NOTE: for batching, x here is already batched and doesn't need slicing
# make general batch idx consistent with the batch
_, used_batch_ind = torch.unique_consecutive(self.batch_ind[used_batch], return_inverse=True)
# descriptors indeces need to be corrected by the used batch
desc_ind = self.desc_ind[used_batch]
# keep only the non zero elements of right input
desc_vec = x[used_batch_ind, desc_ind]
return desc_vec
def _sum_desc_contributions(self, x : torch.Tensor, scatter_indeces : torch.Tensor, batch_size : int):
"""Sums the elements of x according to the indeces to obtain the contribution of each atom to the output due
to a single descriptor along a single space dimension"""
# single output case
if scatter_indeces.shape == x.shape:
# scatter to the right indeces
out = scatter_sum(x, scatter_indeces)
# reshape to the right shape
out = out.reshape((batch_size, self.n_atoms, 3))
# multiple outputs case
else:
out = torch.stack([scatter_sum(x[:, i], scatter_indeces) for i in range(x.shape[-1])], dim=1 )
out = out.reshape((batch_size, self.n_atoms, 3, x.shape[-1]))
return out
def compute_descriptors_derivatives(dataset,
descriptor_function,
n_atoms : int,
separate_boundary_dataset = False,
batch_size : int = None):
"""Compute the derivatives of a set of descriptors wrt input positions in a dataset for committor optimization
Parameters
----------
dataset :
DictDataset with the positions under the 'data' key
descriptor_function : torch.nn.Module
Transform module for the computation of the descriptors
n_atoms : int
Number of atoms in the system
separate_boundary_dataset : bool, optional
Switch to exculde boundary condition labeled data from the variational loss, by default False
batch_size : int
Size of batches to process data, useful for heavy computation to avoid memory overflows, if None a singel batch is used, by default None
Returns
-------
pos : torch.Tensor
Positions tensor (detached)
desc : torch.Tensor
Computed descriptors (detached)
d_desc_d_pos : torch.Tensor
Derivatives of desc wrt to pos (detached)
"""
# get and prepare positions
pos = dataset['data']
labels = dataset['labels']
cell = dataset["cell"] if "cell" in dataset.keys else None
if cell is not None:
cell = resolve_cell(cell=cell,
PBC=descriptor_function.PBC,
scaled_coords=descriptor_function.scaled_coords,
device=pos.device,
batch_size=len(pos))
pos = sanitize_positions_shape(pos=pos, n_atoms=n_atoms)[0]
# get_device
device = pos.device
# check if to separate boundary data
if separate_boundary_dataset:
mask_var = labels.squeeze() > 1
if mask_var.sum()==0:
raise(ValueError('No points left after separating boundary and variational datasets. \n If you are using only unbiased data set separate_boundary_dataset=False here and in Committor or don\'t use SmartDerivatives!!'))
else:
mask_var = torch.ones_like(labels.squeeze()).to(torch.bool)
# check batches size for calculation
if batch_size is None or batch_size == -1:
batch_size = len(pos)
else:
if batch_size <= 0:
raise ( ValueError(f"Batch size must be larger than zero if set! Found {batch_size}"))
n_batches = int(np.ceil(len(pos) / batch_size))
# compute descriptors and derivatives
# we loop over batches and compute everything only for that part of the data, inside we loop over descriptors
# we save lists and make them proper tensors later
batch_aux_der = []
batch_aux_desc = []
batch_count = 0
for batch_count in range(0, n_batches + 1):
print(f"Processing batch {batch_count}/{n_batches}", end='\r')
# get batch slicing indexes, they don't need to be all of the same size
batch_start, batch_stop = batch_count*batch_size, (batch_count+1) * batch_size
batch_mask_var = mask_var[batch_start:batch_stop] # separate_dataset mask
batch_pos = pos[batch_start:batch_stop] # batch positions
batch_pos = batch_pos[batch_mask_var, :, :] # batch_positions for variational dataset only
batch_pos.requires_grad = True
batch_cell = None
if cell is not None:
batch_cell = cell[batch_start:batch_stop]
batch_cell = batch_cell[batch_mask_var]
if len(batch_pos) > 0:
if batch_cell is None:
batch_desc = descriptor_function(batch_pos)
else:
batch_desc = descriptor_function(batch_pos, cell=batch_cell)
# we store things always on the cpu
batch_aux = []
for i in range(len(batch_desc[0])):
aux_der = torch.autograd.grad(batch_desc[:,i], batch_pos, grad_outputs=torch.ones_like(batch_desc[:,i]), retain_graph=True )[0].contiguous()
batch_aux.append(aux_der.detach())
# derivatives
batch_d_desc_d_pos = torch.stack(batch_aux, axis=2).to('cpu') # derivatives of this batch
batch_aux_der.append(batch_d_desc_d_pos.detach().cpu()) # derivatives of all batches
# descriptors
batch_aux_desc.append(batch_desc.detach().cpu())
# cleanup
del aux_der
del batch_pos
del batch_desc
gc.collect()
# to be sure, clean the gpu cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Processed all data in {n_batches} batches!")
if batch_count == 1:
d_desc_d_pos = batch_d_desc_d_pos
else:
d_desc_d_pos = torch.cat(batch_aux_der, dim=0)
# get descriptors
desc_var = torch.cat(batch_aux_desc, axis=0)
# we compute the descriptors on the whole dataset to always have all of them, no need for grads
if separate_boundary_dataset:
with torch.no_grad():
cell_not_var = None
if cell is not None:
cell_not_var = cell[~mask_var]
if cell_not_var is None:
desc_not_var = descriptor_function(pos[~mask_var])
else:
desc_not_var = descriptor_function(pos[~mask_var], cell=cell_not_var)
desc = torch.zeros((len(dataset), desc_not_var.shape[-1]))
desc[mask_var] = desc_var
desc[~mask_var] = desc_not_var
else:
desc = desc_var
# detach and move back to original device
pos = pos.detach().to(device)
desc = desc.detach().to(device)
d_desc_d_pos = d_desc_d_pos.detach().to(device)
# to be sure, clean the gpu cache
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pos, desc, d_desc_d_pos.squeeze(-1)
def create_smart_dataset(desc, dataset, separate_boundary_dataset):
"""Creates the 'smart' dataset with the descriptors and with the correct reference indeces to handle batching/splitting/shuffling"""
# check if to separate boundary data
if separate_boundary_dataset:
mask_var = dataset["labels"].squeeze() > 1
else:
mask_var = torch.ones(len(dataset)).to(torch.bool)
# create reference indeces for batching
ref_idx = torch.zeros(len(dataset), dtype=torch.int)
ref_idx[mask_var] = torch.arange(len(ref_idx[mask_var]), dtype=torch.int)
ref_idx[~mask_var] = -1
# update dataset with the descriptors as data
smart_dataset = DictDataset({'data' : desc.detach(),
'labels': torch.clone(dataset['labels']),
'weights' : torch.clone(dataset['weights']),
'ref_idx': ref_idx
})
return smart_dataset
def test_smart_derivatives():
from mlcolvar.core.transform import PairwiseDistances
from mlcolvar.core.nn import FeedForward
from mlcolvar.data import DictDataset
default_dtype = torch.get_default_dtype()
# this way tests are less prone to fail on different OS
torch.set_default_dtype(torch.float64)
# full atoms with all distances
n_atoms_1 = 10
pos_1 = torch.Tensor([[ 1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193,
-0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363, -1.4243,
-1.4093, -0.4293, 1.3530, -1.4313, -0.4183, 1.3060, 1.4750, -0.4333,
1.2970, -1.3233, -0.4643, 1.1670, -1.3253, -0.5354]])
ref_distances_1 = torch.Tensor([[0.1521, 0.2335, 0.2412, 0.3798, 0.4733, 0.4649, 0.4575, 0.5741, 0.6815,
0.1220, 0.1323, 0.2495, 0.3407, 0.3627, 0.3919, 0.4634, 0.5885, 0.2280,
0.2976, 0.3748, 0.4262, 0.4821, 0.5043, 0.6376, 0.1447, 0.2449, 0.2454,
0.2705, 0.3597, 0.4833, 0.1528, 0.1502, 0.2370, 0.2408, 0.3805, 0.2472,
0.3243, 0.3159, 0.4527, 0.1270, 0.1301, 0.2440, 0.2273, 0.2819, 0.1482]])
force_all_atoms_1 = False
slicing_pairs_1 = None
do_check_1 = True
batch_size_1 = None
# five atoms and only two distances --> useless atoms
n_atoms_2 = 5
pos_2 = torch.Tensor([[ 1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193,
-0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363]])
ref_distances_2 = torch.Tensor([[0.1521, 0.1220]])
force_all_atoms_2 = True
slicing_pairs_2 = [[0, 1], [1, 2]]
do_check_2 = True
batch_size_2 = None
# three atoms, disappearing components and batches
n_atoms_3 = 3
pos_3 = torch.Tensor([[ 1.4970, 1.3861, -0.0273,
1.4970, 1.5070, -0.1133,
-1.4473, -1.4193, -0.0553]])
ref_distances_3 = torch.Tensor([[0.1521, 0.1220]])
force_all_atoms_3 = True
slicing_pairs_3 = [[0, 1], [1, 2]]
do_check_3 = False
batch_size_3 = 3
aux_pos = [pos_1, pos_2, pos_3]
aux_ref_distances = [ref_distances_1, ref_distances_2, ref_distances_3]
aux_n_atoms = [n_atoms_1, n_atoms_2, n_atoms_3]
aux_force_all_atoms = [force_all_atoms_1, force_all_atoms_2, force_all_atoms_3]
aux_slicing_pairs = [slicing_pairs_1, slicing_pairs_2, slicing_pairs_3]
aux_do_check = [do_check_1, do_check_2, do_check_3]
aux_batch_size = [batch_size_1, batch_size_2, batch_size_3]
zipped = zip(aux_pos,
aux_ref_distances,
aux_n_atoms,
aux_force_all_atoms,
aux_slicing_pairs,
aux_do_check,
aux_batch_size)
for pos_original,ref_distances_original,n_atoms,force_all_atoms,slicing_pairs,do_check,batch_size in zipped:
for cell_mode in ["fixed", "varying"]:
pos = pos_original.repeat(4, 1)
labels = torch.arange(0, 4)
if not do_check:
labels[-1] = 0
weights = torch.ones_like(labels)
if cell_mode == 'fixed':
cell = torch.Tensor([3.0233])
dataset = DictDataset({'data' : pos, 'labels' : labels, 'weights': weights})
ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
PBC=True,
cell=cell,
scaled_coords=False,
slicing_pairs=slicing_pairs)
elif cell_mode == 'varying':
cell = torch.Tensor([3.0233]).repeat(len(pos), 1)
dataset = DictDataset({'data' : pos, 'labels' : labels, 'weights': weights, 'cell': cell})
ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
PBC=True,
cell=None,
scaled_coords=False,
slicing_pairs=slicing_pairs)
ref_distances = ref_distances_original.repeat(4, 1)
for separate_boundary_dataset in [False, True]:
if separate_boundary_dataset:
mask = [labels > 1]
else:
mask = torch.ones_like(labels, dtype=torch.bool)
pos, desc, d_desc_d_x = compute_descriptors_derivatives(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset)
if do_check:
assert(torch.allclose(desc, ref_distances, atol=1e-3))
# compute descriptors outside to have their derivatives for checks
pos.requires_grad = True
desc = ComputeDescriptors(pos, cell=cell if cell_mode == 'varying' else None)
# apply simple NN
NN = FeedForward(layers = [desc.shape[-1], 2, 1])
out = NN(desc)
# compute derivatives of out wrt input
d_out_d_x = torch.autograd.grad(out, pos, grad_outputs=torch.ones_like(out), retain_graph=True, create_graph=False )[0]
# compute derivatives of out wrt descriptors
d_out_d_d = torch.autograd.grad(out, desc, grad_outputs=torch.ones_like(out), retain_graph=True, create_graph=True )[0]
ref = torch.einsum('badx,bd->bax ',d_desc_d_x,d_out_d_d[mask])
Ref = d_out_d_x[mask]
# apply smart derivatives
smart_derivatives = SmartDerivatives(force_all_atoms=force_all_atoms)
smart_dataset = smart_derivatives.setup(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset,
descriptors_batch_size=batch_size
)
# check dataset has the right data
assert(torch.allclose(smart_dataset['data'], desc, atol=1e-3))
# check forward
right_input = d_out_d_d.squeeze(-1)
smart_out = smart_derivatives(right_input, smart_dataset['ref_idx'][mask])
# do checks
if do_check:
assert(torch.allclose(smart_out, ref))
assert(torch.allclose(smart_out, Ref))
smart_out.sum().backward()
# Test with multiple outputs
# compute some descriptors from positions --> distances
n_atoms = 10
pos_original = pos_1
ref_distances = ref_distances_1
force_all_atoms = force_all_atoms_1
batch_size = batch_size_1
for cell_mode in ["fixed", "varying"]:
pos = pos_original.repeat(4, 1)
labels = torch.arange(0, 4)
weights = torch.ones_like(labels)
if cell_mode == "fixed":
cell = torch.Tensor([3.0233])
dataset = DictDataset({"data": pos, "labels": labels, "weights": weights})
ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
PBC=True,
cell=cell,
scaled_coords=False)
elif cell_mode == "varying":
cell = torch.Tensor([3.0233]).repeat(len(pos), 1)
dataset = DictDataset({"data": pos, "labels": labels, "weights": weights, "cell": cell})
ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
PBC=True,
cell=None,
scaled_coords=False)
ref_distances_this = ref_distances.repeat(4, 1)
for separate_boundary_dataset in [False, True]:
if separate_boundary_dataset:
mask = [labels > 1]
else:
mask = torch.ones_like(labels, dtype=torch.bool)
pos, desc, d_desc_d_x = compute_descriptors_derivatives(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset)
assert torch.allclose(desc, ref_distances_this, atol=1e-3)
# compute descriptors outside to have their derivatives for checks
pos.requires_grad = True
desc = ComputeDescriptors(pos, cell=cell if cell_mode == "varying" else None)
# apply simple NN
torch.manual_seed(42)
NN = FeedForward(layers=[45, 2, 2])
out = NN(desc)
# compute derivatives of out wrt input
d_out_d_x = torch.stack([torch.autograd.grad(out[:, i], pos, grad_outputs=torch.ones_like(out[:, i]), retain_graph=True, create_graph=False )[0]
for i in range(out.shape[-1])], dim=3)
# compute derivatives of out wrt descriptors
d_out_d_d = torch.stack([torch.autograd.grad(out[:, i], desc, grad_outputs=torch.ones_like(out[:, i]), retain_graph=True, create_graph=True )[0]
for i in range(out.shape[-1])], dim=2)
ref = torch.einsum("badx,bdo->baxo ", d_desc_d_x, d_out_d_d[mask])
Ref = d_out_d_x[mask]
# apply smart derivatives
smart_derivatives = SmartDerivatives(force_all_atoms=force_all_atoms)
smart_dataset = smart_derivatives.setup(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset,
descriptors_batch_size=batch_size)
# check dataset has the right data
assert torch.allclose(smart_dataset["data"], desc, atol=1e-3)
# check forward
right_input = d_out_d_d
smart_out = smart_derivatives(right_input, smart_dataset["ref_idx"][mask])
assert torch.allclose(smart_out, ref, atol=1e-3)
assert torch.allclose(smart_out, Ref, atol=1e-3)
smart_out.sum().backward()
# reset orginal default dtype
torch.set_default_dtype(default_dtype)
def test_batched_smart_derivatives():
from mlcolvar.core.transform import PairwiseDistances
from mlcolvar.core.nn import FeedForward
from mlcolvar.data import DictDataset, DictModule
torch.manual_seed(45)
# compute some descriptors from positions --> distances
n_atoms = 3
pos_original = torch.Tensor([[ 1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193,
-0.0553 ]])
for cell_mode in ["fixed", "varying"]:
pos = pos_original.repeat(20, 1)
pos = pos + torch.randn_like(pos) * 1e-2
labels = torch.arange(0, 4).repeat(5).sort()[0]
weights = torch.ones_like(labels)
if cell_mode == "fixed":
cell = torch.Tensor([3.0233])
dataset = DictDataset({"data": pos, "labels": labels, "weights": weights})
ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
PBC=True,
cell=cell,
scaled_coords=False)
elif cell_mode == "varying":
cell = torch.Tensor([3.0233]).repeat(len(pos), 1)
dataset = DictDataset({"data": pos, "labels": labels, "weights": weights, "cell": cell})
ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
PBC=True,
cell=None,
scaled_coords=False)
for separate_boundary_dataset in [False, True]:
print(f"********************************************** {separate_boundary_dataset} **********************************************")
if separate_boundary_dataset:
mask = [labels > 1]
else:
mask = torch.ones_like(labels, dtype=torch.bool)
# apply smart derivatives
smart_derivatives = SmartDerivatives()
smart_dataset = smart_derivatives.setup(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset)
pos, desc, d_desc_d_x = compute_descriptors_derivatives(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset)
# compute descriptors outside to have their derivatives for checks
pos.requires_grad = True
desc = ComputeDescriptors(pos, cell=cell if cell_mode == "varying" else None)
# check dataset has the right data
assert torch.allclose(smart_dataset["data"], desc, atol=1e-3)
# apply simple NN
torch.manual_seed(42)
NN = FeedForward(layers=[3, 2, 1])
out = NN(desc)
# here we compute things on the whole dataset and we slice it later to get the right entries
# compute derivatives of out wrt input
d_out_d_x = torch.autograd.grad(out, pos, grad_outputs=torch.ones_like(out), retain_graph=True, create_graph=False )[0]
# compute derivatives of out wrt descriptors
d_out_d_d = torch.autograd.grad(out, desc, grad_outputs=torch.ones_like(out), retain_graph=True, create_graph=True )[0]
# get total reference values
ref = torch.einsum("badx,bd->bax ", d_desc_d_x, d_out_d_d[mask])
Ref = d_out_d_x[mask]
# test for different seeds for dataloader
for i in [42, 420]:
print(f"====================== {i} ======================")
torch.manual_seed(i)
datamodule = DictModule(smart_dataset, lengths=[0.8, 0.2], batch_size=4, shuffle=True, random_split=True)
datamodule.setup()
for loader in [datamodule.train_dataloader(), datamodule.val_dataloader()]:
for b, batch in enumerate(iter(loader)):
print(f"==================== BATCH {b} ====================")
aux_dataset = DictDataset(batch)
# we have to mimic what happens during training
if separate_boundary_dataset:
aux_mask = aux_dataset["labels"] > 1
else:
aux_mask = torch.ones_like(aux_dataset["labels"], dtype=torch.bool)
# we get the ref indeces only for the "var" part
ref_idx = torch.clone(aux_dataset["ref_idx"])[aux_mask]
# we get only the right input for the "var" part
right_input = d_out_d_d.squeeze(-1)[ref_idx]
# get smart out
smart_out = smart_derivatives(right_input, ref_idx)
# do checks with the reference value for the elements present in the batch
assert torch.allclose(smart_out, ref[ref_idx], atol=1e-3)
assert torch.allclose(smart_out, Ref[ref_idx], atol=1e-3)
smart_out.sum().backward(retain_graph=True)
def test_compute_descriptors_and_derivatives():
from mlcolvar.core.transform import PairwiseDistances
# full atoms with all distances
n_atoms = 10
pos = torch.Tensor([[ 1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193,
-0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363, -1.4243,
-1.4093, -0.4293, 1.3530, -1.4313, -0.4183, 1.3060, 1.4750, -0.4333,
1.2970, -1.3233, -0.4643, 1.1670, -1.3253, -0.5354]])
ref_distances = torch.Tensor([[0.1521, 0.2335, 0.2412, 0.3798, 0.4733, 0.4649, 0.4575, 0.5741, 0.6815,
0.1220, 0.1323, 0.2495, 0.3407, 0.3627, 0.3919, 0.4634, 0.5885, 0.2280,
0.2976, 0.3748, 0.4262, 0.4821, 0.5043, 0.6376, 0.1447, 0.2449, 0.2454,
0.2705, 0.3597, 0.4833, 0.1528, 0.1502, 0.2370, 0.2408, 0.3805, 0.2472,
0.3243, 0.3159, 0.4527, 0.1270, 0.1301, 0.2440, 0.2273, 0.2819, 0.1482]])
pos = pos.repeat(5, 1)
labels = torch.arange(0, 5)
weights = torch.ones_like(labels)
dataset = DictDataset({'data' : pos, 'labels' : labels, 'weights': weights})
cell = torch.Tensor([3.0233])
ref_distances = ref_distances.repeat(5, 1)
ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
PBC=True,
cell=cell,
scaled_coords=False,
slicing_pairs=None)
for batch_size in [2,3,5]:
for separate_boundary_dataset in [False, True]:
if separate_boundary_dataset:
mask = [labels > 1]
else:
mask = torch.ones_like(labels, dtype=torch.bool)
pos, desc, d_desc_d_x = compute_descriptors_derivatives(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset,
batch_size=batch_size)
assert(torch.allclose(desc, ref_distances, atol=1e-3))
# compute descriptors outside to have their derivatives for checks
pos.requires_grad = True
desc_ref = ComputeDescriptors(pos)
aux = []
# compute derivatives of descriptors wrt positions
for i in range(len(desc_ref[0])):
aux_der = torch.autograd.grad(desc_ref[:, i], pos, grad_outputs=torch.ones_like(desc[:,i]), retain_graph=True )[0]
aux.append(aux_der.detach().cpu())
# derivatives
d_desc_d_x_ref = torch.stack(aux, axis=2)
# checks
assert( torch.allclose(desc, desc_ref) )
assert( torch.allclose(d_desc_d_x, d_desc_d_x_ref[mask]) )
# ---------------------------------------------------------------------
# Mock check: repeated runtime cell should reproduce fixed-cell results
# ---------------------------------------------------------------------
labels = torch.arange(0, 5)
weights = torch.ones_like(labels)
pos_mock = dataset["data"].detach().clone()
dataset_fixed = DictDataset({'data': pos_mock, 'labels': labels, 'weights': weights})
repeated_cell = cell.repeat(len(pos_mock), 1)
dataset_repeated_cell = DictDataset({'data': pos_mock, 'labels': labels, 'weights': weights, 'cell': repeated_cell})
descriptor_fixed = PairwiseDistances(
n_atoms=n_atoms,
PBC=True,
cell=cell,
scaled_coords=False,
slicing_pairs=None,
)
descriptor_runtime = PairwiseDistances(
n_atoms=n_atoms,
PBC=True,
cell=None,
scaled_coords=False,
slicing_pairs=None,
)
for separate_boundary_dataset in [False, True]:
pos_fix, desc_fix, d_desc_d_x_fix = compute_descriptors_derivatives(
dataset=dataset_fixed,
descriptor_function=descriptor_fixed,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset,
batch_size=3,
)
pos_rep, desc_rep, d_desc_d_x_rep = compute_descriptors_derivatives(
dataset=dataset_repeated_cell,
descriptor_function=descriptor_runtime,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset,
batch_size=3,
)
assert torch.allclose(pos_fix, pos_rep, atol=1e-8)
assert torch.allclose(desc_fix, desc_rep, atol=1e-8)
assert torch.allclose(d_desc_d_x_fix, d_desc_d_x_rep, atol=1e-8)
def test_compute_descriptors_and_derivatives_varying_cell():
from mlcolvar.core.transform import PairwiseDistances
torch.manual_seed(42)
n_atoms = 2
n_frames = 6
# Reduced coordinates for two atoms and corresponding frame-dependent cells.
pos_reduced = torch.rand((n_frames, n_atoms, 3))
pos_reduced[:, 0, :] = 0.1
pos_reduced[:, 1, :] = 0.9
cell = torch.stack(
[
torch.tensor([2.5, 2.5, 2.5]),
torch.tensor([3.0, 3.0, 3.0]),
torch.tensor([3.5, 3.5, 3.5]),
torch.tensor([2.8, 2.8, 2.8]),
torch.tensor([3.2, 3.2, 3.2]),
torch.tensor([2.2, 2.2, 2.2]),
],
dim=0,
)
pos_abs = (pos_reduced * cell[:, None, :]).reshape(n_frames, -1)
labels = torch.arange(n_frames)
weights = torch.ones_like(labels)
dataset = DictDataset({"data": pos_abs, "labels": labels, "weights": weights, "cell": cell})
descriptor = PairwiseDistances(
n_atoms=n_atoms,
PBC=True,
cell=None,
scaled_coords=False,
slicing_pairs=[[0, 1]],
)
for separate_boundary_dataset in [False, True]:
if separate_boundary_dataset:
mask = labels > 1
else:
mask = torch.ones_like(labels, dtype=torch.bool)
pos, desc, d_desc_d_x = compute_descriptors_derivatives(
dataset=dataset,
descriptor_function=descriptor,
n_atoms=n_atoms,
separate_boundary_dataset=separate_boundary_dataset,
batch_size=2,
)
# Descriptor values should match frame-wise runtime-cell evaluation.
desc_ref = torch.cat(
[descriptor(pos[i : i + 1], cell=cell[i]) for i in range(n_frames)],
dim=0,
)
assert torch.allclose(desc, desc_ref, atol=1e-8)
# Derivatives should match direct autograd on variational subset.
pos_var = pos[mask].clone().detach().requires_grad_(True)
cell_var = cell[mask]
desc_var = descriptor(pos_var, cell=cell_var)
aux = []
for i in range(desc_var.shape[1]):
aux_der = torch.autograd.grad(
desc_var[:, i],
pos_var,
grad_outputs=torch.ones_like(desc_var[:, i]),
retain_graph=True,
)[0]
aux.append(aux_der.detach())
d_desc_d_x_ref = torch.stack(aux, axis=2)
assert torch.allclose(d_desc_d_x, d_desc_d_x_ref, atol=1e-8)
def test_train_with_smart_derivatives():
from mlcolvar.core.transform import PairwiseDistances
from mlcolvar.data import DictModule, DictDataset
from mlcolvar.cvs import Committor, Generator
from mlcolvar.cvs.committor.utils import initialize_committor_masses
from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives
from mlcolvar.explain.sensitivity import sensitivity_analysis
import lightning
# committor
# full atoms with all distances
n_atoms = 10
pos = torch.Tensor([[ 1.4970, 1.3861, -0.0273, -1.4933, 1.5070, -0.1133, -1.4473, -1.4193,
-0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363, -1.4243,
-1.4093, -0.4293, 1.3530, -1.4313, -0.4183, 1.3060, 1.4750, -0.4333,
1.2970, -1.3233, -0.4643, 1.1670, -1.3253, -0.5354]])
pos = pos.repeat(200, 1)
labels = torch.arange(0, 5, dtype=torch.float32).unsqueeze(-1).repeat(40,1).sort()[0]
weights = torch.ones_like(labels)
atomic_masses = initialize_committor_masses(atom_types=[0, 0, 1, 2, 0, 0, 0, 1, 2, 0],
masses=[12.011, 15.999, 14.007])
dataset = DictDataset({'data' : pos, 'labels' : labels, 'weights': weights})
cell = torch.Tensor([3.0233])
ComputeDescriptors = PairwiseDistances(n_atoms=n_atoms,
PBC=True,
cell=cell,
scaled_coords=False,
slicing_pairs=None)
smart_derivatives = SmartDerivatives()
smart_dataset = smart_derivatives.setup(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=True,
descriptors_batch_size=25)
datamodule = DictModule(dataset=smart_dataset, lengths=[0.8, 0.2], batch_size=80)
model = Committor(model=[45, 10, 1],
atomic_masses=atomic_masses,
alpha=1,
separate_boundary_dataset=True,
descriptors_derivatives=smart_derivatives
)
trainer = lightning.Trainer(max_epochs=3, logger=False, enable_checkpointing=False)
trainer.fit(model, datamodule)
# check that sensitivity works
sensitivity_analysis(model=model, dataset=smart_dataset)
# Generator
kT = 2.49432
# create friction tensor
#### This part should be made easier using committor utils TODO
masses = torch.Tensor([ 12.011, 12.011, 15.999, 14.0067, 12.011, 12.011, 12.011, 15.999, 14.0067, 12.011])
gamma = 1 / 0.05
friction = kT / (gamma*masses)
ref_weights = torch.ones(len(pos))
dataset = DictDataset({'data' : pos, 'labels' : labels, 'weights': ref_weights})
# --------------------------------- TRAIN MODEL ---------------------------------
# ------------ Descriptors as input + SmartDerivatives ------------
# initialize smart derivatives, we do it explicitly to test different functionalities
smart_derivatives = SmartDerivatives()
smart_dataset = smart_derivatives.setup(dataset=dataset,
descriptor_function=ComputeDescriptors,
n_atoms=n_atoms,
separate_boundary_dataset=False)
datamodule = DictModule(smart_dataset, lengths=[0.8, 0.2], random_split=True, shuffle=True)
# seed for reproducibility
torch.manual_seed(42)
options = {"nn": {"activation": "tanh"},
"optimizer": {"lr": 1e-3, "weight_decay": 1e-5}
}
model = Generator(
r=3,
layers=[45, 20, 20, 1],
eta=0.005,
alpha=0.01,
friction=friction,
descriptors_derivatives=smart_derivatives,
options=options,
)
# save outputs as a reference
X = smart_dataset["data"]
q = model(X)
trainer = lightning.Trainer(
accelerator='cpu',
callbacks=None,
max_epochs=6,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
limit_val_batches=0,
num_sanity_val_steps=0,
)
# fit
trainer.fit(model, datamodule)
# save outputs as a reference
X = smart_dataset["data"]
q = model(X)
# compute eigenfunctions
eigfuncs, eigvals, eigvecs = model.compute_eigenfunctions(dataset=smart_dataset, descriptors_derivatives=smart_derivatives)
print(eigfuncs.shape)
print(eigvals.shape)
print(eigvecs.shape)
# check that sensitivity works
sensitivity_analysis(model=model, dataset=smart_dataset)