mlcolvar.core.nn.graph.SchNetModel¶
- class mlcolvar.core.nn.graph.SchNetModel(n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation: str = 'mean', n_bases: int = 16, n_layers: int = 2, n_filters: int = 16, n_hidden_channels: int = 16, aggr: str = 'mean', w_out_after_pool: bool = False, **kwargs)[source]¶
Bases:
BaseGNNThe SchNet [1] model. This implementation is adapted from torch_geometric: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py
References
- __init__(n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation: str = 'mean', n_bases: int = 16, n_layers: int = 2, n_filters: int = 16, n_hidden_channels: int = 16, aggr: str = 'mean', w_out_after_pool: bool = False, **kwargs) None[source]¶
- Parameters:
n_out (int) – Size of the output node features.
dataset_for_initialization (DictDataset, optional) – Dataset containing the graphs on which the gnn model will be applied. This is used to initialize and register the cutoff, buffer, long_range_cutoff and atomic_numbers from the dataset metadata. This is the preferred way to initialize the gnn model, as it ensures consistency between the model and the dataset. As an alternative this can be set to None and the cutoff, buffer, long_range_cutoff and atomic_numbers can be provided as kwargs.
pooling_operation (str) – Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default ‘mean’
n_bases (int, optional) – Size of the basis set used for the embedding, by default 16
n_layers (int, optional) – Number of the graph convolution layers, by default 2
n_filters (int, optional) – Number of filters, by default 16
n_hidden_channels (int, optional) – Size of hidden embeddings, by default 16
aggr (str, optional) – Type of the GNN aggregation function, by default ‘mean’ Possible choices are: ‘mean’, ‘sum’, ‘max’, ‘min’, ‘mul’, ‘attention’/’attentional’ (shared attention gate across all layers), ‘attention_separate’/’attentional_separate’ (Independent attention gate for each layer).
w_out_after_pool (bool, optional) – Whether to apply the last linear transformation form hidden to output channels after the pooling sum, by default False
Methods
__init__(n_out[, ...])forward(data)The forward pass. :Parameters: data (Dict[str, torch.Tensor]) -- The data dict. Usually came from the to_dict method of a torch_geometric.data.Batch object.
Resets all learnable parameters of the module.
- embed_edge(data: Dict[str, Tensor], normalize: bool = True) Tuple[Tensor, Tensor, Tensor]¶
Performs the model edge embedding form torch_geometric.data.Batch object.
- Parameters:
data (Dict[str, torch.Tensor]) – The data dict. Usually from the to_dict method of a torch_geometric.data.Batch object.
normalize (bool) – If to return the normalized distance vectors, by default True.
- Returns:
edge_lengths (torch.Tensor (shape: [n_edges, 1])) – The edge lengths.
edge_length_embeddings (torch.Tensor (shape: [n_edges, n_bases])) – The edge length embeddings.
edge_unit_vectors (torch.Tensor (shape: [n_edges, 3])) – The normalized edge vectors.
- forward(data: Dict[str, Tensor]) Tensor[source]¶
The forward pass. :Parameters: data (Dict[str, torch.Tensor]) – The data dict. Usually came from the to_dict method of a
torch_geometric.data.Batch object.
- pooling(input: Tensor, data: Dict[str, Tensor]) Tensor¶
Performs pooling of the node-level outputs to obtain a graph-level output
- Parameters:
input (torch.Tensor) – Nodel level features to be pooled
data (Dict[str, torch.Tensor]) – Data batch containing the graph data informations
- Returns:
Pooled output
- Return type:
torch.Tensor
Attributes
T_destination
call_super_init
dump_patches
in_features
out_features
training