Source code for mlcolvar.core.stats.tica

"""Time-lagged independent component analysis"""

__all__ = ["TICA"]

import torch
from mlcolvar.core.stats import Stats
from mlcolvar.core.stats.utils import (
    correlation_matrix,
    cholesky_eigh,
    compute_average,
    reduced_rank_eig,
)
from mlcolvar.core.transform.tools.utils import batch_reshape
import warnings


[docs] class TICA(Stats): """ Time-lagged independent component analysis base class. """
[docs] def __init__(self, in_features, out_features=None): """ Initialize a TICA object. """ super().__init__() # save attributes self.in_features = in_features self.out_features = out_features if out_features is not None else in_features # buffers # tica eigenvectors self.register_buffer("evecs", torch.eye(in_features, self.out_features)) # mean to obtain mean free inputs self.register_buffer("mean", torch.zeros(in_features)) # init other attributes self.evals = None self.C_0 = None self.C_lag = None # Regularization self.reg_C_0 = 1e-6
[docs] def extra_repr(self) -> str: repr = f"in_features={self.in_features}, out_features={self.out_features}" return repr
[docs] def compute(self, data, weights=None, remove_average=True, save_params=False): """Perform TICA computation. Parameters ---------- data : [list of torch.Tensors] Time-lagged configurations (x_t, x_{t+lag}) weights : [list of torch.Tensors], optional Weights at time t and t+lag, by default None remove_average: bool, optional whether to make the inputs mean free, by default True save_params : bool, optional Save parameters of estimator, by default False Returns ------- torch.Tensor,torch.Tensor eigenvalues,eigenvectors """ # parse args x_t, x_lag = data w_t, w_lag = None, None if weights is not None: w_t, w_lag = weights if remove_average: x_ave = compute_average(x_t, w_t) x_t = x_t.sub(x_ave) x_lag = x_lag.sub(x_ave) C_0 = correlation_matrix(x_t, x_t, w_t) C_lag = correlation_matrix(x_t, x_lag, w_lag) evals, evecs = cholesky_eigh(C_lag, C_0, self.reg_C_0, n_eig=self.out_features) if save_params: self.evals = evals self.evecs = evecs if remove_average: self.mean = x_ave return evals, evecs
[docs] def timescales(self, lag): r"""Return implied timescales from eigenvalues and lag-time. Parameters ---------- lag : float lag-time Returns ------- its : tensor implied timescales Notes ----- If `lambda_i` are the eigenvalues and `tau` the lag-time, the implied times are given by: .. math:: t_i = - tau / \log\lambda_i """ its = -lag / torch.log(torch.abs(self.evals)) return its
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute linear combination with saved eigenvectors Parameters ---------- x: torch.Tensor input Returns ------- out : torch.Tensor output """ mean = batch_reshape(self.mean, x.size()) return torch.matmul(x.sub(mean), self.evecs)
def test_tica(): in_features = 2 X = torch.rand(100, in_features) * 100 x_t = X[:-1] x_lag = X[1:] w_t = torch.rand(len(x_t)) w_lag = w_t # direct way, compute tica function tica = TICA(in_features, out_features=2) print(tica) tica.compute([x_t, x_lag], [w_t, w_lag], save_params=True) s = tica(X) print(X.shape, "-->", s.shape) print("eigvals", tica.evals) print("timescales", tica.timescales(lag=10)) # step by step tica = TICA(in_features) C_0 = correlation_matrix(x_t, x_t) C_lag = correlation_matrix(x_t, x_lag) print(C_0.shape, C_lag.shape) evals, evecs = cholesky_eigh(C_lag, C_0) print(evals.shape, evecs.shape) print(">> batch") s = tica(X) print(X.shape, "-->", s.shape) print(">> single") X2 = X[0] s2 = tica(X2) print(X2.shape, "-->", s2.shape) if __name__ == "__main__": test_tica()