mlcolvar.explain.sensitivity.sensitivity_analysis¶
- class mlcolvar.explain.sensitivity.sensitivity_analysis(model, dataset, std=None, feature_names=None, metric='mean_abs_val', per_class=False, plot_mode='violin', ax=None)[source]¶
Bases:
Perform a sensitivity analysis using the partial derivatives method. This allows us to measure which input features the model is most sensitive to (i.e., which quantities produce significant changes in the output).
To do this, the partial derivatives of the model with respect to each input \(x_i\) are computed over a set of N points of a \($$\{\mathbf{x}^{(j)}\}_{j=1} ^N$$\) dataset. These values, in the case where the dataset is not standardized, are multiplied by the standard deviation of the features over the dataset.
Then, an average sensitivity value \(s_i\) is computed, either as the mean absolute value (metric=`MAV`): .. math:: s_i =
rac{1}{N} sum_j left|{ rac{partial s}{partial x_i}(mathbf{x}^{(j)})} ight| sigma_i
or as the root mean square (metric=`RMS`): .. math:: s_i = sqrt{
rac{1}{N} sum_j left({ rac{partial s}{partial x_i}(mathbf{x}^{(j)})} sigma_i ight)^2 }
In alternative, one can also compute simply average, without taking the absolute values (metric=`mean`).
In all the above cases, the sensitivity values are normalized such that they sum to 1.
In case in which a labeled dataset these quantities can be computed also on the subset of the data belonging to each class.
- mlcolvar.utils.fes.plot_sensitivity
Plot the sensitivity analysis results
- modelmlcolvar.cvs.BaseCV
collective variable model
- datasetmlcovar.data.DictDataset
dataset on which to compute the sensitivity analysis.
- stdarray_like, optional
standard deviation of the features, by default it will be computed from the dataset
- feature_namesarray-like, optional
array-like with input features names, by default they will be taken from the dataset if available
- metricstr, optional
sensitivity measure (‘mean_abs_val’|’MAV’,’root_mean_square’|’RMS’,’mean’), by default ‘mean_abs_val’
- per_classbool, optional
if the dataset has labels, compute also the sensitivity per class, by default False
- plot_modestr, optional
how to visualize the results (‘violin’,’barh’,’scatter’), by default ‘violin’
- axmatplotlib.axis, optional
ax where to plot the results, by default it will be initialized
- results: dictionary
results of the sensitivity analysis, containing ‘feature_names’, the ‘sensitivity’ and the ‘gradients’ per samples, ordered according to the sensitivity.
- __init__(**kwargs)¶