mlcolvar.core.nn.graph.SchNetModel

class mlcolvar.core.nn.graph.SchNetModel(n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation: str = 'mean', n_bases: int = 16, n_layers: int = 2, n_filters: int = 16, n_hidden_channels: int = 16, aggr: str = 'mean', w_out_after_pool: bool = False, **kwargs)[source]

Bases: BaseGNN

The SchNet [1] model. This implementation is adapted from torch_geometric: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py

References

__init__(n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation: str = 'mean', n_bases: int = 16, n_layers: int = 2, n_filters: int = 16, n_hidden_channels: int = 16, aggr: str = 'mean', w_out_after_pool: bool = False, **kwargs) None[source]
Parameters:
  • n_out (int) – Size of the output node features.

  • dataset_for_initialization (DictDataset, optional) – Dataset containing the graphs on which the gnn model will be applied. This is used to initialize and register the cutoff, buffer, long_range_cutoff and atomic_numbers from the dataset metadata. This is the preferred way to initialize the gnn model, as it ensures consistency between the model and the dataset. As an alternative this can be set to None and the cutoff, buffer, long_range_cutoff and atomic_numbers can be provided as kwargs.

  • pooling_operation (str) – Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default ‘mean’

  • n_bases (int, optional) – Size of the basis set used for the embedding, by default 16

  • n_layers (int, optional) – Number of the graph convolution layers, by default 2

  • n_filters (int, optional) – Number of filters, by default 16

  • n_hidden_channels (int, optional) – Size of hidden embeddings, by default 16

  • aggr (str, optional) – Type of the GNN aggregation function, by default ‘mean’ Possible choices are: ‘mean’, ‘sum’, ‘max’, ‘min’, ‘mul’, ‘attention’/’attentional’ (shared attention gate across all layers), ‘attention_separate’/’attentional_separate’ (Independent attention gate for each layer).

  • w_out_after_pool (bool, optional) – Whether to apply the last linear transformation form hidden to output channels after the pooling sum, by default False

Methods

__init__(n_out[, ...])

forward(data)

The forward pass. :Parameters: data (Dict[str, torch.Tensor]) -- The data dict. Usually came from the to_dict method of a torch_geometric.data.Batch object.

reset_parameters()

Resets all learnable parameters of the module.

embed_edge(data: Dict[str, Tensor], normalize: bool = True) Tuple[Tensor, Tensor, Tensor]

Performs the model edge embedding form torch_geometric.data.Batch object.

Parameters:
  • data (Dict[str, torch.Tensor]) – The data dict. Usually from the to_dict method of a torch_geometric.data.Batch object.

  • normalize (bool) – If to return the normalized distance vectors, by default True.

Returns:

  • edge_lengths (torch.Tensor (shape: [n_edges, 1])) – The edge lengths.

  • edge_length_embeddings (torch.Tensor (shape: [n_edges, n_bases])) – The edge length embeddings.

  • edge_unit_vectors (torch.Tensor (shape: [n_edges, 3])) – The normalized edge vectors.

forward(data: Dict[str, Tensor]) Tensor[source]

The forward pass. :Parameters: data (Dict[str, torch.Tensor]) – The data dict. Usually came from the to_dict method of a

torch_geometric.data.Batch object.

pooling(input: Tensor, data: Dict[str, Tensor]) Tensor

Performs pooling of the node-level outputs to obtain a graph-level output

Parameters:
  • input (torch.Tensor) – Nodel level features to be pooled

  • data (Dict[str, torch.Tensor]) – Data batch containing the graph data informations

Returns:

Pooled output

Return type:

torch.Tensor

reset_parameters() None[source]

Resets all learnable parameters of the module.

Attributes

T_destination

call_super_init

dump_patches

in_features

out_features

training