import functools
import math
import torch
from torch import nn
from torch_geometric.nn import MessagePassing
from typing import Tuple, Callable, Optional, List, Dict
from mlcolvar.core.nn.graph.gnn import BaseGNN
from mlcolvar.data import DictDataset
"""
The Geometric Vector Perceptron (GVP) layer. This module is adapted from:
https://github.com/chaitjo/geometric-gnn-dojo/blob/main/models/layers/py,
and made compilable.
"""
__all__ = ['GVPModel', 'GVPConvLayer', 'LayerNorm', 'Dropout']
[docs]
class GVPModel(BaseGNN):
"""
The Geometric Vector Perceptron (GVP) model [1, 2] with vector gate [2].
References
----------
.. [1] Jing, Bowen, et al.
"Learning from protein structure with geometric vector perceptrons."
International Conference on Learning Representations. 2020.
.. [2] Jing, Bowen, et al.
"Equivariant graph neural networks for 3d macromolecular structure."
arXiv preprint arXiv:2106.03843 (2021).
"""
[docs]
def __init__(
self,
n_out: int,
dataset_for_initialization: DictDataset = None,
pooling_operation : str = 'mean',
n_bases: int = 8,
n_polynomials: int = 6,
n_layers: int = 1,
n_messages: int = 2,
n_feedforwards: int = 2,
n_scalars_node: int = 8,
n_vectors_node: int = 8,
n_scalars_edge: int = 8,
drop_rate: int = 0.1,
activation: str = 'SiLU',
basis_type: str = 'bessel',
smooth: bool = False,
**kwargs
) -> None:
"""Initializes a Geometric Vector Perceptron (GVP) model.
Parameters
----------
n_out: int
Number of the output scalar node features.
dataset_for_initialization : DictDataset, optional
Dataset containing the graphs on which the gnn model will be applied.
This is used to initialize and register the cutoff, buffer, and atomic_numbers from the dataset metadata.
This is the preferred way to initialize the gnn model, as it ensures consistency between the model and the dataset.
As an alternative this can be set to None and the cutoff, buffer, and atomic_numbers can be provided as kwargs.
pooling_operation : str
Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default 'mean'
n_bases: int
Size of the basis set used for the embedding, by default 8.
n_polynomials: bool
Order of the polynomials in the basis functions, by default 6.
n_layers: int
Number of the graph convolution layers, by default 1.
n_messages: int
Number of GVP layers to be used in the message functions, by default 2.
n_feedforwards: int
Number of GVP layers to be used in the feedforward functions, by default 2.
n_scalars_node: int
Size of the scalar channel of the node embedding in hidden layers, by default 8.
n_vectors_node: int
Size of the vector channel of the node embedding in hidden layers, by default 8.
n_scalars_edge: int
Size of the scalar channel of the edge embedding in hidden layers, by default 8.
drop_rate: int
Drop probability in all dropout layers, by default 0.1.
activation: str
Name of the activation function to be used in the GVPs (case sensitive), by default SiLU.
basis_type: str
Type of the basis function, by default bessel.
smooth: bool
If use the smoothed GVPConv, by default False.
"""
super().__init__(
n_out=n_out,
dataset_for_initialization=dataset_for_initialization,
pooling_operation=pooling_operation,
n_bases=n_bases,
n_polynomials=n_polynomials,
basis_type=basis_type,
**kwargs
)
self.W_e = nn.ModuleList([
LayerNorm((n_bases, 1)),
GVP(in_dims=(n_bases, 1),
out_dims=(n_scalars_edge, 1),
activations=(None, None),
vector_gate=True
)
])
self.W_v = nn.ModuleList([
LayerNorm((len(self.atomic_numbers), 0)),
GVP(in_dims=(len(self.atomic_numbers), 0),
out_dims=(n_scalars_node, n_vectors_node),
activations=(None, None),
vector_gate=True
)
])
self.layers = nn.ModuleList(
GVPConvLayer(node_dims=(n_scalars_node, n_vectors_node),
edge_dims=(n_scalars_edge, 1),
n_message=n_messages,
n_feedforward=n_feedforwards,
drop_rate=drop_rate,
activations=(eval(f'torch.nn.{activation}')(), None),
vector_gate=True,
cutoff=(self.cutoff if smooth else -1),
long_range_cutoff=(self.long_range_cutoff if smooth else -1),
)
for _ in range(n_layers)
)
self.W_out = nn.ModuleList([
LayerNorm((n_scalars_node, n_vectors_node)),
GVP(in_dims=(n_scalars_node, n_vectors_node),
out_dims=(n_out, 0),
activations=(None, None),
vector_gate=True)
])
[docs]
def forward(
self, data: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""The forward pass.
Parameters
----------
data: Dict[str, torch.Tensor]
The data dict. Usually came from the `to_dict` method of a
`torch_geometric.data.Batch` object.
"""
h_V = (data['node_attrs'], None)
for w in self.W_v:
h_V = w(h_V)
h_V_1, h_V_2 = h_V
assert h_V_2 is not None
h_V = (h_V_1, h_V_2)
h_E = self.embed_edge(data)
lengths = h_E[0]
h_E = (h_E[1], h_E[2].unsqueeze(-2))
for w in self.W_e:
h_E = w(h_E)
h_E_1, h_E_2 = h_E
assert h_E_2 is not None
h_E = (h_E_1, h_E_2)
for layer in self.layers:
mask = data.get("edge_masks_lr", None)
h_V = layer(
h_V,
data['edge_index'],
h_E,
lengths,
None,
mask
)
for w in self.W_out:
h_V = w(h_V)
out = h_V[0]
# pooling is controlled by `self.pooling_operation` (mean/sum/None)
out = self.pooling(input=out, data=data)
return out
class GVP(nn.Module):
"""
Geometric Vector Perceptron (GVP) layer from [1, 2] with vector gate [2].
References
----------
.. [1] Jing, Bowen, et al.
"Learning from protein structure with geometric vector perceptrons."
International Conference on Learning Representations. 2020.
.. [2] Jing, Bowen, et al.
"Equivariant graph neural networks for 3d macromolecular structure."
arXiv preprint arXiv:2106.03843 (2021).
"""
def __init__(
self,
in_dims: Tuple[int, Optional[int]],
out_dims: Tuple[int, Optional[int]],
h_dim: Tuple[int, Optional[int]] = None,
activations: Tuple[
Optional[Callable], Optional[Callable]
] = (nn.functional.relu, torch.sigmoid),
vector_gate: bool = True,
) -> None:
r"""Geometric Vector Perceptron layer.
Updates the scalar node feature s as:
.. math:: bm{s}^n \leftarrow \sigma \left(\bm{s}'\right) \quad\text{with}\quad \bm{s}' \coloneq \bm{W}_m \left[{\|\bm{W}_h\vec{\bm{v}}^o\|_2 \atop \bm{s}^o}\right] + \bm{b}
And the vector nore features as:
.. math:: \vec{\bm{v}}^n \leftarrow \sigma_g \left(\bm{W}_g\left(\sigma^+ \left(\bm{s}'\right)\right) + \bm{b}_g \right) \odot \bm{W}_\mu\bm{W}_h\vec{\bm{v}}^o
Parameters
----------
in_dims : Tuple[int, Optional[int]]
Dimension of inputs
out_dims : Tuple[int, Optional[int]]
Dimension of outputs
h_dim : Tuple[int, Optional[int]], optional
Intermidiate number of vector channels, by default None
activations : Tuple[ Optional[Callable], Optional[Callable] ], optional
Scalar and vector activation functions (scalar_act, vector_act), by default (nn.functional.relu, torch.sigmoid)
vector_gate : bool, optional
Whether to use vector gating, by default True. The vector activation will be used as sigma^+ in vector gating if `True`
"""
super(GVP, self).__init__()
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.vector_gate = vector_gate
if self.vi:
self.h_dim = h_dim or max(self.vi, self.vo)
self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
self.ws = nn.Linear(self.h_dim + self.si, self.so)
if self.vo:
self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
if self.vector_gate:
self.wsv = nn.Linear(self.so, self.vo)
else:
self.wv = None
self.wsv = None
else:
self.wh = None
self.wv = None
self.wsv = None
self.ws = nn.Linear(self.si, self.so)
self.scalar_act, self.vector_act = activations
self.dummy_param = nn.Parameter(torch.empty(0))
def forward(
self,
x: Tuple[torch.Tensor, Optional[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward pass of GVP
Parameters
----------
x : Tuple[torch.Tensor, Optional[torch.Tensor]]
Input scalar and vector node embeddings
Returns
-------
Tuple[torch.Tensor, Optional[torch.Tensor]]
Input scalar and vector node embeddings
"""
s, v = x
if v is not None:
assert self.wh is not None
v = torch.transpose(v, -1, -2)
vh = self.wh(v)
vn = _norm_no_nan(vh, axis=-2)
s = self.ws(torch.cat([s, vn], -1))
if self.vo:
assert self.wv is not None
v = self.wv(vh)
v = torch.transpose(v, -1, -2)
if self.vector_gate:
assert self.wsv is not None
gate = (
self.wsv(self.vector_act(s))
if self.vector_act is not None
else self.wsv(s)
)
v = v * torch.sigmoid(gate).unsqueeze(-1)
elif self.vector_act is not None:
v = v * self.vector_act(
_norm_no_nan(v, axis=-1, keepdims=True)
)
else:
s = self.ws(s)
if self.vo:
v = torch.zeros(
s.shape[0],
self.vo,
3,
device=self.dummy_param.device,
dtype=s.dtype
)
else:
v = None
if self.scalar_act is not None:
s = self.scalar_act(s)
return s, v
class GVPConv(MessagePassing):
"""
Graph convolution / message passing with Geometric Vector Perceptrons.
"""
propagate_type = {
's': torch.Tensor,
'v': torch.Tensor,
'edge_attr_s': torch.Tensor,
'edge_attr_v': torch.Tensor,
'edge_lengths': torch.Tensor,
'edge_masks_lr': Optional[torch.Tensor],
}
def __init__(
self,
in_dims,
out_dims,
edge_dims,
n_layers=3,
aggr='mean',
activations=(nn.functional.relu, torch.sigmoid),
vector_gate=True,
cutoff: float = -1.0,
long_range_cutoff: float = -1.0,
) -> None:
"""Graph convolution / message passing with Geometric Vector Perceptrons.
Takes in a graph with node and edge embeddings,
and returns new node embeddings.
This does NOT do residual updates and pointwise feedforward layers
--- see `GVPConvLayer` instead.
Parameters
----------
in_dims :
input node embedding dimensions (n_scalar, n_vector)
out_dims :
output node embedding dimensions (n_scalar, n_vector)
edge_dims :
input edge embedding dimensions (n_scalar, n_vector)
n_layers : int, optional
number of GVPs in the message function, by default 3
aggr : str, optional
Type of message aggregate function, by default 'mean'
activations : tuple, optional
activation functions (scalar_act, vector_act) to be used use in GVPs, by default (nn.functional.relu, torch.sigmoid)
vector_gate : bool, optional
Whether to use vector gating, by default True. The vector activation will be used as sigma^+ in vector gating if `True`
cutoff : float, optional
Radial cutoff, by default -1.0
long_range_cutoff : float
Cutoff radius for the long-range edges defined on subsystem atoms.
If negative, no long-range interactions are considered, by default -1.0
"""
super(GVPConv, self).__init__(aggr=aggr)
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.se, self.ve = edge_dims
self.cutoff = cutoff
self.long_range_cutoff = long_range_cutoff
GVP_ = functools.partial(
GVP, activations=activations, vector_gate=vector_gate
)
self._module_list = torch.nn.ModuleList()
if n_layers == 1:
self._module_list.append(
GVP_(in_dims=(2 * self.si + self.se, 2 * self.vi + self.ve),
out_dims=(self.so, self.vo),
activations=(None, None))
)
else:
self._module_list.append(
GVP_(in_dims=(2 * self.si + self.se, 2 * self.vi + self.ve),
out_dims=out_dims)
)
for i in range(n_layers - 2):
self._module_list.append(GVP_(out_dims, out_dims))
self._module_list.append(
GVP_(in_dims=out_dims,
out_dims=out_dims,
activations=(None, None))
)
def forward(
self,
x: Tuple[torch.Tensor, torch.Tensor],
edge_index: torch.Tensor,
edge_attr: Tuple[torch.Tensor, torch.Tensor],
edge_lengths: torch.Tensor,
edge_masks_lr: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of GVPConv
Parameters
----------
x : Tuple[torch.Tensor, torch.Tensor]
Input scalar and vector node embeddings
edge_index : torch.Tensor
Index of edge sources and destinations
edge_attr : Tuple[torch.Tensor, torch.Tensor]
Edge attributes
edge_lengths : torch.Tensor
Edge lengths
edge_masks_lr : Optional[torch.Tensor]
Mask for long-range edges defined on subsystem atoms, by default None.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Output scalar and vector node embeddings
"""
x_s, x_v = x
assert x_v is not None
message = self.propagate(
edge_index,
s=x_s,
v=x_v.contiguous().view(x_v.shape[0], x_v.shape[1] * 3),
edge_attr_s=edge_attr[0],
edge_attr_v=edge_attr[1],
edge_lengths=edge_lengths,
edge_masks_lr=edge_masks_lr,
)
return _split(message, self.vo)
def message(
self,
s_i: torch.Tensor,
v_i: torch.Tensor,
s_j: torch.Tensor,
v_j: torch.Tensor,
edge_attr_s: torch.Tensor,
edge_attr_v: torch.Tensor,
edge_lengths: torch.Tensor,
edge_masks_lr: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert edge_attr_s is not None
assert edge_attr_v is not None
v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
message = _tuple_cat(
(s_j, v_j), (edge_attr_s, edge_attr_v), (s_i, v_i)
)
message = self.message_func(message)
message_merged = _merge(*message)
if self.cutoff > 0:
lens = edge_lengths
# normal cutoff
c = 0.5 * (torch.cos(lens * math.pi / self.cutoff) + 1.0)
if edge_masks_lr is not None and self.long_range_cutoff > self.cutoff:
mask = edge_masks_lr.view(-1, 1)
# long cutoff
c_l = 0.5 * torch.cos(lens * math.pi / self.long_range_cutoff) + 0.5
c_l_1 = 0.5 - 0.5 * torch.cos(lens * math.pi / self.cutoff)
c_l = c_l * (
c_l_1 * (lens < self.cutoff) +
(lens >= self.cutoff).float()
)
# replace long edges
c = torch.where(mask, c_l, c)
message_merged = message_merged * c
return message_merged
def message_func(
self, x: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
for m in self._module_list:
x = m(x)
output_1, output_2 = x
assert output_2 is not None
return output_1, output_2
class GVPConvLayer(nn.Module):
"""
Full graph convolution / message passing layer with
Geometric Vector Perceptrons.
Residually updates node embeddings with
aggregated incoming messages, applies a pointwise feedforward
network to node embeddings, and returns updated node embeddings.
To only compute the aggregated messages, see `GVPConv`.
"""
def __init__(
self,
node_dims,
edge_dims,
n_message=3,
n_feedforward=2,
drop_rate=0.1,
activations=(nn.functional.relu, torch.sigmoid),
vector_gate=True,
residual=True,
cutoff: float = -1.0,
long_range_cutoff: float = -1.0,
aggr: str = 'mean',
) -> None:
"""Full graph convolution / message passing layer with
Geometric Vector Perceptrons.
Residually updates node embeddings with
aggregated incoming messages, applies a pointwise feedforward
network to node embeddings, and returns updated node embeddings.
To only compute the aggregated messages see `GVPConv` instead.
Parameters
----------
node_dims :
node embedding dimensions (n_scalar, n_vector)
edge_dims :
input edge embedding dimensions (n_scalar, n_vector)
n_message : int, optional
number of GVP layers to be used in message function, by default 3
n_feedforward : int, optional
number of GVPs to be used use in feedforward function, by default 2
drop_rate : float, optional
drop probability in all dropout layers, by default 0.1
activations : tuple, optional
activation functions (scalar_act, vector_act) to be used use in GVPs, by default (nn.functional.relu, torch.sigmoid)
vector_gate : bool, optional
whether to use vector gating, by default True. The vector activation will be used as sigma^+ in vector gating if `True`
residual : bool, optional
whether to perform the update residually, by default True
cutoff : float, optional
radial cutoff, by default -1.0
long_range_cutoff : float
Cutoff radius for the long-range edges defined on subsystem atoms.
If negative, no long-range interactions are considered, by default -1.0
aggr : str, optional
Type of message aggregate function, by default 'mean'
"""
super(GVPConvLayer, self).__init__()
self.conv = GVPConv(
node_dims,
node_dims,
edge_dims,
n_message,
aggr=aggr,
activations=activations,
vector_gate=vector_gate,
cutoff=cutoff,
long_range_cutoff=long_range_cutoff,
)
GVP_ = functools.partial(
GVP, activations=activations, vector_gate=vector_gate
)
self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
self._module_list = nn.ModuleList()
if n_feedforward == 1:
self._module_list.append(
GVP_(in_dims=node_dims,
out_dims=node_dims,
activations=(None, None))
)
else:
hid_dims = 4 * node_dims[0], 2 * node_dims[1]
self._module_list.append(GVP_(node_dims, hid_dims))
self._module_list.extend(
GVP_(in_dims=hid_dims, out_dims=hid_dims) for _ in range(n_feedforward - 2)
)
self._module_list.append(
GVP_(in_dims=hid_dims, out_dims=node_dims, activations=(None, None))
)
self.residual = residual
def forward(
self,
x: Tuple[torch.Tensor, torch.Tensor],
edge_index: torch.Tensor,
edge_attr: Tuple[torch.Tensor, torch.Tensor],
edge_lengths: torch.Tensor,
node_mask: Optional[torch.Tensor] = None,
edge_masks_lr: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of GVPConvLayer
Parameters
----------
x : Tuple[torch.Tensor, torch.Tensor]
Input scalar and vector node embeddings
edge_index : torch.Tensor
Index of edge sources and destinations
edge_attr : Tuple[torch.Tensor, torch.Tensor]
Edge attributes
edge_lengths : torch.Tensor
Edge lengths
node_mask : Optional[torch.Tensor], optional
Mask to restrict the node update to a subset.
It should be a tensor of type `bool` to index the first dim of node embeddings (s, V), by default None.
If not `None`, only the selected nodes will be updated.
edge_masks_lr : Optional[torch.Tensor]
Mask for long-range edges defined on subsystem atoms, by default None.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Output scalar and vector node embeddings
"""
dh = self.conv(x, edge_index, edge_attr, edge_lengths, edge_masks_lr)
x_ = x
if node_mask is not None:
x, dh = _tuple_index(x, node_mask), _tuple_index(dh, node_mask)
if self.residual:
input_1, input_2 = self.dropout[0](dh)
assert input_2 is not None
output_1, output_2 = self.norm[0](
_tuple_sum(x, (input_1, input_2))
)
assert output_2 is not None
x = (output_1, output_2)
else:
x = dh
dh = self.ff_func(x)
if self.residual:
input_1, input_2 = self.dropout[1](dh)
assert input_2 is not None
output_1, output_2 = self.norm[1](
_tuple_sum(x, (input_1, input_2))
)
assert output_2 is not None
x = (output_1, output_2)
else:
x = dh
if node_mask is not None:
x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
x = x_
return x
def ff_func(
self, x: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
for m in self._module_list:
x = m(x)
output_1 = x[0]
output_2 = x[1]
assert output_2 is not None
return output_1, output_2
class LayerNorm(nn.Module):
"""
Combined LayerNorm for tuples (s, V).
Takes tuples (s, V) as input and as output.
"""
def __init__(self, dims) -> None:
super(LayerNorm, self).__init__()
self.s, self.v = dims
self.scalar_norm = nn.LayerNorm(self.s)
def forward(
self,
x: Tuple[torch.Tensor, Optional[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward pass of LayerNorm
Parameters
----------
x : Tuple[torch.Tensor, Optional[torch.Tensor]]
Input channels, if a single tensor is provided it assumes it to be the scalar channel
Returns
-------
Tuple[torch.Tensor, Optional[torch.Tensor]]
Normalized channels
"""
s, v = x
if not self.v:
return self.scalar_norm(s), None
else:
assert v is not None
vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
return self.scalar_norm(s), v / vn
class Dropout(nn.Module):
"""
Combined dropout for tuples (s, V).
Takes tuples (s, V) as input and as output.
"""
def __init__(self, drop_rate) -> None:
super(Dropout, self).__init__()
self.sdropout = nn.Dropout(drop_rate)
self.vdropout = _VDropout(drop_rate)
def forward(
self,
x: Tuple[torch.Tensor, Optional[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward pass of Dropout
Parameters
----------
x : Tuple[torch.Tensor, Optional[torch.Tensor]]
Input channels, if a single tensor is provided it assumes it to be the scalar channel
Returns
-------
Tuple[torch.Tensor, Optional[torch.Tensor]]
Dropped out channels
"""
s, v = x
if v is None:
return self.sdropout(s), None
else:
assert v is not None
return self.sdropout(s), self.vdropout(v)
class _VDropout(nn.Module):
"""
Vector channel dropout where the elements of each
vector channel are dropped together.
"""
def __init__(self, drop_rate) -> None:
super(_VDropout, self).__init__()
self.drop_rate = drop_rate
self.dummy_param = nn.Parameter(torch.empty(0))
def forward(self, x : torch.Tensor) -> torch.Tensor:
"""Forward pass of _VDropout
Parameters
----------
x : torch.Tensor
Vector channel
Returns
-------
torch.Tensor
Dropped out vector channel
"""
device = self.dummy_param.device
if not self.training:
return x
mask = torch.bernoulli(
(1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
).unsqueeze(-1)
x = mask * x / (1 - self.drop_rate)
return x
def _tuple_sum(
input_1: Tuple[torch.Tensor, torch.Tensor],
input_2: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sums any number of tuples (s, V) elementwise.
"""
out = [i + j for i, j in zip(input_1, input_2)]
return out[0], out[1]
@torch.jit.script_if_tracing
def _tuple_cat(
input_1: Tuple[torch.Tensor, torch.Tensor],
input_2: Tuple[torch.Tensor, torch.Tensor],
input_3: Tuple[torch.Tensor, torch.Tensor],
dim: int = -1
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Concatenates any number of tuples (s, V) elementwise.
Parameters
----------
input_1 : Tuple[torch.Tensor, torch.Tensor]
First input to concatenate
input_2 : Tuple[torch.Tensor, torch.Tensor]
Second input to concatenate
input_3 : Tuple[torch.Tensor, torch.Tensor]
Third input to concatenate
dim : int, optional
dimension along which to concatenate when viewed
as the `dim` index for the scalar-channel tensors, by default -1.
This means that `dim=-1` will be applied as
`dim=-2` for the vector-channel tensors.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Concatenated tuple
"""
dim = int(dim % len(input_1[0].shape))
s_args, v_args = list(zip(input_1, input_2, input_3))
return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
@torch.jit.script_if_tracing
def _tuple_index(
x: Tuple[torch.Tensor, torch.Tensor], idx: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Indexes a tuple (s, V) along the first dimension at a given index.
Parameters
----------
x : Tuple[torch.Tensor, torch.Tensor]
Tuple to be indexed
idx : torch.Tensor
any object which can be used to index a `torch.Tensor`
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Tuple with the element at the given index
"""
return x[0][idx], x[1][idx]
@torch.jit.script_if_tracing
def _norm_no_nan(
x: torch.Tensor,
axis: int = -1,
keepdims: bool = False,
eps: float = 1e-8,
sqrt: bool = True
) -> torch.Tensor:
"""L2 norm of tensor clamped above a minimum value `eps`.
Parameters
----------
x : torch.Tensor
Input tensor
axis : int, optional
Axis along which to compute the norm, by default -1
keepdims : bool, optional
Whether to keep the original dimensions, by default False
eps : float, optional
Lowest threshold for clamping the norm, by default 1e-8
sqrt : bool, optional
Compute the sqaure root in L2 norm, by default True.
If `False`, returns the square of the L2 norm
Returns
-------
torch.Tensor
Normed tensor
"""
out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
return torch.sqrt(out) if sqrt else out
@torch.jit.script_if_tracing
def _split(x: torch.Tensor, nv: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Splits a merged representation of (s, V) back into a tuple.
Should be used only with `_merge(s, V)` and only if the tuple
representation cannot be used.
Parameters
----------
x : torch.Tensor
the `torch.Tensor` returned from `_merge`
nv : int
the number of vector channels in the input to `_merge`
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
split representation
"""
s = x[..., :-3 * nv]
v = x[..., -3 * nv:].contiguous().view(x.shape[0], nv, 3)
return s, v
@torch.jit.script_if_tracing
def _merge(s: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Merges a tuple (s, V) into a single `torch.Tensor`, where the
vector channels are flattened and appended to the scalar channels.
Should be used only if the tuple representation cannot be used.
Use `_split(x, nv)` to reverse.
"""
v = v.contiguous().view(v.shape[0], v.shape[1] * 3)
return torch.cat([s, v], -1)
from mlcolvar.data.graph.utils import create_graph_tracing_example, create_test_graph_input
def _create_test_data_list():
batch = create_test_graph_input(
output_type='batch',
n_atoms=3,
n_samples=6,
n_states=1,
add_noise=False,
)
return batch['data_list']
def test_gvp() -> None:
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
model = GVPModel(
n_out=2,
cutoff=0.1,
atomic_numbers=[1, 8],
n_bases=6,
n_polynomials=6,
n_layers=2,
n_messages=2,
n_feedforwards=1,
n_scalars_node=16,
n_vectors_node=8,
n_scalars_edge=16,
drop_rate=0,
activation='SiLU',
)
data = _create_test_data_list()
ref_out = torch.tensor([[0.6100070244145421, -0.2559670171962067]] * 5)
assert ( torch.allclose(model(data), ref_out) )
traced_model = torch.jit.trace(model, example_inputs=create_graph_tracing_example(2))
assert ( torch.allclose(traced_model(data), ref_out) )
model = GVPModel(
n_out=2,
cutoff=0.1,
atomic_numbers=[1, 8],
n_bases=6,
n_polynomials=6,
n_layers=2,
n_messages=2,
n_feedforwards=2,
n_scalars_node=16,
n_vectors_node=8,
n_scalars_edge=16,
drop_rate=0,
activation='SiLU',
)
data = _create_test_data_list()
ref_out = torch.tensor([[0.5097288781305398, -0.032077559793064814]] * 5)
assert ( torch.allclose(model(data), ref_out) )
traced_model = torch.jit.trace(model, example_inputs=create_graph_tracing_example(2))
assert ( torch.allclose(traced_model(data), ref_out) )
torch.set_default_dtype(torch.float32)
def test_gvp_from_dataset() -> None:
from mlcolvar.data.graph.utils import create_test_graph_input
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
dataset = create_test_graph_input(output_type='dataset',
n_atoms=3,
n_samples=5,
n_states=1,
add_noise=False,
)
model = GVPModel(
n_out=2,
dataset_for_initialization=dataset,
n_bases=6,
n_polynomials=6,
n_layers=2,
n_messages=2,
n_feedforwards=2,
n_scalars_node=16,
n_vectors_node=8,
n_scalars_edge=16,
drop_rate=0,
activation='SiLU',
)
# check the model parameters are correctly initialized from the dataset metadata
assert ( model.cutoff == dataset.metadata['cutoff'] )
assert ( torch.allclose(model.atomic_numbers, torch.as_tensor(dataset.metadata['atomic_numbers'])) )
assert ( torch.allclose(model.buffer, torch.as_tensor(dataset.metadata['buffer'])) )
# check output is consistent with the one obtained from the test graph input
ref_out = torch.tensor([[-0.12551015, -0.5192468]] * 5)
assert ( torch.allclose(model(dataset.get_graph_inputs()), ref_out) )
# test with environment atoms
dataset = create_test_graph_input(output_type='dataset',
n_atoms=3,
n_samples=5,
n_states=1,
add_noise=False,
environment=True
)
model = GVPModel(
n_out=2,
dataset_for_initialization=dataset,
n_bases=6,
n_polynomials=6,
n_layers=2,
n_messages=2,
n_feedforwards=2,
n_scalars_node=16,
n_vectors_node=8,
n_scalars_edge=16,
drop_rate=0,
activation='SiLU',
)
# check the model parameters are correctly initialized from the dataset metadata
assert ( model.cutoff == dataset.metadata['cutoff'] )
assert ( torch.allclose(model.atomic_numbers, torch.as_tensor(dataset.metadata['atomic_numbers'])) )
assert ( torch.allclose(model.buffer, torch.as_tensor(dataset.metadata['buffer'])) )
# check output is consistent with the one obtained from the test graph input
ref_out = torch.tensor([[0.32278482, 0.05976963]] * 5)
assert ( torch.allclose(model(dataset.get_graph_inputs()), ref_out) )
torch.set_default_dtype(torch.float32)
if __name__ == "__main__":
test_gvp_from_dataset()