Source code for mlcolvar.explain.graph_sensitivity

import numpy as np
from typing import Dict
import torch
import copy

from mlcolvar.data import DictModule
from mlcolvar.utils.plot import pbar
from mlcolvar.core.nn import BaseGNN
from mlcolvar.data.utils import save_dataset_configurations_as_extyz


__all__ = ['graph_node_sensitivity']


[docs] def graph_node_sensitivity(model, dataset, component: int = 0, device: str = 'cpu', batch_size: int = None, show_progress: bool = True, extxyz_filename: str = None, ) -> Dict[str, np.ndarray]: """Performs a sensitivity analysis on a GNN-based CV model using partial derivatives w.r.t. nodes' positions. This allows us to measure which atom is most important to the CV model. If system/environment atoms are defined in the input dataset, the average node-sensitivities are returned only for the system atom, while a aggregated sensitivities (mean and sum) are returned for environment instead. Optionally, the results can be printed to an extxyz, where each atom is associated to its sensitivity score so that they can be visualized in a molecular viewer. Parameters ---------- model: mlcolvar.cvs.BaseCV Collective variable model based on GNN dataset: mlcovar.data.DictDataset Graph-based dataset on which to compute the sensitivity analysis device: str Name of the device on which to perform the computation batch_size: Batch size used for evaluating the CV show_progress: bool If to show the progress bar extxyz_filename: str If provided, a extxyz file with this name is printed with the positions of the atoms in the graph dataset with the corresponding sensitivities so that they can be visualized in a molecular viewer. Returns ------- results: dictionary Results of the sensitivity analysis, containing: - 'node_labels': names associated to the nodes of the graphs. If truncated graphs are used, only the system atoms are labeled, while the contribution from the environment atoms is aggregated with mean and sum. - 'avg_sensitivities': averaged sensititivities over the given dataset. If truncated graphs are used, environment values are aggregated with mean and sum. The results are ordered consistently with the node labels, the enviroment values are the last two. - 'raw_sensitivities': raw sensitivities per-frame, including the sensitivities relative to each atom. If truncated graphs are used, the sensitivities wrt different environment atoms (whose number may change for different frame) are returned. The results are ordered as the positions in the dataset. See also -------- mlcolvar.utils.explain.sensitivity_analysis Perform the sensitivity analysis of a feedforward model. """ # check model is GNN-based if not isinstance(model.nn, BaseGNN): raise ValueError ( "The CV model is not based on GNN! Maybe you should use the feedforward sensitivity_analysis from mlcolvar.utils.explain.sensitivity!" ) # make user aware of behaviour for truncated graphs if dataset.metadata['is_truncated_graph']: print("[NOTE] The input dataset contains truncated graphs for which system and environment selections are provided.\nThe average node-sensitivities" + "will be returned only for the system atom, while an aggregated sensitivty is returned for environment instead") model = model.to(device) # get gradients of cv model on dataset gradients_norm, gradients_components = get_dataset_cv_gradients(model=model, dataset=dataset, component=component, batch_size=batch_size, show_progress=show_progress, progress_prefix='Getting gradients' ) # create a nice dataset with the results results = {} results['atoms_list'] = [a for a in dataset.metadata['system_atoms_names']] # append environment result entries if necessary if dataset.metadata['is_truncated_graph']: results['atoms_list'].extend(['environment_atoms_mean', 'environment_atoms_sum']) results['node_labels'] = [str(a) for a in results['atoms_list']] results['avg_sensitivities'] = gradients_norm results['raw_sensitivities'] = gradients_components # save an extxyz file with the atoms in the graph and their sensitivity score if extxyz_filename is not None: updated_dataset = copy.deepcopy(dataset) for i, d in enumerate(updated_dataset['data_list']): d['sensitivities'] = gradients_components[i] save_dataset_configurations_as_extyz(dataset=updated_dataset, file_name=extxyz_filename, extra_keys=['sensitivities']) return results
def get_dataset_cv_values(model, dataset, batch_size: int = None, show_progress: bool = True, progress_prefix: str = 'Calculating CV values' ) -> np.ndarray: """Gets the values of a CV model on a given dataset. The calculation will run on the device where the model is on. Parameters ---------- model: mlcolvar.cvs.BaseCV Collective variable model dataset: mlcovar.data.DictDataset Dataset on which to compute the sensitivity analysis batch_size: Batch size used for evaluating the CV show_progress: bool If to show the progress bar """ datamodule = DictModule( dataset=dataset, lengths=(1.0,), batch_size=batch_size, random_split=False, shuffle=False ) datamodule.setup() cv_values = [] device = next(model.parameters()).device if show_progress: items = pbar( datamodule.train_dataloader(), frequency=0.001, prefix=progress_prefix ) else: items = datamodule.train_dataloader() with torch.no_grad(): for batchs in items: outputs = model(batchs['data_list'].to(device).to_dict()) outputs = outputs.cpu().numpy() cv_values.append(outputs) return np.concatenate(cv_values) def get_dataset_cv_gradients(model, dataset, component: int = 0, batch_size: int = None, show_progress: bool = True, progress_prefix: str = 'Calculating CV gradients' ) -> np.ndarray: """Get gradients of a GNN-based CV w.r.t. node positions in a given dataset. The calculation will run on the device where the model is on. Parameters ---------- model: mlcolvar.cvs.BaseCV Collective variable model based on GNN dataset: mlcovar.data.DictDataset Graph-based dataset on which to compute the sensitivity analysis component: int Component of the CV to analyse batch_size: Batch size used for evaluating the CV show_progress: bool If to show the progress bar """ # get device device = next(model.parameters()).device # create a datamodule to initialize the dataloader datamodule = DictModule(dataset=dataset, lengths=(1.0,), batch_size=batch_size, random_split=False, shuffle=False ) datamodule.setup() if show_progress: items = pbar( datamodule.train_dataloader(), frequency=0.001, prefix=progress_prefix ) else: items = datamodule.train_dataloader() cv_value_gradients = [] cv_value_gradients_components = [] # iterate over the batches for batchs in items: # get data batch_dict = batchs['data_list'].to(device) batch_dict['positions'].requires_grad_(True) # get desired cv component cv_values = model(batch_dict) cv_values = cv_values[:, component] # compute gradients grad_outputs = [torch.ones_like(cv_values, device=device)] gradients = torch.autograd.grad( outputs=[cv_values], inputs=[batch_dict['positions']], grad_outputs=grad_outputs, retain_graph=False, create_graph=False, ) graph_sizes = batch_dict['ptr'][1:] - batch_dict['ptr'][:-1] # if we used the removed isolated atoms this will give an inhomogenous tensor! gradients = torch.split( gradients[0].detach(), graph_sizes.cpu().numpy().tolist() ) if dataset.metadata['is_truncated_graph']: # here we separate the system (constant in number) and environment (possibly changing) gradients # system grads are treated individually, env grads are aggragated with sum and mean for i,g in enumerate(gradients): # slice the gradients to get the contribution from system and environment atoms system_atoms_grads = torch.linalg.vector_norm( g[batch_dict[i]['system_masks'].squeeze()], dim=-1) env_atoms_grads = torch.linalg.vector_norm( g[torch.logical_not(batch_dict[i]['system_masks'].squeeze())], dim=-1) # add the values to the result lists # one with the components of the gradients of all atoms, ordered as the positions in the dataset cv_value_gradients_components.append(torch.linalg.vector_norm( g, dim=-1).unsqueeze(-1)) # one with the norm of the gradients for each system atom and aggregated env atoms cv_value_gradients.extend(torch.concat([system_atoms_grads, env_atoms_grads.mean(dim=-1, keepdim=True), env_atoms_grads.sum(dim=-1, keepdim=True) ]).unsqueeze(0).cpu().numpy() ) else: # here we ensure that all the gradients have the correct shape # and that each entry is at the correct index accordingly max_used_atoms = len(dataset.metadata['system_idx']) for i,g in enumerate(gradients): aux = torch.zeros((max_used_atoms, 3)) # this populates the right entries according to the orignal indexing aux[batch_dict[i]['system_names_idx'], :] = g aux = aux.unsqueeze(0) cv_value_gradients.extend(torch.linalg.vector_norm(aux, dim=-1).cpu().numpy()) cv_value_gradients_components.extend(aux.cpu().numpy()) return cv_value_gradients, cv_value_gradients_components def test_get_cv_values_graph(): import lightning from mlcolvar.cvs import DeepTDA from mlcolvar.core.nn.graph import SchNetModel from mlcolvar.data.graph.utils import create_test_graph_input # create data, we need the dataset for sensitivity analysis later dataset = create_test_graph_input(output_type='dataset', n_samples=50, n_states=2, n_atoms=3) datamodule = DictModule(dataset=dataset, lengths=[0.8, 0.2], shuffle=[1, 0]) # create model gnn_model = SchNetModel(n_out=1, cutoff=0.1, atomic_numbers=[8, 1]) model = DeepTDA( n_states=2, n_cvs=1, target_centers=[-5, 5], target_sigmas=[0.2, 0.2], model=gnn_model ) # train model trainer = lightning.Trainer( accelerator="cpu", max_epochs=2, logger=False, enable_checkpointing=False, enable_model_summary=False ) trainer.fit(model, datamodule) # do analysis cv_values = get_dataset_cv_values(model=model, dataset=dataset, batch_size=0) # print results print(cv_values) assert (torch.allclose(model(dataset.get_graph_inputs()), torch.Tensor(cv_values))) def test_graph_sensitivity(): import lightning from mlcolvar.cvs import DeepTDA from mlcolvar.core.nn.graph import SchNetModel from mlcolvar.data.graph.utils import create_test_graph_input for environment in [False, True]: # create data, we need the dataset for sensitivity analysis later dataset = create_test_graph_input(output_type='dataset', n_samples=100, n_states=2, n_atoms=3, environment=environment) datamodule = DictModule(dataset=dataset, lengths=[0.8, 0.2], shuffle=[1, 0]) # create model gnn_model = SchNetModel(n_out=1, cutoff=0.1, atomic_numbers=[8, 1]) model = DeepTDA( n_states=2, n_cvs=1, target_centers=[-5, 5], target_sigmas=[0.2, 0.2], model=gnn_model ) # train model trainer = lightning.Trainer( accelerator="cpu", max_epochs=2, logger=False, enable_checkpointing=False, enable_model_summary=False ) trainer.fit(model, datamodule) # do analysis test_sensitivity = graph_node_sensitivity(model=model, dataset=dataset, batch_size=0) # print results print(test_sensitivity)