import torch
import lightning
from mlcolvar.cvs import BaseCV
from mlcolvar.core import FeedForward, Normalization, BaseGNN
from mlcolvar.core.loss import MSELoss
from typing import Union, List
__all__ = ["RegressionCV"]
[docs]
class RegressionCV(BaseCV):
"""
Example of collective variable obtained with a regression task.
Combine the inputs with a neural-network and optimize it to match a target function.
**Data**: for training it requires a DictDataset containing:
- If using descriptors as input, the keys 'data', 'target' and optionally 'weights'.
- If using graphs as input, `torch_geometric.data` with either 'graph_labels' (graph-level)
or 'node_labels' (node-level) as regression target, and optionally 'weight' in 'data_list'.
**Loss**: least squares (MSELoss).
See also
--------
mlcolvar.core.loss.MSELoss
(weighted) Mean Squared Error (MSE) loss function.
"""
DEFAULT_BLOCKS = ["norm_in", "nn"]
MODEL_BLOCKS = ["nn"]
[docs]
def __init__(
self,
model: Union[List[int], FeedForward, BaseGNN],
options: dict = None,
graph_target_key: str = "graph_labels",
**kwargs,
):
"""Example of collective variable obtained with a regression task.
By default a module standardizing the inputs is used.
Parameters
----------
model : list or FeedForward or BaseGNN
Determines the underlying machine-learning model. One can pass:
1. A list of integers corresponding to the number of neurons per layer of a feed-forward NN.
The model Will be automatically intialized using a `mlcolvar.core.nn.feedforward.FeedForward` object.
The CV class will be initialized according to the DEFAULT_BLOCKS.
2. An externally intialized model (either `mlcolvar.core.nn.feedforward.FeedForward` or `mlcolvar.core.nn.graph.BaseGNN` object).
The CV class will be initialized according to the MODEL_BLOCKS.
options : dict[str, Any], optional
Options for the building blocks of the model, by default None.
Available blocks: ['norm_in', 'nn'].
Set 'block_name' = None or False to turn off that block.
graph_target_key : str, optional
Graph regression target key, either 'graph_labels' or 'node_labels', by default 'graph_labels'.
Only used when `model` is a `BaseGNN` and should match the model output level
configured through `model.pooling_operation`.
"""
super().__init__(model, **kwargs)
allowed_graph_targets = {"graph_labels", "node_labels"}
if graph_target_key not in allowed_graph_targets:
raise ValueError(
f"`graph_target_key` must be one of {allowed_graph_targets}, found '{graph_target_key}'."
)
self.graph_target_key = graph_target_key
# ======= LOSS =======
self.loss_fn = MSELoss()
# ======= OPTIONS =======
# parse and sanitize
options = self.parse_options(options)
# ======= BLOCKS =======
if not self._override_model:
# Initialize norm_in
o = "norm_in"
if (options[o] is not False) and (options[o] is not None):
self.norm_in = Normalization(self.in_features, **options[o])
# initialize NN
o = "nn"
self.nn = FeedForward(self.layers, **options[o])
elif self._override_model:
self.nn = model
[docs]
def training_step(self, train_batch, batch_idx):
"""Compute and return the training loss and record metrics."""
# =================get data===================
loss_kwargs = {}
if isinstance(self.nn, FeedForward):
x = train_batch["data"]
labels = train_batch["target"]
if "weights" in train_batch:
loss_kwargs["weights"] = train_batch["weights"]
elif isinstance(self.nn, BaseGNN):
x = self._setup_graph_data(train_batch)
if self.graph_target_key not in x:
raise KeyError(
f"Missing '{self.graph_target_key}' in graph batch. Available keys: {list(x.keys())}"
)
labels = x[self.graph_target_key]
if self.graph_target_key == "graph_labels" and "weight" in x:
loss_kwargs["weights"] = x["weight"]
# =================forward====================
y = self.forward_cv(x)
# Keep compatibility with scalar targets stored with extra singleton dims.
y, labels = self._align_regression_tensors(y, labels)
# ===================loss=====================
try:
loss = self.loss_fn(y, labels, **loss_kwargs)
except TypeError as e:
if "unexpected keyword argument 'weights'" in str(e):
loss = self.loss_fn(y, labels)
else:
raise
# ====================log=====================
name = "train" if self.training else "valid"
self.log(f"{name}_loss", loss, on_epoch=True)
return loss
@staticmethod
def _squeeze_trailing_singletons(x: torch.Tensor) -> torch.Tensor:
"""Remove trailing singleton dimensions, preserving multi-target tensors."""
while x.ndim > 1 and x.shape[-1] == 1:
x = x.squeeze(-1)
return x
def _align_regression_tensors(
self, y: torch.Tensor, labels: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Align prediction and target shapes before computing the regression loss."""
y = self._squeeze_trailing_singletons(y)
labels = self._squeeze_trailing_singletons(labels)
if y.shape != labels.shape:
raise ValueError(
"Prediction/target shape mismatch in RegressionCV: "
f"pred shape={tuple(y.shape)}, target shape={tuple(labels.shape)}. "
"Check `graph_target_key`, pooling configuration, and label tensor shapes."
)
return y, labels
def test_regression_cv():
"""
Create a synthetic dataset and test functionality of the RegressionCV class
"""
from mlcolvar.data import DictDataset, DictModule
in_features, out_features = 2, 1
layers = [in_features, 5, 10, out_features]
print()
print('NORMAL')
print()
# initialize via dictionary
options = {"nn": {"activation": "relu"}}
model = RegressionCV(model=layers, options=options)
print("----------")
print(model)
# create dataset
X = torch.randn((100, 2))
y = X.square().sum(1)
dataset = DictDataset({"data": X, "target": y})
datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=25)
# train model
model.optimizer_name = "SGD"
model.optimizer_kwargs.update(dict(lr=1e-2))
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=1, logger=None, enable_checkpointing=False
)
trainer.fit(model, datamodule)
model.eval()
# trace model
traced_model = model.to_torchscript(
file_path=None, method="trace", example_inputs=X[0]
)
assert torch.allclose(model(X), traced_model(X))
# weighted loss
print("weighted loss")
w = torch.randn((100))
dataset_weights = DictDataset({"data": X, "target": y, "weights": w})
datamodule_weights = DictModule(
dataset_weights, lengths=[0.75, 0.2, 0.05], batch_size=25
)
trainer.fit(model, datamodule_weights)
# use custom loss
print("custom loss")
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=1, logger=None, enable_checkpointing=False
)
model = RegressionCV(model=[2, 10, 10, 1])
model.loss_fn = lambda y, y_ref: (y - y_ref).abs().mean()
trainer.fit(model, datamodule)
print()
print('EXTERNAL FEEDFORWARD')
print()
ff_model = FeedForward(layers=layers)
# create model
model = RegressionCV(model=ff_model)
# create dataset
X = torch.randn((100, 2))
y = X.square().sum(1)
dataset = DictDataset({"data": X, "target": y})
datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=25)
# train model
model.optimizer_name = "SGD"
model.optimizer_kwargs.update(dict(lr=1e-2))
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=1, logger=None, enable_checkpointing=False
)
trainer.fit(model, datamodule)
model.eval()
# trace model
traced_model = model.to_torchscript(
file_path=None, method="trace", example_inputs=X[0]
)
assert torch.allclose(model(X), traced_model(X))
# weighted loss
print("weighted loss")
w = torch.randn((100))
dataset_weights = DictDataset({"data": X, "target": y, "weights": w})
datamodule_weights = DictModule(
dataset_weights, lengths=[0.75, 0.2, 0.05], batch_size=25
)
trainer.fit(model, datamodule_weights)
# use custom loss
print("custom loss")
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=1, logger=None, enable_checkpointing=False
)
model = RegressionCV(model=ff_model)
model.loss_fn = lambda y, y_ref: (y - y_ref).abs().mean()
trainer.fit(model, datamodule)
print()
print('EXTERNAL GNN')
print()
# gnn external
from mlcolvar.core.nn.graph.schnet import SchNetModel
from mlcolvar.data.graph.utils import create_test_graph_input
gnn_model = SchNetModel(n_out=1, cutoff=0.1, atomic_numbers=[1, 8])
# create model
model = RegressionCV(model=gnn_model)
datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=2)
# train model
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=1, logger=False, enable_checkpointing=False, enable_model_summary=False
)
trainer.fit(model, datamodule)
model.eval()
# trace model
traced_model = model.to_torchscript(file_path=None, method="trace")
example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2)
assert torch.allclose(model(example_input_graph_test), traced_model(example_input_graph_test))
# weighted loss
print("weighted loss")
datamodule_weights = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=2, random_weights=True)
trainer.fit(model, datamodule_weights)
# use custom loss
print("custom loss")
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=1, logger=False, enable_checkpointing=False, enable_model_summary=False
)
model = RegressionCV(model=gnn_model)
model.loss_fn = lambda y, y_ref: (y - y_ref).abs().mean()
trainer.fit(model, datamodule)
# node-level regression with GNN configured without pooling
print("node-level")
gnn_model_node = SchNetModel(n_out=1, cutoff=0.1, atomic_numbers=[1, 8], pooling_operation=None)
model = RegressionCV(model=gnn_model_node, graph_target_key="node_labels")
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=1, logger=False, enable_checkpointing=False, enable_model_summary=False
)
trainer.fit(model, datamodule)
model.eval()
traced_model = model.to_torchscript(file_path=None, method="trace")
example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2)
assert torch.allclose(model(example_input_graph_test), traced_model(example_input_graph_test))