import torch
import lightning
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward, BaseGNN, Normalization
from mlcolvar.data import DictModule
from mlcolvar.core.stats import LDA
from mlcolvar.core.loss import ReduceEigenvaluesLoss
from typing import Union, List
__all__ = ["DeepLDA"]
[docs]
class DeepLDA(BaseCV):
"""Deep Linear Discriminant Analysis (Deep-LDA) CV.
Non-linear generalization of LDA in which a feature map is learned by a neural network optimized
as to maximize the classes separation. The method is described in [1]_.
**Data**: for training it requires a DictDataset containing:
- If using descriptors as input, the keys 'data' and 'labels'
- If using graphs as input, `torch_geometric.data` with 'graph_labels' in their 'data_list'.
**Loss**: maximize LDA eigenvalues (ReduceEigenvaluesLoss)
References
----------
.. [1] L. Bonati, V. Rizzi, and M. Parrinello, "Data-driven collective variables for enhanced
sampling", JPCL 11, 2998–3004 (2020).
See also
--------
mlcolvar.core.stats.LDA
Linear Discriminant Analysis method
mlcolvar.core.loss.ReduceEigenvalueLoss
Eigenvalue reduction to a scalar quantity
"""
DEFAULT_BLOCKS = ["norm_in", "nn", "lda"]
MODEL_BLOCKS = ["nn", "lda"]
[docs]
def __init__(self,
model: Union[List[int], FeedForward, BaseGNN],
n_states: int,
options: dict = None,
**kwargs):
"""
Define a Deep Linear Discriminant Analysis (Deep-LDA) CV composed by a
neural network module and a LDA object.
By default a module standardizing the inputs is also used.
Parameters
----------
model : list or FeedForward or BaseGNN
Determines the underlying machine-learning model. One can pass:
1. A list of integers corresponding to the number of neurons per layer of a feed-forward NN.
The model Will be automatically intialized using a `mlcolvar.core.nn.feedforward.FeedForward` object.
The CV class will be initialized according to the DEFAULT_BLOCKS.
2. An externally intialized model (either `mlcolvar.core.nn.feedforward.FeedForward` or `mlcolvar.core.nn.graph.BaseGNN` object).
The CV class will be initialized according to the MODEL_BLOCKS.
n_states : int
Number of states for the training
options : dict[str, Any], optional
Options for the building blocks of the model, by default {}.
Available blocks: ['norm_in','nn','lda'] .
Set 'block_name' = None or False to turn off that block
"""
super().__init__(model=model, **kwargs)
# ======= LOSS =======
# Maximize the sum of all the LDA eigenvalues.
self.loss_fn = ReduceEigenvaluesLoss(mode="sum")
# ======= OPTIONS =======
# parse and sanitize
options = self.parse_options(options)
# Save n_states
self.n_states = n_states
# ======= BLOCKS =======
if not self._override_model:
# initialize norm_in
o = "norm_in"
if (options[o] is not False) and (options[o] is not None):
self.norm_in = Normalization(self.in_features, **options[o])
# initialize nn
o = "nn"
self.nn = FeedForward(self.layers, **options[o])
elif self._override_model:
self.nn = model
# initialize lda
o = "lda"
self.lda = LDA(self.nn.out_features, n_states, **options[o])
# regularization
self.lorentzian_reg = 40 # == 2/sw_reg, see set_regularization
self.set_regularization(sw_reg=0.05)
def forward_nn(self, x: torch.Tensor) -> torch.Tensor:
if not self._override_model:
if self.norm_in is not None:
x = self._apply_module(self.norm_in, x)
x = self._apply_module(self.nn, x)
return x
[docs]
def set_regularization(self, sw_reg=0.05, lorentzian_reg=None):
r"""
Set magnitude of regularizations for the training:
- add identity matrix multiplied by `sw_reg` to within scatter S_w.
- add lorentzian regularization to NN outputs with magnitude `lorentzian_reg`
If `lorentzian_reg` is None, set it equal to `2./sw_reg`.
Parameters
----------
sw_reg : float
Regularization value for S_w.
lorentzian_reg: float
Regularization for lorentzian on NN outputs.
Notes
-----
These regularizations are described in [1]_.
- S_w
.. math:: S_w = S_w + \mathtt{sw_reg}\ \mathbf{1}.
- Lorentzian
.. math:: \text{reg}_{lor}=\alpha \left( 1+( \mathbb{E}\left[||\mathbf{s}||^2\right]-1)^2 \right)^{-1}
"""
self.lda.sw_reg = sw_reg
if lorentzian_reg is None:
self.lorentzian_reg = 2.0 / sw_reg
else:
self.lorentzian_reg = lorentzian_reg
[docs]
def regularization_lorentzian(self, x):
"""
Compute lorentzian regularization on the CVs.
Parameters
----------
x : float
input data
"""
reg_loss = x.pow(2).sum().div(x.size(0))
reg_loss_lor = -self.lorentzian_reg / (1 + (reg_loss - 1).pow(2))
return reg_loss_lor
[docs]
def training_step(self, train_batch, batch_idx):
"""Compute and return the training loss and record metrics."""
# =================get data===================
if isinstance(self.nn, FeedForward):
x = train_batch["data"]
labels = train_batch["labels"]
elif isinstance(self.nn, BaseGNN):
x = self._setup_graph_data(train_batch)
labels = x['graph_labels'].squeeze()
# =================forward====================
h = self.forward_nn(x)
# ===================lda======================
eigvals, _ = self.lda.compute(
h, labels, save_params=True if self.training else False
)
# ===================loss=====================
loss = self.loss_fn(eigvals)
if self.lorentzian_reg > 0:
s = self.lda(h)
lorentzian_reg = self.regularization_lorentzian(s)
loss += lorentzian_reg
# ====================log=====================
name = "train" if self.training else "valid"
loss_dict = {f"{name}_loss": loss, f"{name}_lorentzian_reg": lorentzian_reg}
eig_dict = {f"{name}_eigval_{i+1}": eigvals[i] for i in range(len(eigvals))}
self.log_dict(dict(loss_dict, **eig_dict), on_step=True, on_epoch=True)
return loss
def test_deeplda(n_states=2):
from mlcolvar.data import DictDataset
in_features, out_features = 2, n_states - 1
layers = [in_features, 50, 50, out_features]
# create dataset
n_points = 500
X, y = [], []
for i in range(n_states):
X.append(
torch.randn(n_points, in_features) * (i + 1)
+ torch.Tensor([10 * i, (i - 1) * 10])
)
y.append(torch.ones(n_points) * i)
X = torch.cat(X, dim=0)
y = torch.cat(y, dim=0)
dataset = DictDataset({"data": X, "labels": y})
datamodule = DictModule(dataset, lengths=[0.8, 0.2], batch_size=n_states * n_points)
# initialize CV
opts = {
"norm_in": {"mode": "mean_std"},
"nn": {"activation": "relu"},
"lda": {},
}
print()
print('NORMAL')
print()
model = DeepLDA(layers, n_states, options=opts)
# create trainer and fit
trainer = lightning.Trainer(
max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False
)
trainer.fit(model, datamodule)
# eval
model.eval()
with torch.no_grad():
_ = model(X).numpy()
# feedforward external
print()
print('EXTERNAL')
print()
ff_model = FeedForward(layers=layers)
model = DeepLDA(ff_model, n_states)
# create trainer and fit
trainer = lightning.Trainer(
max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False
)
trainer.fit(model, datamodule)
# eval
model.eval()
with torch.no_grad():
s = model(X).numpy()
print(s)
# gnn external
print()
print('GNN')
print()
from mlcolvar.core.nn.graph.schnet import SchNetModel
from mlcolvar.data.graph.utils import create_test_graph_input
gnn_model = SchNetModel(n_out=2, cutoff=0.1, atomic_numbers=[1, 8])
model = DeepLDA(gnn_model, n_states)
datamodule = create_test_graph_input(output_type='datamodule', n_samples=200, n_states=n_states)
# create trainer and fit
trainer = lightning.Trainer(
max_epochs=1, log_every_n_steps=2, logger=False, enable_checkpointing=False, enable_model_summary=False
)
trainer.fit(model, datamodule)
traced_model = model.to_torchscript(
file_path=None, method="trace")
example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=n_states)
assert torch.allclose(model(example_input_graph_test), traced_model(example_input_graph_test))
# eval
model.eval()
with torch.no_grad():
s = model(example_input_graph_test).numpy()
print(s)