import math
import torch
from torch import nn
from torch_geometric.nn.aggr import AttentionalAggregation
from torch_geometric.nn import MessagePassing
from mlcolvar.data import DictDataset
from mlcolvar.core.nn.utils import Shifted_Softplus
from mlcolvar.core.nn.graph.gnn import BaseGNN
from typing import List, Dict, Optional
"""
The SchNet components. This module is adapted from the pgy package:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py
"""
__all__ = ["SchNetModel", "InteractionBlock"]
[docs]
class SchNetModel(BaseGNN):
"""
The SchNet [1] model. This implementation is adapted from torch_geometric:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py
References
----------
.. [1] Schütt, Kristof T., et al. "Schnet–a deep learning architecture for
molecules and materials." The Journal of Chemical Physics 148.24
(2018).
"""
[docs]
def __init__(
self,
n_out: int,
dataset_for_initialization: DictDataset = None,
pooling_operation : str = 'mean',
n_bases: int = 16,
n_layers: int = 2,
n_filters: int = 16,
n_hidden_channels: int = 16,
aggr: str = 'mean',
w_out_after_pool: bool = False,
**kwargs
) -> None:
"""
Parameters
----------
n_out : int
Size of the output node features.
dataset_for_initialization : DictDataset, optional
Dataset containing the graphs on which the gnn model will be applied.
This is used to initialize and register the cutoff, buffer, long_range_cutoff and atomic_numbers from the dataset metadata.
This is the preferred way to initialize the gnn model, as it ensures consistency between the model and the dataset.
As an alternative this can be set to None and the cutoff, buffer, long_range_cutoff and atomic_numbers can be provided as kwargs.
pooling_operation : str
Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default 'mean'
n_bases : int, optional
Size of the basis set used for the embedding, by default 16
n_layers : int, optional
Number of the graph convolution layers, by default 2
n_filters : int, optional
Number of filters, by default 16
n_hidden_channels : int, optional
Size of hidden embeddings, by default 16
aggr : str, optional
Type of the GNN aggregation function, by default 'mean'
Possible choices are: 'mean', 'sum', 'max', 'min', 'mul',
'attention'/'attentional' (shared attention gate across all layers),
'attention_separate'/'attentional_separate' (Independent attention gate for each layer).
w_out_after_pool : bool, optional
Whether to apply the last linear transformation form hidden to output channels after the pooling sum, by default False
"""
super().__init__(
n_out=n_out,
dataset_for_initialization=dataset_for_initialization,
pooling_operation=pooling_operation,
n_bases=n_bases,
n_polynomials=0,
basis_type='gaussian',
**kwargs
)
# transforms embedding into hidden channels
self.W_v = nn.Linear(
in_features=len(self.atomic_numbers),
out_features=n_hidden_channels,
bias=False
)
# attentional aggregation
if aggr in ['attention', 'attentional']:
self.attention_gate = nn.Sequential(
nn.Linear(n_filters, n_filters // 2),
Shifted_Softplus(),
nn.Linear(n_filters // 2, 1)
)
aggr = [
AttentionalAggregation(self.attention_gate)
] * n_layers
elif aggr in ['attention_separate', 'attentional_separate']:
self.attention_gate = nn.ModuleList([
nn.Sequential(
nn.Linear(n_filters, n_filters // 2),
Shifted_Softplus(),
nn.Linear(n_filters // 2, 1)
)
for _ in range(n_layers)
])
aggr = [
AttentionalAggregation(self.attention_gate[i])
for i in range(n_layers)
]
else:
self.attention_gate = None
aggr = [aggr] * n_layers
# initialize layers with interaction blocks
self.layers = nn.ModuleList([
InteractionBlock(
n_hidden_channels, n_bases, n_filters, self.cutoff, self.long_range_cutoff, aggr[i]
)
for i in range(n_layers)
])
# transforms hidden channels into output channels
self.W_out = nn.ModuleList([
nn.Linear(n_hidden_channels, n_hidden_channels // 2),
Shifted_Softplus(),
nn.Linear(n_hidden_channels // 2, n_out)
])
self._w_out_after_pool = w_out_after_pool
self.reset_parameters()
[docs]
def reset_parameters(self) -> None:
"""
Resets all learnable parameters of the module.
"""
self.W_v.reset_parameters()
for layer in self.layers:
layer.reset_parameters()
nn.init.xavier_uniform_(self.W_out[0].weight)
self.W_out[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.W_out[2].weight)
self.W_out[2].bias.data.fill_(0)
if isinstance(self.attention_gate, torch.nn.Sequential):
nn.init.xavier_uniform_(self.attention_gate[0].weight)
self.attention_gate[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.attention_gate[2].weight)
self.attention_gate[2].bias.data.fill_(0)
elif isinstance(self.attention_gate, torch.nn.ModuleList):
for gate in self.attention_gate:
nn.init.xavier_uniform_(gate[0].weight)
gate[0].bias.data.fill_(0)
nn.init.xavier_uniform_(gate[2].weight)
gate[2].bias.data.fill_(0)
[docs]
def forward(
self, data: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""
The forward pass.
Parameters
----------
data: Dict[str, torch.Tensor]
The data dict. Usually came from the `to_dict` method of a
`torch_geometric.data.Batch` object.
"""
# embed edges and node attrs
h_E = self.embed_edge(data)
h_V = self.W_v(data['node_attrs'])
# update through layers
mask = data.get("edge_masks_lr", None)
for layer in self.layers:
h_V = h_V + layer(
h_V,
data["edge_index"],
h_E[0],
h_E[1],
mask,
)
# in case the last linear transformation is performed BEFORE pooling
if not self._w_out_after_pool:
for w in self.W_out:
h_V = w(h_V)
out = h_V
# pooling is controlled by `self.pooling_operation` (mean/sum/None)
out = self.pooling(input=out, data=data)
# in case the last linear transformation is performed AFTER pooling
if self._w_out_after_pool:
for w in self.W_out:
out = w(out)
return out
class InteractionBlock(nn.Module):
def __init__(
self,
hidden_channels: int,
num_gaussians: int,
num_filters: int,
cutoff: float,
long_range_cutoff: float = -1.0,
aggr: str = 'mean'
) -> None:
"""SchNet interaction block
Parameters
----------
hidden_channels : int
Size of hidden embeddings
num_gaussians : int
Number of Gaussians for the embedding
num_filters : int
Number of filters
cutoff : float
Radial cutoff
long_range_cutoff : float
Cutoff radius for the long-range edges defined on subsystem atoms.
If negative, no long-range interactions are considered, by default -1.0
aggr : str, optional
Aggregation function, by default 'mean'
"""
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(num_gaussians, num_filters),
Shifted_Softplus(),
nn.Linear(num_filters, num_filters),
)
if long_range_cutoff > 0:
self.mlp_lr = nn.Sequential(
nn.Linear(num_gaussians, num_filters),
Shifted_Softplus(),
nn.Linear(num_filters, num_filters),
)
else:
self.mlp_lr = None
self.conv = CFConv(
hidden_channels,
hidden_channels,
num_filters,
self.mlp,
cutoff,
self.mlp_lr,
long_range_cutoff,
aggr
)
self.act = Shifted_Softplus()
self.lin = nn.Linear(hidden_channels, hidden_channels)
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Resets all learnable parameters of the module.
"""
nn.init.xavier_uniform_(self.mlp[0].weight)
self.mlp[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.mlp[2].weight)
self.mlp[2].bias.data.fill_(0)
self.conv.reset_parameters()
nn.init.xavier_uniform_(self.lin.weight)
self.lin.bias.data.fill_(0)
if self.mlp_lr is not None:
nn.init.xavier_uniform_(self.mlp_lr[0].weight)
self.mlp_lr[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.mlp_lr[2].weight)
self.mlp_lr[2].bias.data.fill_(0)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_weight: torch.Tensor,
edge_attr: torch.Tensor,
edge_masks_lr: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = self.conv(x, edge_index, edge_weight, edge_attr, edge_masks_lr)
x = self.act(x)
x = self.lin(x)
return x
class CFConv(MessagePassing):
"""Continuos-filter convolution from SchNet"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_filters: int,
network: nn.Sequential,
cutoff: float,
network_lr: Optional[nn.Sequential] = None,
long_range_cutoff: float = -1.0,
aggr: str = 'mean'
) -> None:
"""Applies a continuous-filter convolution
Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
num_filters : int
Number of filters
network : nn.Sequential
Neural network
cutoff : float
Radial cutoff
network_lr : Optional[nn.Sequential]
Neural network for long-range interactions
long_range_cutoff : float
Cutoff radius for the long-range edges defined on subsystem atoms
aggr : str, optional
Aggregation function, by default 'mean'
"""
super().__init__(aggr=aggr)
self.lin1 = nn.Linear(in_channels, num_filters, bias=False)
self.lin2 = nn.Linear(num_filters, out_channels)
self.network = network
self.network_lr = network_lr
self.cutoff = cutoff
self.long_range_cutoff = long_range_cutoff
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.lin1.weight)
nn.init.xavier_uniform_(self.lin2.weight)
self.lin2.bias.data.fill_(0)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_weight: torch.Tensor,
edge_attr: torch.Tensor,
edge_masks_lr: Optional[torch.Tensor] = None,
) -> torch.Tensor:
C = 0.5 * (torch.cos(edge_weight * math.pi / self.cutoff) + 1.0)
W = self.network(edge_attr) * C.view(-1, 1)
if edge_masks_lr is not None and self.network_lr is not None:
assert self.network_lr is not None
assert self.long_range_cutoff > self.cutoff
indices_lr = edge_masks_lr.nonzero()[:, 0]
lengths_lr = edge_weight[indices_lr]
edge_attr_lr = edge_attr[indices_lr]
C_l = 0.5 * torch.cos(lengths_lr * math.pi / self.long_range_cutoff) + 0.5
C_l_1 = 0.5 - 0.5 * torch.cos(lengths_lr * math.pi / self.cutoff)
C_l = C_l * (
C_l_1 * (lengths_lr < self.cutoff) # le shorter than cutoff
+ 1.0 * (lengths_lr > self.cutoff) # le longer than cutoff
)
W_l = self.network_lr(edge_attr_lr) * C_l.view(-1, 1)
W = W.index_copy_(0, indices_lr, W_l)
x = self.lin1(x)
x = self.propagate(edge_index, x=x, W=W)
x = self.lin2(x)
return x
def message(self, x_j: torch.Tensor, W: torch.Tensor) -> torch.Tensor:
return x_j * W
from mlcolvar.data.graph.utils import create_graph_tracing_example, create_test_graph_input
def _create_test_data_list():
batch = create_test_graph_input(
output_type='batch',
n_atoms=3,
n_samples=6,
n_states=1,
add_noise=False,
)
return batch['data_list']
def test_schnet_1() -> None:
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
model = SchNetModel(
n_out=2,
cutoff=0.1,
atomic_numbers=[1, 8],
n_bases=6,
n_layers=2,
n_filters=16,
n_hidden_channels=16
)
data = _create_test_data_list()
ref_out = torch.tensor([[0.40384621527953063, -0.12575133651389694]] * 5)
assert ( torch.allclose(model(data), ref_out) )
model = SchNetModel(
n_out=2,
cutoff=0.1,
atomic_numbers=[1, 8],
n_bases=6,
n_layers=2,
n_filters=16,
n_hidden_channels=16,
pooling_operation='sum',
)
data = _create_test_data_list()
ref_out = torch.tensor([[0.15911003978422333, 0.45333821159230125]] * 5)
assert ( torch.allclose(model(data), ref_out) )
traced_model = torch.jit.trace(model, example_inputs=create_graph_tracing_example(2))
assert ( torch.allclose(traced_model(data), ref_out) )
def test_schnet_2() -> None:
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
model = SchNetModel(
n_out=2,
cutoff=0.1,
atomic_numbers=[1, 8],
n_bases=6,
n_layers=2,
n_filters=16,
n_hidden_channels=16,
aggr='min',
w_out_after_pool=True
)
data = _create_test_data_list()
ref_out = torch.tensor([[0.3654537816221449, -0.0748265132499575]] * 5)
assert ( torch.allclose(model(data), ref_out) )
torch.set_default_dtype(torch.float32)
def test_schnet_from_dataset() -> None:
from mlcolvar.data.graph.utils import create_test_graph_input
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
dataset = create_test_graph_input(output_type='dataset',
n_atoms=3,
n_samples=5,
n_states=1,
add_noise=False,
)
model = SchNetModel(
n_out=2,
dataset_for_initialization=dataset,
n_bases=6,
n_layers=2,
n_filters=16,
n_hidden_channels=16,
aggr='max',
w_out_after_pool=True
)
# check the model parameters are correctly initialized from the dataset metadata
assert ( model.cutoff == dataset.metadata['cutoff'] )
assert ( torch.allclose(model.atomic_numbers, torch.as_tensor(dataset.metadata['atomic_numbers'])) )
assert ( torch.allclose(model.buffer, torch.as_tensor(dataset.metadata['buffer'])) )
# check output is consistent with the one obtained from the test graph input
ref_out = torch.tensor([[0.36632594, -0.08193991]] * 5)
assert ( torch.allclose(model(dataset.get_graph_inputs()), ref_out) )
# test with environment atoms
dataset = create_test_graph_input(output_type='dataset',
n_atoms=3,
n_samples=5,
n_states=1,
add_noise=False,
environment=True
)
model = SchNetModel(
n_out=2,
dataset_for_initialization=dataset,
n_bases=6,
n_layers=2,
n_filters=16,
n_hidden_channels=16,
aggr='max',
w_out_after_pool=True
)
# check the model parameters are correctly initialized from the dataset metadata
assert ( model.cutoff == dataset.metadata['cutoff'] )
assert ( torch.allclose(model.atomic_numbers, torch.as_tensor(dataset.metadata['atomic_numbers'])) )
assert ( torch.allclose(model.buffer, torch.as_tensor(dataset.metadata['buffer'])) )
# check output is consistent with the one obtained from the test graph input
ref_out = torch.tensor([[0.14110785, -0.22323715]] * 5)
assert ( torch.allclose(model(dataset.get_graph_inputs()), ref_out) )
torch.set_default_dtype(torch.float32)
def test_schnet_3() -> None:
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
model = SchNetModel(
n_out=2,
cutoff=0.1,
atomic_numbers=[1, 8],
n_bases=6,
n_layers=2,
n_filters=16,
n_hidden_channels=16,
aggr='attention',
)
data = _create_test_data_list()
ref_out = torch.tensor([[-0.3191231788534454, -0.0436194218681725]] * 5)
assert ( torch.allclose(model(data), ref_out) )
model = SchNetModel(
n_out=2,
cutoff=0.1,
atomic_numbers=[1, 8],
n_bases=6,
n_layers=2,
n_filters=16,
n_hidden_channels=16,
aggr='attention_separate',
)
data = _create_test_data_list()
ref_out = torch.tensor([[-0.1364561627454978, -0.1203537910489112]] * 5)
assert ( torch.allclose(model(data), ref_out) )
torch.set_default_dtype(torch.float32)
def test_schnet_4() -> None:
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
model = SchNetModel(
n_out=2,
cutoff=0.1,
long_range_cutoff=0.2,
atomic_numbers=[1, 8],
n_bases=6,
n_layers=2,
n_filters=16,
n_hidden_channels=16,
aggr='min',
)
data = _create_test_data_list()
data['edge_masks_lr'] = torch.zeros(
((data['edge_index'].shape[1]), 1), dtype=bool
)
data['edge_masks_lr'][:-2] = True
torch.set_printoptions(precision=16)
ref_out = torch.tensor([[-0.1873424391457965, -0.0150953093265520],
[-0.1873424391457965, -0.0150953093265520],
[-0.1873424391457965, -0.0150953093265520],
[-0.1873424391457965, -0.0150953093265520],
[-0.1846414580701179, -0.0121660647548140]])
assert ( torch.allclose(model(data), ref_out) )
torch.set_default_dtype(torch.float32)