Source code for mlcolvar.data.graph.atomic

import warnings
import numpy as np
import mdtraj as md
from dataclasses import dataclass
from typing import List, Iterable, Optional

"""
The helper functions for atomic data. This module is taken from MACE directly:
https://github.com/ACEsuit/mace/blob/main/mace/tools/utils.py
https://github.com/ACEsuit/mace/blob/main/mace/data/utils.py
"""

__all__ = ['AtomicNumberTable', 'Configuration', 'Configurations']


[docs] class AtomicNumberTable: """The atomic number table. Used to map between one hot encodings and a given set of actual atomic numbers. """
[docs] def __init__(self, zs: List[int]) -> None: """Initializes an atomi number table object Parameters ---------- zs: List[int] The atomic numbers in this table """ self.zs = zs self.masses = [1.0] * len(zs) for i in range(len(zs)): try: m = md.element.Element.getByAtomicNumber(zs[i]).mass self.masses[i] = m except Exception: warnings.warn( 'Can not assign mass for atom number: {:d}'.format(zs[i]) )
def __len__(self) -> int: """Number of elements in the table""" return len(self.zs) def __str__(self) -> str: return f'AtomicNumberTable: {tuple(s for s in self.zs)}'
[docs] def index_to_z(self, index: int) -> int: """Maps the encoding to the actual atomic number Parameters ---------- index: int Index of the encoding to be mapped """ return self.zs[index]
[docs] def index_to_symbol(self, index: int) -> str: """Map the encoding to the atomic symbol Parameters ---------- index: int Index of the encoding to be mapped """ return md.element.Element.getByAtomicNumber(self.zs[index]).symbol
[docs] def z_to_index(self, atomic_number: int) -> int: """Maps an atomic number to the encoding. Parameters ---------- atomic_number: int The atomic number to be mapped """ return self.zs.index(atomic_number)
[docs] def zs_to_indices(self, atomic_numbers: np.ndarray) -> np.ndarray: """Maps an array of atomic number to the encodings. Parameters ---------- atomic_numbers: numpy.ndarray The atomic numbers to be mapped """ to_index_fn = np.vectorize(self.z_to_index) return to_index_fn(atomic_numbers)
[docs] @classmethod def from_zs(cls, atomic_numbers: Iterable[int]) -> 'AtomicNumberTable': """Build the table from an array atomic numbers. Parameters ---------- atomic_numbers: Iterable[int] The atomic numbers to be used for building the table """ z_set = set() for z in atomic_numbers: z_set.add(z) return cls(sorted(list(z_set)))
def get_masses(atomic_numbers: Iterable[int]) -> List[float]: """Get atomic masses from atomic numbers. Parameters ---------- atomic_numbers: Iterable[int] The atomic numbers for which to return the atomic masses """ return AtomicNumberTable.from_zs(atomic_numbers).masses.copy()
[docs] @dataclass class Configuration: """ Internal helper class that describe a given configuration of the system. Parameters ---------- atomic_numbers: np.ndarray Atomic numbers of the atoms in the system. Shape: [n_atoms] positions: np.ndarray Positions of the atoms in the system. Shape: [n_atoms, 3], units: Ang cell: np.ndarray Cell of the system. Shape: [n_atoms, 3], units: Ang pbc: Optional[tuple] Periodic boundary conditions of the system. Shape: [3] node_labels: Optional[np.ndarray] Node labels of the graph. Shape: [n_atoms, n_node_labels] graph_labels: Optional[np.ndarray] Graph-level labels of the configuration. Shape: [n_graph_labels, 1] weight: Optional[float] Weight of the configuration. Shape: [] system: Optional[np.ndarray] Indices of the system atoms. Shape: [n_system_atoms] environment: Optional[np.ndarray] Indices of the environment atoms. Shape: [n_environment_atoms] subsystem: Optional[np.ndarray] Indices of the subsystem atoms for long-range interactions. Shape: [n_subsystem_atoms] """ atomic_numbers: np.ndarray # shape: [n_atoms] positions: np.ndarray # shape: [n_atoms, 3], units: Ang cell: np.ndarray # shape: [n_atoms, 3], units: Ang pbc: Optional[tuple] # shape: [3] node_labels: Optional[np.ndarray] # shape: [n_atoms, n_node_labels] graph_labels: Optional[np.ndarray] # shape: [n_graph_labels, 1] weight: Optional[float] = 1.0 # shape: [] system: Optional[np.ndarray] = None # shape: [n_system_atoms] environment: Optional[np.ndarray] = None # shape: [n_environment_atoms] subsystem: Optional[np.ndarray] = None # shape: [n_subsystem]
Configurations = List[Configuration] def test_atomic_number_table() -> None: table = AtomicNumberTable([1, 6, 7, 8]) numbers = np.array([1, 7, 6, 8]) assert ( table.zs_to_indices(numbers) == np.array([0, 2, 1, 3], dtype=int) ).all() numbers = np.array([1, 1, 1, 6, 8, 1]) assert ( table.zs_to_indices(numbers) == np.array([0, 0, 0, 1, 3, 0], dtype=int) ).all() table_1 = AtomicNumberTable.from_zs([6] * 3 + [1] * 10 + [7] * 3 + [8] * 2) assert table_1.zs == table.zs