mlcolvar.explain.graph_sensitivity.graph_node_sensitivity

class mlcolvar.explain.graph_sensitivity.graph_node_sensitivity(model, dataset, component: int = 0, device: str = 'cpu', batch_size: int = None, show_progress: bool = True, extxyz_filename: str = None)[source]

Bases:

Performs a sensitivity analysis on a GNN-based CV model using partial derivatives w.r.t. nodes’ positions. This allows us to measure which atom is most important to the CV model.

If system/environment atoms are defined in the input dataset, the average node-sensitivities are returned only for the system atom, while a aggregated sensitivities (mean and sum) are returned for environment instead.

Optionally, the results can be printed to an extxyz, where each atom is associated to its sensitivity score so that they can be visualized in a molecular viewer.

Parameters:
  • model (mlcolvar.cvs.BaseCV) – Collective variable model based on GNN

  • dataset (mlcovar.data.DictDataset) – Graph-based dataset on which to compute the sensitivity analysis

  • device (str) – Name of the device on which to perform the computation

  • batch_size (int) – Batch size used for evaluating the CV

  • show_progress (bool) – If to show the progress bar

  • extxyz_filename (str) – If provided, a extxyz file with this name is printed with the positions of the atoms in the graph dataset with the corresponding sensitivities so that they can be visualized in a molecular viewer.

Returns:

results

Results of the sensitivity analysis, containing:
  • ’node_labels’: names associated to the nodes of the graphs. If truncated graphs are used, only the system atoms

    are labeled, while the contribution from the environment atoms is aggregated with mean and sum.

  • ’avg_sensitivities’: averaged sensititivities over the given dataset. If truncated graphs are used, environment

    values are aggregated with mean and sum. The results are ordered consistently with the node labels, the enviroment values are the last two.

  • ’raw_sensitivities’: raw sensitivities per-frame, including the sensitivities relative to each atom. If truncated

    graphs are used, the sensitivities wrt different environment atoms (whose number may change for different frame) are returned. The results are ordered as the positions in the dataset.

Return type:

dictionary

See also

mlcolvar.utils.explain.sensitivity_analysis

Perform the sensitivity analysis of a feedforward model.

__init__(**kwargs)