mlcolvar.core.nn.graph.BaseGNN¶
- class mlcolvar.core.nn.graph.BaseGNN(n_out: int, dataset_for_initialization: DictDataset, pooling_operation: str | None = None, n_bases: int = 6, n_polynomials: int = 6, basis_type: str = 'bessel', cutoff: float = None, buffer: float = None, long_range_cutoff: float = None, atomic_numbers: List[int] = None)[source]¶
Bases:
ModuleBase class for Graph Neural Network (GNN) models
- __init__(n_out: int, dataset_for_initialization: DictDataset, pooling_operation: str | None = None, n_bases: int = 6, n_polynomials: int = 6, basis_type: str = 'bessel', cutoff: float = None, buffer: float = None, long_range_cutoff: float = None, atomic_numbers: List[int] = None) None[source]¶
Initializes the core of a GNN model, taking care of edge embeddings.
- Parameters:
n_out (int) – Number of the output scalar node features.
dataset_for_initialization (DictDataset, optional) – Dataset containing the graphs on which the gnn model will be applied. This is used to initialize and register the cutoff, buffer, and atomic_numbers from the dataset metadata. This is the preferred way to initialize the gnn model, as it ensures consistency between the model and the dataset. As an alternative this can be set to None and the cutoff, buffer, and atomic_numbers can be provided as arguments.
pooling_operation (str or None) – Type of pooling operation to combine node-level features into graph-level features (‘mean’ or ‘sum’). If None, pooling is disabled and node-level outputs are returned unchanged.
n_bases (int, optional) – Size of the basis set used for the embedding, by default 6
n_polynomials (int, optional) – Order of the polynomials in the basis functions, by default 6
basis_type (str, optional) – Type of the basis function, by default ‘bessel’
cutoff (float) – When dataset_for_initialization is not provided, the cutoff radius of the basis functions, by default None. Should be the same as the cutoff radius used to build the graphs.
buffer (float, optional) – When dataset_for_initialization is not provided, the additional buffer radius used to find active environment atoms, by default None. Should be the same as the buffer used to build the graphs.
long_range_cutoff (float) – Cutoff radius for the long-range edges defined on subsystem atoms. If negative, no long-range interactions are considered, by default -1.0
atomic_numbers (List[int]) – When dataset_for_initialization is not provided, the atomic numbers mapping, by default None. Should be the same as the atomic numbers mapping used to build the graphs.
Methods
__init__(n_out, dataset_for_initialization)Initializes the core of a GNN model, taking care of edge embeddings.
embed_edge(data[, normalize])Performs the model edge embedding form torch_geometric.data.Batch object.
pooling(input, data)Performs pooling of the node-level outputs to obtain a graph-level output
- embed_edge(data: Dict[str, Tensor], normalize: bool = True) Tuple[Tensor, Tensor, Tensor][source]¶
Performs the model edge embedding form torch_geometric.data.Batch object.
- Parameters:
data (Dict[str, torch.Tensor]) – The data dict. Usually from the to_dict method of a torch_geometric.data.Batch object.
normalize (bool) – If to return the normalized distance vectors, by default True.
- Returns:
edge_lengths (torch.Tensor (shape: [n_edges, 1])) – The edge lengths.
edge_length_embeddings (torch.Tensor (shape: [n_edges, n_bases])) – The edge length embeddings.
edge_unit_vectors (torch.Tensor (shape: [n_edges, 3])) – The normalized edge vectors.
- pooling(input: Tensor, data: Dict[str, Tensor]) Tensor[source]¶
Performs pooling of the node-level outputs to obtain a graph-level output
- Parameters:
input (torch.Tensor) – Nodel level features to be pooled
data (Dict[str, torch.Tensor]) – Data batch containing the graph data informations
- Returns:
Pooled output
- Return type:
torch.Tensor
Attributes
T_destination
call_super_init
dump_patches
in_features
out_features
training