mlcolvar.explain.sensitivity.sensitivity_analysis

class mlcolvar.explain.sensitivity.sensitivity_analysis(model, dataset, std=None, feature_names=None, metric='mean_abs_val', per_class=False, plot_mode='violin', ax=None)[source]

Bases:

Perform a sensitivity analysis using the partial derivatives method. This allows us to measure which input features the model is most sensitive to (i.e., which quantities produce significant changes in the output).

To do this, the partial derivatives of the model with respect to each input \(x_i\) are computed over a set of N points of a \($$\{\mathbf{x}^{(j)}\}_{j=1} ^N$$\) dataset. These values, in the case where the dataset is not standardized, are multiplied by the standard deviation of the features over the dataset.

Then, an average sensitivity value \(s_i\) is computed, either as the mean absolute value (metric=`MAV`): .. math:: s_i =

rac{1}{N} sum_j left|{ rac{partial s}{partial x_i}(mathbf{x}^{(j)})} ight| sigma_i

or as the root mean square (metric=`RMS`): .. math:: s_i = sqrt{

rac{1}{N} sum_j left({ rac{partial s}{partial x_i}(mathbf{x}^{(j)})} sigma_i ight)^2 }

In alternative, one can also compute simply average, without taking the absolute values (metric=`mean`).

In all the above cases, the sensitivity values are normalized such that they sum to 1.

In case in which a labeled dataset these quantities can be computed also on the subset of the data belonging to each class.

mlcolvar.utils.fes.plot_sensitivity

Plot the sensitivity analysis results

modelmlcolvar.cvs.BaseCV

collective variable model

datasetmlcovar.data.DictDataset

dataset on which to compute the sensitivity analysis.

stdarray_like, optional

standard deviation of the features, by default it will be computed from the dataset

feature_namesarray-like, optional

array-like with input features names, by default they will be taken from the dataset if available

metricstr, optional

sensitivity measure (‘mean_abs_val’|’MAV’,’root_mean_square’|’RMS’,’mean’), by default ‘mean_abs_val’

per_classbool, optional

if the dataset has labels, compute also the sensitivity per class, by default False

plot_modestr, optional

how to visualize the results (‘violin’,’barh’,’scatter’), by default ‘violin’

axmatplotlib.axis, optional

ax where to plot the results, by default it will be initialized

results: dictionary

results of the sensitivity analysis, containing ‘feature_names’, the ‘sensitivity’ and the ‘gradients’ per samples, ordered according to the sensitivity.

__init__(**kwargs)