Source code for mlcolvar.core.transform.tools.switching_functions

import torch

from mlcolvar.core.transform import Transform


__all__ = ["SwitchingFunctions"]

[docs] class SwitchingFunctions(Transform): """ Common switching functions """ SWITCH_FUNCS = ['Fermi', 'Rational']
[docs] def __init__(self, in_features: int, name: str, cutoff: float, dmax: float = None, options: dict = None): f"""Initialize switching function object Parameters ---------- name : str Name of the switching function to be used, available {",".join(self.SWITCH_FUNCS)} cutoff : float Cutoff for the swtiching functions dmax : float, optional Distance at which, if set, the switching function will be forced to be zero by strecthing it and shifting it, by default None. options : dict, optional Dictionary with all the arguments of the switching function, by default None """ super().__init__(in_features=in_features, out_features=in_features) self.name = name self.cutoff = cutoff if options is None: options = {} self.options = options self.dmax = torch.Tensor([dmax]) if dmax is not None else None if name not in self.SWITCH_FUNCS: raise NotImplementedError(f'''The switching function {name} is not implemented in this class. The available options are: {",".join(self.SWITCH_FUNCS)}. You can initialize it as a method of the SwitchingFunctions class and tell us on Github, contributions are welcome!''')
[docs] def forward(self, x: torch.Tensor): switch_function = getattr(self, f'{self.name}_switch') y = switch_function(x, self.cutoff, **self.options) if self.dmax is not None: ymax = switch_function(self.dmax.to(x.device), self.cutoff, **self.options) y = torch.div((y-ymax), (1-ymax)) return y
# ========================== define here switching functions ========================== def Fermi_switch(self, x: torch.Tensor, cutoff: float, q: float = 0.01, prefactor_cutoff: float = 1.0): y = torch.div( 1, ( 1 + torch.exp( torch.div((x - prefactor_cutoff*cutoff) , q )))) return y def Rational_switch(self, x: torch.Tensor, cutoff: float, n: int = 6, m: int = 12, eps: float = 1e-8, prefactor_cutoff: float = 1.0): y = torch.div((1 - torch.pow(x/(prefactor_cutoff*cutoff), n) + eps) , (1 - torch.pow(x/(prefactor_cutoff*cutoff), m) + 2*eps) ) return y
def test_switchingfunctions(): x = torch.Tensor([1., 2., 3.]) cutoff = 2 switch = SwitchingFunctions(in_features=len(x), name='Fermi', cutoff=cutoff) switch(x) switch = SwitchingFunctions(in_features=len(x), name='Fermi', cutoff=cutoff, options = {'q' : 0.5}) switch(x) switch = SwitchingFunctions(in_features=len(x), name='Rational', cutoff=cutoff, options = {'n' : 6, 'm' : 12}) switch(x)