mlcolvar.core.nn.graph.GVPModel¶
- class mlcolvar.core.nn.graph.GVPModel(n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation: str = 'mean', n_bases: int = 8, n_polynomials: int = 6, n_layers: int = 1, n_messages: int = 2, n_feedforwards: int = 2, n_scalars_node: int = 8, n_vectors_node: int = 8, n_scalars_edge: int = 8, drop_rate: int = 0.1, activation: str = 'SiLU', basis_type: str = 'bessel', smooth: bool = False, **kwargs)[source]¶
Bases:
BaseGNNThe Geometric Vector Perceptron (GVP) model [1, 2] with vector gate [2].
References
- __init__(n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation: str = 'mean', n_bases: int = 8, n_polynomials: int = 6, n_layers: int = 1, n_messages: int = 2, n_feedforwards: int = 2, n_scalars_node: int = 8, n_vectors_node: int = 8, n_scalars_edge: int = 8, drop_rate: int = 0.1, activation: str = 'SiLU', basis_type: str = 'bessel', smooth: bool = False, **kwargs) None[source]¶
Initializes a Geometric Vector Perceptron (GVP) model.
- Parameters:
n_out (int) – Number of the output scalar node features.
dataset_for_initialization (DictDataset, optional) – Dataset containing the graphs on which the gnn model will be applied. This is used to initialize and register the cutoff, buffer, and atomic_numbers from the dataset metadata. This is the preferred way to initialize the gnn model, as it ensures consistency between the model and the dataset. As an alternative this can be set to None and the cutoff, buffer, and atomic_numbers can be provided as kwargs.
pooling_operation (str) – Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default ‘mean’
n_bases (int) – Size of the basis set used for the embedding, by default 8.
n_polynomials (bool) – Order of the polynomials in the basis functions, by default 6.
n_layers (int) – Number of the graph convolution layers, by default 1.
n_messages (int) – Number of GVP layers to be used in the message functions, by default 2.
n_feedforwards (int) – Number of GVP layers to be used in the feedforward functions, by default 2.
n_scalars_node (int) – Size of the scalar channel of the node embedding in hidden layers, by default 8.
n_vectors_node (int) – Size of the vector channel of the node embedding in hidden layers, by default 8.
n_scalars_edge (int) – Size of the scalar channel of the edge embedding in hidden layers, by default 8.
drop_rate (int) – Drop probability in all dropout layers, by default 0.1.
activation (str) – Name of the activation function to be used in the GVPs (case sensitive), by default SiLU.
basis_type (str) – Type of the basis function, by default bessel.
smooth (bool) – If use the smoothed GVPConv, by default False.
Methods
__init__(n_out[, ...])Initializes a Geometric Vector Perceptron (GVP) model.
forward(data)The forward pass.
- embed_edge(data: Dict[str, Tensor], normalize: bool = True) Tuple[Tensor, Tensor, Tensor]¶
Performs the model edge embedding form torch_geometric.data.Batch object.
- Parameters:
data (Dict[str, torch.Tensor]) – The data dict. Usually from the to_dict method of a torch_geometric.data.Batch object.
normalize (bool) – If to return the normalized distance vectors, by default True.
- Returns:
edge_lengths (torch.Tensor (shape: [n_edges, 1])) – The edge lengths.
edge_length_embeddings (torch.Tensor (shape: [n_edges, n_bases])) – The edge length embeddings.
edge_unit_vectors (torch.Tensor (shape: [n_edges, 3])) – The normalized edge vectors.
- forward(data: Dict[str, Tensor]) Tensor[source]¶
The forward pass.
- Parameters:
data (Dict[str, torch.Tensor]) – The data dict. Usually came from the to_dict method of a torch_geometric.data.Batch object.
- pooling(input: Tensor, data: Dict[str, Tensor]) Tensor¶
Performs pooling of the node-level outputs to obtain a graph-level output
- Parameters:
input (torch.Tensor) – Nodel level features to be pooled
data (Dict[str, torch.Tensor]) – Data batch containing the graph data informations
- Returns:
Pooled output
- Return type:
torch.Tensor
Attributes
T_destination
call_super_init
dump_patches
in_features
out_features
training