Source code for mlcolvar.core.transform.transform
import torch
__all__ = ["Transform"]
[docs]
class Transform(torch.nn.Module):
"""
Base transform class.
To implement a new transform override the forward method.
The parameters of the transform should be set either in the initialization or via the setup_from_datamodule function.
"""
[docs]
def __init__(self, in_features: int, out_features: int):
"""Transform class options.
Parameters
----------
in_features : int
Number of inputs of the transform
out_features : int
Number of outputs of the transform
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
[docs]
def setup_from_datamodule(self, datamodule):
"""
Initialize parameters based on pytorch lighting datamodule.
"""
pass
def teardown(self):
pass