Source code for mlcolvar.core.nn.graph.schnet

import math
import torch
from torch import nn
from torch_geometric.nn.aggr import AttentionalAggregation
from torch_geometric.nn import MessagePassing

from mlcolvar.data import DictDataset
from mlcolvar.core.nn.utils import Shifted_Softplus
from mlcolvar.core.nn.graph.gnn import BaseGNN

from typing import List, Dict, Optional

"""
The SchNet components. This module is adapted from the pgy package:
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py
"""

__all__ = ["SchNetModel", "InteractionBlock"]

[docs] class SchNetModel(BaseGNN): """ The SchNet [1] model. This implementation is adapted from torch_geometric: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py References ---------- .. [1] Schütt, Kristof T., et al. "Schnet–a deep learning architecture for molecules and materials." The Journal of Chemical Physics 148.24 (2018). """
[docs] def __init__( self, n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation : str = 'mean', n_bases: int = 16, n_layers: int = 2, n_filters: int = 16, n_hidden_channels: int = 16, aggr: str = 'mean', w_out_after_pool: bool = False, **kwargs ) -> None: """ Parameters ---------- n_out : int Size of the output node features. dataset_for_initialization : DictDataset, optional Dataset containing the graphs on which the gnn model will be applied. This is used to initialize and register the cutoff, buffer, long_range_cutoff and atomic_numbers from the dataset metadata. This is the preferred way to initialize the gnn model, as it ensures consistency between the model and the dataset. As an alternative this can be set to None and the cutoff, buffer, long_range_cutoff and atomic_numbers can be provided as kwargs. pooling_operation : str Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default 'mean' n_bases : int, optional Size of the basis set used for the embedding, by default 16 n_layers : int, optional Number of the graph convolution layers, by default 2 n_filters : int, optional Number of filters, by default 16 n_hidden_channels : int, optional Size of hidden embeddings, by default 16 aggr : str, optional Type of the GNN aggregation function, by default 'mean' Possible choices are: 'mean', 'sum', 'max', 'min', 'mul', 'attention'/'attentional' (shared attention gate across all layers), 'attention_separate'/'attentional_separate' (Independent attention gate for each layer). w_out_after_pool : bool, optional Whether to apply the last linear transformation form hidden to output channels after the pooling sum, by default False """ super().__init__( n_out=n_out, dataset_for_initialization=dataset_for_initialization, pooling_operation=pooling_operation, n_bases=n_bases, n_polynomials=0, basis_type='gaussian', **kwargs ) # transforms embedding into hidden channels self.W_v = nn.Linear( in_features=len(self.atomic_numbers), out_features=n_hidden_channels, bias=False ) # attentional aggregation if aggr in ['attention', 'attentional']: self.attention_gate = nn.Sequential( nn.Linear(n_filters, n_filters // 2), Shifted_Softplus(), nn.Linear(n_filters // 2, 1) ) aggr = [ AttentionalAggregation(self.attention_gate) ] * n_layers elif aggr in ['attention_separate', 'attentional_separate']: self.attention_gate = nn.ModuleList([ nn.Sequential( nn.Linear(n_filters, n_filters // 2), Shifted_Softplus(), nn.Linear(n_filters // 2, 1) ) for _ in range(n_layers) ]) aggr = [ AttentionalAggregation(self.attention_gate[i]) for i in range(n_layers) ] else: self.attention_gate = None aggr = [aggr] * n_layers # initialize layers with interaction blocks self.layers = nn.ModuleList([ InteractionBlock( n_hidden_channels, n_bases, n_filters, self.cutoff, self.long_range_cutoff, aggr[i] ) for i in range(n_layers) ]) # transforms hidden channels into output channels self.W_out = nn.ModuleList([ nn.Linear(n_hidden_channels, n_hidden_channels // 2), Shifted_Softplus(), nn.Linear(n_hidden_channels // 2, n_out) ]) self._w_out_after_pool = w_out_after_pool self.reset_parameters()
[docs] def reset_parameters(self) -> None: """ Resets all learnable parameters of the module. """ self.W_v.reset_parameters() for layer in self.layers: layer.reset_parameters() nn.init.xavier_uniform_(self.W_out[0].weight) self.W_out[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.W_out[2].weight) self.W_out[2].bias.data.fill_(0) if isinstance(self.attention_gate, torch.nn.Sequential): nn.init.xavier_uniform_(self.attention_gate[0].weight) self.attention_gate[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.attention_gate[2].weight) self.attention_gate[2].bias.data.fill_(0) elif isinstance(self.attention_gate, torch.nn.ModuleList): for gate in self.attention_gate: nn.init.xavier_uniform_(gate[0].weight) gate[0].bias.data.fill_(0) nn.init.xavier_uniform_(gate[2].weight) gate[2].bias.data.fill_(0)
[docs] def forward( self, data: Dict[str, torch.Tensor] ) -> torch.Tensor: """ The forward pass. Parameters ---------- data: Dict[str, torch.Tensor] The data dict. Usually came from the `to_dict` method of a `torch_geometric.data.Batch` object. """ # embed edges and node attrs h_E = self.embed_edge(data) h_V = self.W_v(data['node_attrs']) # update through layers mask = data.get("edge_masks_lr", None) for layer in self.layers: h_V = h_V + layer( h_V, data["edge_index"], h_E[0], h_E[1], mask, ) # in case the last linear transformation is performed BEFORE pooling if not self._w_out_after_pool: for w in self.W_out: h_V = w(h_V) out = h_V # pooling is controlled by `self.pooling_operation` (mean/sum/None) out = self.pooling(input=out, data=data) # in case the last linear transformation is performed AFTER pooling if self._w_out_after_pool: for w in self.W_out: out = w(out) return out
class InteractionBlock(nn.Module): def __init__( self, hidden_channels: int, num_gaussians: int, num_filters: int, cutoff: float, long_range_cutoff: float = -1.0, aggr: str = 'mean' ) -> None: """SchNet interaction block Parameters ---------- hidden_channels : int Size of hidden embeddings num_gaussians : int Number of Gaussians for the embedding num_filters : int Number of filters cutoff : float Radial cutoff long_range_cutoff : float Cutoff radius for the long-range edges defined on subsystem atoms. If negative, no long-range interactions are considered, by default -1.0 aggr : str, optional Aggregation function, by default 'mean' """ super().__init__() self.mlp = nn.Sequential( nn.Linear(num_gaussians, num_filters), Shifted_Softplus(), nn.Linear(num_filters, num_filters), ) if long_range_cutoff > 0: self.mlp_lr = nn.Sequential( nn.Linear(num_gaussians, num_filters), Shifted_Softplus(), nn.Linear(num_filters, num_filters), ) else: self.mlp_lr = None self.conv = CFConv( hidden_channels, hidden_channels, num_filters, self.mlp, cutoff, self.mlp_lr, long_range_cutoff, aggr ) self.act = Shifted_Softplus() self.lin = nn.Linear(hidden_channels, hidden_channels) self.reset_parameters() def reset_parameters(self) -> None: """ Resets all learnable parameters of the module. """ nn.init.xavier_uniform_(self.mlp[0].weight) self.mlp[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.mlp[2].weight) self.mlp[2].bias.data.fill_(0) self.conv.reset_parameters() nn.init.xavier_uniform_(self.lin.weight) self.lin.bias.data.fill_(0) if self.mlp_lr is not None: nn.init.xavier_uniform_(self.mlp_lr[0].weight) self.mlp_lr[0].bias.data.fill_(0) nn.init.xavier_uniform_(self.mlp_lr[2].weight) self.mlp_lr[2].bias.data.fill_(0) def forward( self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor, edge_attr: torch.Tensor, edge_masks_lr: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = self.conv(x, edge_index, edge_weight, edge_attr, edge_masks_lr) x = self.act(x) x = self.lin(x) return x class CFConv(MessagePassing): """Continuos-filter convolution from SchNet""" def __init__( self, in_channels: int, out_channels: int, num_filters: int, network: nn.Sequential, cutoff: float, network_lr: Optional[nn.Sequential] = None, long_range_cutoff: float = -1.0, aggr: str = 'mean' ) -> None: """Applies a continuous-filter convolution Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels num_filters : int Number of filters network : nn.Sequential Neural network cutoff : float Radial cutoff network_lr : Optional[nn.Sequential] Neural network for long-range interactions long_range_cutoff : float Cutoff radius for the long-range edges defined on subsystem atoms aggr : str, optional Aggregation function, by default 'mean' """ super().__init__(aggr=aggr) self.lin1 = nn.Linear(in_channels, num_filters, bias=False) self.lin2 = nn.Linear(num_filters, out_channels) self.network = network self.network_lr = network_lr self.cutoff = cutoff self.long_range_cutoff = long_range_cutoff self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.lin1.weight) nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0) def forward( self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor, edge_attr: torch.Tensor, edge_masks_lr: Optional[torch.Tensor] = None, ) -> torch.Tensor: C = 0.5 * (torch.cos(edge_weight * math.pi / self.cutoff) + 1.0) W = self.network(edge_attr) * C.view(-1, 1) if edge_masks_lr is not None and self.network_lr is not None: assert self.network_lr is not None assert self.long_range_cutoff > self.cutoff indices_lr = edge_masks_lr.nonzero()[:, 0] lengths_lr = edge_weight[indices_lr] edge_attr_lr = edge_attr[indices_lr] C_l = 0.5 * torch.cos(lengths_lr * math.pi / self.long_range_cutoff) + 0.5 C_l_1 = 0.5 - 0.5 * torch.cos(lengths_lr * math.pi / self.cutoff) C_l = C_l * ( C_l_1 * (lengths_lr < self.cutoff) # le shorter than cutoff + 1.0 * (lengths_lr > self.cutoff) # le longer than cutoff ) W_l = self.network_lr(edge_attr_lr) * C_l.view(-1, 1) W = W.index_copy_(0, indices_lr, W_l) x = self.lin1(x) x = self.propagate(edge_index, x=x, W=W) x = self.lin2(x) return x def message(self, x_j: torch.Tensor, W: torch.Tensor) -> torch.Tensor: return x_j * W from mlcolvar.data.graph.utils import create_graph_tracing_example, create_test_graph_input def _create_test_data_list(): batch = create_test_graph_input( output_type='batch', n_atoms=3, n_samples=6, n_states=1, add_noise=False, ) return batch['data_list'] def test_schnet_1() -> None: torch.manual_seed(0) torch.set_default_dtype(torch.float64) model = SchNetModel( n_out=2, cutoff=0.1, atomic_numbers=[1, 8], n_bases=6, n_layers=2, n_filters=16, n_hidden_channels=16 ) data = _create_test_data_list() ref_out = torch.tensor([[0.40384621527953063, -0.12575133651389694]] * 5) assert ( torch.allclose(model(data), ref_out) ) model = SchNetModel( n_out=2, cutoff=0.1, atomic_numbers=[1, 8], n_bases=6, n_layers=2, n_filters=16, n_hidden_channels=16, pooling_operation='sum', ) data = _create_test_data_list() ref_out = torch.tensor([[0.15911003978422333, 0.45333821159230125]] * 5) assert ( torch.allclose(model(data), ref_out) ) traced_model = torch.jit.trace(model, example_inputs=create_graph_tracing_example(2)) assert ( torch.allclose(traced_model(data), ref_out) ) def test_schnet_2() -> None: torch.manual_seed(0) torch.set_default_dtype(torch.float64) model = SchNetModel( n_out=2, cutoff=0.1, atomic_numbers=[1, 8], n_bases=6, n_layers=2, n_filters=16, n_hidden_channels=16, aggr='min', w_out_after_pool=True ) data = _create_test_data_list() ref_out = torch.tensor([[0.3654537816221449, -0.0748265132499575]] * 5) assert ( torch.allclose(model(data), ref_out) ) torch.set_default_dtype(torch.float32) def test_schnet_from_dataset() -> None: from mlcolvar.data.graph.utils import create_test_graph_input torch.manual_seed(0) torch.set_default_dtype(torch.float64) dataset = create_test_graph_input(output_type='dataset', n_atoms=3, n_samples=5, n_states=1, add_noise=False, ) model = SchNetModel( n_out=2, dataset_for_initialization=dataset, n_bases=6, n_layers=2, n_filters=16, n_hidden_channels=16, aggr='max', w_out_after_pool=True ) # check the model parameters are correctly initialized from the dataset metadata assert ( model.cutoff == dataset.metadata['cutoff'] ) assert ( torch.allclose(model.atomic_numbers, torch.as_tensor(dataset.metadata['atomic_numbers'])) ) assert ( torch.allclose(model.buffer, torch.as_tensor(dataset.metadata['buffer'])) ) # check output is consistent with the one obtained from the test graph input ref_out = torch.tensor([[0.36632594, -0.08193991]] * 5) assert ( torch.allclose(model(dataset.get_graph_inputs()), ref_out) ) # test with environment atoms dataset = create_test_graph_input(output_type='dataset', n_atoms=3, n_samples=5, n_states=1, add_noise=False, environment=True ) model = SchNetModel( n_out=2, dataset_for_initialization=dataset, n_bases=6, n_layers=2, n_filters=16, n_hidden_channels=16, aggr='max', w_out_after_pool=True ) # check the model parameters are correctly initialized from the dataset metadata assert ( model.cutoff == dataset.metadata['cutoff'] ) assert ( torch.allclose(model.atomic_numbers, torch.as_tensor(dataset.metadata['atomic_numbers'])) ) assert ( torch.allclose(model.buffer, torch.as_tensor(dataset.metadata['buffer'])) ) # check output is consistent with the one obtained from the test graph input ref_out = torch.tensor([[0.14110785, -0.22323715]] * 5) assert ( torch.allclose(model(dataset.get_graph_inputs()), ref_out) ) torch.set_default_dtype(torch.float32) def test_schnet_3() -> None: torch.manual_seed(0) torch.set_default_dtype(torch.float64) model = SchNetModel( n_out=2, cutoff=0.1, atomic_numbers=[1, 8], n_bases=6, n_layers=2, n_filters=16, n_hidden_channels=16, aggr='attention', ) data = _create_test_data_list() ref_out = torch.tensor([[-0.3191231788534454, -0.0436194218681725]] * 5) assert ( torch.allclose(model(data), ref_out) ) model = SchNetModel( n_out=2, cutoff=0.1, atomic_numbers=[1, 8], n_bases=6, n_layers=2, n_filters=16, n_hidden_channels=16, aggr='attention_separate', ) data = _create_test_data_list() ref_out = torch.tensor([[-0.1364561627454978, -0.1203537910489112]] * 5) assert ( torch.allclose(model(data), ref_out) ) torch.set_default_dtype(torch.float32) def test_schnet_4() -> None: torch.manual_seed(0) torch.set_default_dtype(torch.float64) model = SchNetModel( n_out=2, cutoff=0.1, long_range_cutoff=0.2, atomic_numbers=[1, 8], n_bases=6, n_layers=2, n_filters=16, n_hidden_channels=16, aggr='min', ) data = _create_test_data_list() data['edge_masks_lr'] = torch.zeros( ((data['edge_index'].shape[1]), 1), dtype=bool ) data['edge_masks_lr'][:-2] = True torch.set_printoptions(precision=16) ref_out = torch.tensor([[-0.1873424391457965, -0.0150953093265520], [-0.1873424391457965, -0.0150953093265520], [-0.1873424391457965, -0.0150953093265520], [-0.1873424391457965, -0.0150953093265520], [-0.1846414580701179, -0.0121660647548140]]) assert ( torch.allclose(model(data), ref_out) ) torch.set_default_dtype(torch.float32)