mlcolvar.core.nn.graph.GVPModel

class mlcolvar.core.nn.graph.GVPModel(n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation: str = 'mean', n_bases: int = 8, n_polynomials: int = 6, n_layers: int = 1, n_messages: int = 2, n_feedforwards: int = 2, n_scalars_node: int = 8, n_vectors_node: int = 8, n_scalars_edge: int = 8, drop_rate: int = 0.1, activation: str = 'SiLU', basis_type: str = 'bessel', smooth: bool = False, **kwargs)[source]

Bases: BaseGNN

The Geometric Vector Perceptron (GVP) model [1, 2] with vector gate [2].

References

__init__(n_out: int, dataset_for_initialization: DictDataset = None, pooling_operation: str = 'mean', n_bases: int = 8, n_polynomials: int = 6, n_layers: int = 1, n_messages: int = 2, n_feedforwards: int = 2, n_scalars_node: int = 8, n_vectors_node: int = 8, n_scalars_edge: int = 8, drop_rate: int = 0.1, activation: str = 'SiLU', basis_type: str = 'bessel', smooth: bool = False, **kwargs) None[source]

Initializes a Geometric Vector Perceptron (GVP) model.

Parameters:
  • n_out (int) – Number of the output scalar node features.

  • dataset_for_initialization (DictDataset, optional) – Dataset containing the graphs on which the gnn model will be applied. This is used to initialize and register the cutoff, buffer, and atomic_numbers from the dataset metadata. This is the preferred way to initialize the gnn model, as it ensures consistency between the model and the dataset. As an alternative this can be set to None and the cutoff, buffer, and atomic_numbers can be provided as kwargs.

  • pooling_operation (str) – Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default ‘mean’

  • n_bases (int) – Size of the basis set used for the embedding, by default 8.

  • n_polynomials (bool) – Order of the polynomials in the basis functions, by default 6.

  • n_layers (int) – Number of the graph convolution layers, by default 1.

  • n_messages (int) – Number of GVP layers to be used in the message functions, by default 2.

  • n_feedforwards (int) – Number of GVP layers to be used in the feedforward functions, by default 2.

  • n_scalars_node (int) – Size of the scalar channel of the node embedding in hidden layers, by default 8.

  • n_vectors_node (int) – Size of the vector channel of the node embedding in hidden layers, by default 8.

  • n_scalars_edge (int) – Size of the scalar channel of the edge embedding in hidden layers, by default 8.

  • drop_rate (int) – Drop probability in all dropout layers, by default 0.1.

  • activation (str) – Name of the activation function to be used in the GVPs (case sensitive), by default SiLU.

  • basis_type (str) – Type of the basis function, by default bessel.

  • smooth (bool) – If use the smoothed GVPConv, by default False.

Methods

__init__(n_out[, ...])

Initializes a Geometric Vector Perceptron (GVP) model.

forward(data)

The forward pass.

embed_edge(data: Dict[str, Tensor], normalize: bool = True) Tuple[Tensor, Tensor, Tensor]

Performs the model edge embedding form torch_geometric.data.Batch object.

Parameters:
  • data (Dict[str, torch.Tensor]) – The data dict. Usually from the to_dict method of a torch_geometric.data.Batch object.

  • normalize (bool) – If to return the normalized distance vectors, by default True.

Returns:

  • edge_lengths (torch.Tensor (shape: [n_edges, 1])) – The edge lengths.

  • edge_length_embeddings (torch.Tensor (shape: [n_edges, n_bases])) – The edge length embeddings.

  • edge_unit_vectors (torch.Tensor (shape: [n_edges, 3])) – The normalized edge vectors.

forward(data: Dict[str, Tensor]) Tensor[source]

The forward pass.

Parameters:

data (Dict[str, torch.Tensor]) – The data dict. Usually came from the to_dict method of a torch_geometric.data.Batch object.

pooling(input: Tensor, data: Dict[str, Tensor]) Tensor

Performs pooling of the node-level outputs to obtain a graph-level output

Parameters:
  • input (torch.Tensor) – Nodel level features to be pooled

  • data (Dict[str, torch.Tensor]) – Data batch containing the graph data informations

Returns:

Pooled output

Return type:

torch.Tensor

Attributes

T_destination

call_super_init

dump_patches

in_features

out_features

training