Source code for mlcolvar.utils.plot

##########################################################################
## FESSA COLOR PALETTE
##########################################################################
#  https://github.com/luigibonati/fessa-color-palette/blob/master/fessa.py

import mlcolvar
from matplotlib.colors import LinearSegmentedColormap, ColorConverter
from matplotlib import patches as mpatches
import matplotlib as mpl

__all__ = ["paletteFessa", "paletteCortina","plot_metrics","plot_features_distribution"]

# Fessa colormap
paletteFessa = [
    "#1F3B73",  # dark-blue
    "#2F9294",  # green-blue
    "#50B28D",  # green
    "#A7D655",  # pisello
    "#FFE03E",  # yellow
    "#FFA955",  # orange
    "#D6573B",  # red
]

cm_fessa = LinearSegmentedColormap.from_list("fessa", paletteFessa)
mpl.colormaps.register(cmap=cm_fessa)
mpl.colormaps.register(cmap=cm_fessa.reversed())

for i in range(len(paletteFessa)):
    ColorConverter.colors[f"fessa{i}"] = paletteFessa[i]

### To set it as default
# import fessa
# plt.set_cmap('fessa')
### or the reversed one
# plt.set_cmap('fessa_r')
### For contour plots
# plt.contourf(X, Y, Z, cmap='fessa')
### For standard plots
# plt.plot(x, y, color='fessa0')


# Cortina1980 colormap
paletteCortina = [
    [0.0, 0.0, 0.803921568627451, 1],  # mediumblue
    [0.4823529411764706, 0.40784313725490196, 0.9333333333333333, 1],  # mediumslateblue
    [0.0, 0.9803921568627451, 0.6039215686274509, 1],  # mediumspringgreen
    [0.23529411764705882, 0.7019607843137254, 0.44313725490196076, 1],  # mediumseagreen
    [0.8588235294117647, 0.4392156862745098, 0.5764705882352941, 1],  # palevioletred
    [0.7803921568627451, 0.08235294117647059, 0.5215686274509804, 1],  # mediumvioletred
]

cm_cortina = LinearSegmentedColormap.from_list("cortina80", paletteCortina)
mpl.colormaps.register(cmap=cm_cortina)
mpl.colormaps.register(cmap=cm_cortina.reversed())

##########################################################################
## HELPER FUNCTIONS FOR 2D SYSTEMS
##########################################################################

import matplotlib.pyplot as plt
import numpy as np
import torch


def muller_brown_potential(x, y):
    """Muller-Brown analytical potential"""
    prefactor = 0.15
    A = (-200, -100, -170, 15)
    a = (-1, -1, -6.5, 0.7)
    b = (0, 0, 11, 0.6)
    c = (-10, -10, -6.5, 0.7)
    x0 = (1, 0, -0.5, -1)
    y0 = (0, 0.5, 1.5, 1)
    offset = -146.7

    v = -prefactor * offset
    for i in range(4):
        v += (
            prefactor
            * A[i]
            * np.exp(
                a[i] * (x - x0[i]) ** 2
                + b[i] * (x - x0[i]) * (y - y0[i])
                + c[i] * (y - y0[i]) ** 2
            )
        )
    return v


def muller_brown_mfep():
    mfep = np.loadtxt(
        mlcolvar.__path__[0]
        + "/../docs/notebooks/tutorials/data/muller-brown/mfep.txt",
        usecols=(0, 1),
    )
    return mfep


def muller_brown_potential_three_states(x, y):
    """Muller-Brown analytical potential"""
    prefactor = 0.15
    A = (-280, -170, -170, 15)
    a = (-15, -1, -6.5, 0.7)
    b = (0, 0, 11, 0.6)
    c = (-10, -10, -6.5, 0.7)
    x0 = (1, 0.2, -0.5, -1)
    y0 = (0, 0.5, 1.5, 1)
    offset = -146.7

    v = -prefactor * offset
    for i in range(4):
        v += (
            prefactor
            * A[i]
            * np.exp(
                a[i] * (x - x0[i]) ** 2
                + b[i] * (x - x0[i]) * (y - y0[i])
                + c[i] * (y - y0[i]) ** 2
            )
        )
    return v


def muller_brown_three_states_mfep():
    mfep = np.loadtxt(
        mlcolvar.__path__[0]
        + "/../docs/notebooks/tutorials/data/muller-brown-3states/mfep.txt",
        usecols=(0, 1),
    )
    return mfep


def plot_isolines_2D(
    function,
    component=None,
    limits=((-1.8, 1.2), (-0.4, 2.1)),
    num_points=(100, 100),
    mode="contourf",
    levels=12,
    cmap=None,
    colorbar=None,
    max_value=None,
    ax=None,
    allow_grad=False,
    **kwargs,
):
    """Plot isolines of a function/model in a 2D space."""

    # Define grid where to evaluate function
    if type(num_points) == int:
        num_points = (num_points, num_points)
    xx = np.linspace(limits[0][0], limits[0][1], num_points[0])
    yy = np.linspace(limits[1][0], limits[1][1], num_points[1])
    xv, yv = np.meshgrid(xx, yy)

    # if torch module
    if isinstance(function, torch.nn.Module):
        z = np.zeros_like(xv)
        for i in range(num_points[0]):
            for j in range(num_points[1]):
                xy = torch.Tensor([xv[i, j], yv[i, j]])
                if allow_grad:
                    s = function(xy.unsqueeze(0)).squeeze(0).detach().numpy()
                else:
                    with torch.no_grad():
                        train_mode = function.training
                        function.eval()
                        s = function(xy.unsqueeze(0)).squeeze(0).numpy()
                        function.training = train_mode
                if component is not None:
                    s = s[component]
                z[i, j] = np.squeeze(s)
    # else apply function directly to grid points
    else:
        z = function(xv, yv)

    if max_value is not None:
        z[z > max_value] = max_value

    # Setup plot
    return_axs = False
    if ax is None:
        return_axs = True
        _, ax = plt.subplots(figsize=(6, 4.0), dpi=100)

    # Color scheme
    if cmap is None:
        if mode == "contourf":
            cmap = "fessa"
        elif mode == "contour":
            if 'colors' not in kwargs:
                cmap = "Greys_r"

    # Colorbar
    if colorbar is None:
        if mode == "contourf":
            colorbar = True
        elif mode == "contour":
            colorbar = False

    # Plot
    if mode == "contourf":
        pp = ax.contourf(xv, yv, z, levels=levels, cmap=cmap, **kwargs)
        if colorbar:
            plt.colorbar(pp, ax=ax)
    else:
        pp = ax.contour(xv, yv, z, levels=levels, cmap=cmap, **kwargs)

    if return_axs:
        return ax
    else:
        return None


[docs] def plot_metrics( metrics, keys=["train_loss_epoch", "valid_loss"], x=None, # 'epoch' labels=None, # ['Train','Valid'], linestyles=None, # ['-','--'] colors=None, # ['fessa0','fessa1'] xlabel="Epoch", ylabel="Loss", title="Learning curves", yscale=None, ax=None, ): """Plot logged metrics.""" # Setup axis return_axs = False if ax is None: return_axs = True _, ax = plt.subplots(figsize=(5, 4), dpi=100) # Plot metrics auto_x = True if x is None else False for i, key in enumerate(keys): y = metrics[key] lstyle = linestyles[i] if linestyles is not None else None label = labels[i] if labels is not None else key color = colors[i] if colors is not None else None if auto_x: x = np.arange(len(y)) ax.plot(x, y, linestyle=lstyle, label=label, color=color) # Plot settings if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) if title is not None: ax.set_title(title) if yscale is not None: ax.set_yscale(yscale) ax.legend(ncol=1, frameon=False) if return_axs: return ax else: return None
[docs] def plot_features_distribution(dataset, features, titles=None, axs=None): """Plot distribution of the given features. Parameters ---------- dataset : DictDataset dataset features : list list of features names titles : list,optional list titles to be displayed, by default None axs : optional matplotlib axs, by default None """ if isinstance(features,dict): raise TypeError('features should be a list of feature names, not a dictionary') n_feat = len(features) if axs is None: if n_feat <=5 : fig, axs = plt.subplots(1,n_feat,figsize=(3*n_feat+1,3)) else: fig, axs = plt.subplots(n_feat, 1, figsize=(3, 3*n_feat)) plt.suptitle('Features distribution') init_ax = True else: if n_feat != len(axs): raise ValueError(f'Number of features ({len(features)}) != number of axis ({len(axs)})') init_ax = False axs[0].set_ylabel('Distribution') if "labels" in dataset.keys: labels = sorted(dataset['labels'].unique().numpy()) for l in labels: id_l = np.argwhere(dataset['labels'] == l)[0,:] data_label = dataset['data'][id_l,:] for i,feat in enumerate(features): ax = axs[i] id = np.argwhere(dataset.feature_names == feat)[0] x = data_label[:,id].numpy() ax.hist(x,bins=50,label=f"State {int(l)}",histtype='step') ax.set_yticks([]) ax.set_xlabel(feat) if i == 0: if titles is not None: ax.legend(title=titles[i], loc='upper center', framealpha=0.8, edgecolor='white') else: ax.legend(loc='upper center', framealpha=0.8, edgecolor='white') else: if titles is not None: ax.legend([],[],title=titles[i],loc='upper center', framealpha=0.8, edgecolor='white') else: for i,feat in enumerate(features): ax = axs[i] id = np.argwhere(dataset.feature_names == feat)[0] data = dataset['data'] x = data[:,id].numpy() ax.hist(x,bins=100,) ax.set_yticks([]) ax.legend([],[],title=feat,loc='upper center',frameon=False)
def test_utils_plot(): import matplotlib x = np.linspace(-1.5, 1.5) y = np.linspace(-0.5, 2.5) mp = muller_brown_potential(x, y) mp = muller_brown_potential_three_states(x, y) pal = paletteFessa pal = paletteCortina cmap = matplotlib.colors.Colormap("fessa", 2) cmap = matplotlib.colors.Colormap("fessa_r", 2) cmap = matplotlib.colors.Colormap("cortina80", 2) cmap = matplotlib.colors.Colormap("cortina80_r", 2)