import torch
import numpy as np
from bisect import bisect_left
from mlcolvar.data import DictDataset
import warnings
# optional packages
# pandas
try:
import pandas as pd
PANDAS_IS_INSTALLED = True
except ImportError:
PANDAS_IS_INSTALLED = False
# tqdm (progress bar)
try:
from tqdm import tqdm
TQDM_IS_INSTALLED = True
except ImportError:
TQDM_IS_INSTALLED = False
__all__ = ["find_timelagged_configurations", "create_timelagged_dataset"]
def closest_idx(array, value):
"""
Find index of the element of 'array' which is closest to 'value'.
The array is first converted to a np.array in case of a tensor.
Note: it does always round to the lowest one.
Parameters:
array (tensor/np.array)
value (float)
Returns:
pos (int): index of the closest value in array
"""
if type(array) is np.ndarray:
pos = bisect_left(array, value)
else:
pos = bisect_left(array.numpy(), value)
if pos == 0:
return 0
elif pos == len(array):
return -1
else:
return pos - 1
# evaluation of tprime from simulation time and logweights
def tprime_evaluation(t, logweights=None):
"""
Estimate the accelerated time if a set of (log)weights is given
Parameters
----------
t : array-like,
unbias time series,
logweights : array-like,optional
logweights to evaluate rescaled time as dt' = dt*exp(logweights)
"""
# rescale time with log-weights if given
if logweights is not None:
# compute time increment in simulation time t
dt = np.round(t[1] - t[0], 5)
# sanitize logweights
logweights = torch.Tensor(logweights)
# when the bias is not deposited the value of bias potential is minimum
logweights -= torch.max(logweights)
# note: exp(logweights/lognorm) != exp(logweights)/norm, where norm is sum_i beta V_i
""" possibilities:
1) logweights /= torch.min(logweights) -> logweights belong to [0,1]
2) pass beta as an argument, then logweights *= beta
3) tprime = dt * torch.cumsum( torch.exp( torch.logsumexp(logweights,0) ) ,0)
4) tprime = dt *torch.exp ( torch.log (torch.cumsum (torch.exp(logweights) ) ) )
"""
lognorm = torch.logsumexp(logweights, 0)
logweights /= lognorm
# compute instantaneus time increment in rescaled time t'
d_tprime = torch.exp(logweights) * dt
# calculate cumulative time t'
tprime = torch.cumsum(d_tprime, 0)
else:
tprime = t
return tprime
def find_timelagged_configurations(
x: torch.Tensor,
t: torch.Tensor,
lag_time: float,
logweights: torch.Tensor = None,
progress_bar: bool = True,
):
"""Searches for all the pairs which are distant 'lag' in time, and returns the weights associated.
If logweights are provided they will be returned both for x_t and x_t+lag (used only for `reweight_mode=weights_t` of create_time_lagged_dataset).
Parameters
----------
x : torch.Tensor
array whose columns are the descriptors and rows the time evolution
t : torch.Tensor
array with the simulation time
lag_time : float
lag-time
logweights : torch.Tensor, optional
logweights to be returned
progress_bar : bool, optional
display progress bar with tqdm (if installed), by default True
Returns
-------
x_t: torch.Tensor
descriptors at time t
x_lag: torch.Tensor
descriptors at time t+lag
w_t: torch.Tensor
weights at time t
w_lag: torch.Tensor
weights at time t+lag
"""
# define lists
x_t = []
x_lag = []
w_t = []
w_lag = []
# find maximum time idx
idx_end = closest_idx(t, t[-1] - lag_time)
# start_j = 0
N = len(t)
def progress(iter, progress_bar=progress_bar):
if progress_bar and TQDM_IS_INSTALLED:
return tqdm(iter)
else:
warnings.warn(
"Monitoring the progress for the search of time-lagged configurations with a progress_bar requires `tqdm`."
)
return iter
# sanitize logweights if given
calculate_weights = True
if logweights is not None:
calculate_weights = False
if len(logweights) != len(x):
raise ValueError(
f"Length of logweights ({len(logweights)}) is different from length of data ({len(x)})."
)
logweights = torch.Tensor(logweights)
weights = torch.exp(logweights)
# loop over time array and find pairs which are far away by lag_time
for i in progress(range(idx_end)):
stop_condition = lag_time + t[i + 1]
n_j = 0
for j in range(i, N):
if (t[j] < stop_condition) and (t[j + 1] > t[i] + lag_time):
x_t.append(x[i])
x_lag.append(x[j])
deltaTau = min(t[i + 1] + lag_time, t[j + 1]) - max(
t[i] + lag_time, t[j]
)
if calculate_weights:
w_lag.append(deltaTau)
else:
w_lag.append(weights[i])
# if n_j == 0: #assign j as the starting point for the next loop
# start_j = j
n_j += 1
elif t[j] > stop_condition:
break
for k in range(n_j):
if calculate_weights:
w_t.append((t[i + 1] - t[i]) / float(n_j))
else:
if n_j > 1:
print(n_j)
w_t.append(weights[i] / float(n_j))
x_t = torch.stack(x_t) if type(x) == torch.Tensor else torch.Tensor(x_t)
x_lag = torch.stack(x_lag) if type(x) == torch.Tensor else torch.Tensor(x_lag)
w_t = torch.Tensor(w_t)
w_lag = torch.Tensor(w_lag)
return x_t, x_lag, w_t, w_lag
[docs]
def create_timelagged_dataset(
X: torch.Tensor,
t: torch.Tensor = None,
lag_time: float = 1,
reweight_mode: str = None,
logweights: torch.Tensor = None,
tprime: torch.Tensor = None,
interval: list = None,
progress_bar: bool = False,
walker: torch.Tensor = None,
):
"""
Create a DictDataset of time-lagged configurations.
In case of biased simulations the reweight can be performed in two different ways (``reweight_mode``):
1. ``rescale_time`` : the search for time-lagged pairs is performed in the accelerated time (dt' = dt*exp(logweights)), as described in [1]_ .
2. ``weights_t`` : the weight of each pair of configurations (t,t+lag_time) depends only on time t (logweights(t)), as done in [2]_ , [3]_ .
If reweighting is None and tprime is given the `rescale_time` mode is used. If instead only the logweights are specified the user needs to choose the reweighting mode.
References
----------
.. [1] Y. I. Yang and M. Parrinello, “Refining collective coordinates and improving free energy
representation in variational enhanced sampling,” JCTC 14, 2889-2894 (2018).
.. [2] J. McCarty and M. Parrinello, "A variational conformational dynamics approach to the selection
of collective variables in meta- dynamics,” JCP 147, 204109 (2017).
.. [3] H. Wu, et al. "Variational Koopman models: Slow collective variables and molecular kinetics
from short off-equilibrium simulations." JCP 146.15 (2017).
Parameters
----------
X : array-like
input descriptors
t : array-like, optional
time series, by default np.arange(len(X))
reweight_mode: str, optional
how to do the reweighting, see documentation, by default none
lag_time: float, optional
lag between configurations, by default = 10
logweights : array-like,optional
logweight of each configuration (typically beta*bias)
tprime : array-like,optional
rescaled time estimated from the simulation. If not given and `reweighting_mode`=`rescale_time` then `tprime_evaluation(t,logweights)` is used
interval : list or np.array or tuple, optional
Range for slicing the returned dataset. Useful to work with batches of same sizes. Recall that with different lag_times one obtains different datasets, with different lengths
progress_bar: bool
Display progress bar with tqdm
walker : array-like, optional
Identifier of the trajectory (walker) to which each configuration belongs.
This can only be used when `reweight_mode` is set to `weights_t`.
Returns
-------
dataset: DictDataset
Dataset with keys 'data', 'data_lag' (data at time t and t+lag), 'weights', 'weights_lag' (weights at time t and t+lag).
"""
if PANDAS_IS_INSTALLED:
# check if dataframe
if type(X) == pd.core.frame.DataFrame:
X = X.values
if type(t) == pd.core.frame.DataFrame:
t = t.values
# check reweigthing mode if logweights are given:
# 1) if rescaled time tprime is given
if tprime is not None:
if reweight_mode is None:
reweight_mode = "rescale_time"
elif reweight_mode != "rescale_time":
raise ValueError(
"The `reweighting_mode` needs to be equal to `rescale_time`, and not {reweight_mode} if the rescale time `tprime` is given."
)
# 2) if logweights are given
elif logweights is not None:
if reweight_mode is None:
reweight_mode = "rescale_time"
# TODO output warning or error if mode not specified?
# warnings.warn('`reweight_mode` not specified, setting it to `rescale_time`.')
# define time if not given
if t is None:
t = torch.arange(0, len(X))
else:
if len(t) != len(X):
raise ValueError(
f"The length of t ({len(t)}) is different from the one of X ({len(X)}) "
)
if walker is not None:
if reweight_mode == "rescale_time":
raise ValueError(
"The `walker` argument is not compatible with `reweight_mode='rescale_time'`."
)
if len(walker) != len(X):
raise ValueError(
f"The length of walker ({len(walker)}) is different from the one of X ({len(X)}) "
)
# define tprime if not given:
if reweight_mode == "rescale_time":
if tprime is None:
tprime = tprime_evaluation(t, logweights)
else:
tprime = t
# =========================
# Fast slicing shortcut
# =========================
if reweight_mode is None or reweight_mode == "weights_t":
dt = float(t[1] - t[0]) if len(t) > 1 else 1.0
lag_steps = int(round(lag_time / dt))
if lag_steps < 1:
raise ValueError("lag_time too small.")
if lag_steps >= len(X):
raise ValueError("lag_time too large.")
# pairs
x_t = X[:-lag_steps]
x_lag = X[lag_steps:]
# weights
if reweight_mode is None:
w_t = torch.ones(len(x_t))
w_lag = torch.ones(len(x_lag))
else:
logweights = torch.tensor(logweights, dtype=torch.float32)
weights = torch.exp(logweights)
w_t = weights[:-lag_steps]
w_lag = weights[lag_steps:]
# walker
if walker is not None:
walker = torch.as_tensor(walker)
valid = walker[:-lag_steps] == walker[lag_steps:]
x_t = x_t[valid]
x_lag = x_lag[valid]
w_t = w_t[valid]
w_lag = w_lag[valid]
# =========================
# Full search mode (rescale_time)
# =========================
elif reweight_mode == 'rescale_time':
x_t, x_lag, w_t, w_lag = find_timelagged_configurations(
X,
tprime,
lag_time=lag_time,
logweights=logweights if reweight_mode == "weights_t" else None,
progress_bar=progress_bar,
)
# =========================
# Invalid mode
# =========================
else:
raise ValueError(
f"Unknown reweight_mode '{reweight_mode}'. "
"Supported modes are: None, 'weights_t', 'rescale_time'."
)
# return only a slice of the data (N. Pedrani)
if interval is not None:
# convert to a list
data = list(x_t, x_lag, w_t, w_lag)
# assert dimension of interval
assert len(interval) == 2
# modifies the content of data by slicing
for i in range(len(data)):
data[i] = data[i][interval[0] : interval[1]]
x_t, x_lag, w_t, w_lag = data
dataset = DictDataset(
{"data": x_t, "data_lag": x_lag, "weights": w_t, "weights_lag": w_lag}
)
return dataset
def test_create_timelagged_dataset():
in_features = 2
n_points = 100
lag_time = 5
X = torch.rand(n_points, in_features) * 100
# unbiased case
t = np.arange(n_points)
dataset = create_timelagged_dataset(X, t, lag_time=lag_time, walker=None)
assert len(dataset) == n_points - lag_time
# reweight mode rescale_time (default)
logweights = np.random.rand(n_points)
dataset = create_timelagged_dataset(X, t, lag_time=lag_time, logweights=logweights)
assert len(dataset) > 0
# reweight mode weights_t
logweights = np.random.rand(n_points)
dataset = create_timelagged_dataset(
X, t, logweights=logweights, lag_time=lag_time, reweight_mode="weights_t"
)
assert len(dataset) == n_points - lag_time
# unbiased multi-walker case
walker = np.array([0] * (n_points // 2) + [1] * (n_points // 2))
dataset = create_timelagged_dataset(
X, t, lag_time=lag_time, walker=walker
)
assert len(dataset) == n_points - 2 * lag_time