import copy
from collections import defaultdict
from typing import Union
import torch
import torch_geometric
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import BaseTransform
from mlcolvar.data import DictDataset, DictModule
from mlcolvar.data.graph import atomic
from mlcolvar.data.graph.neighborhood import get_neighborhood
from mlcolvar.utils.plot import pbar
from typing import List
__all__ = ["create_dataset_from_configurations", "create_test_graph_input"]
def _create_pyg_data_from_configuration(
config: atomic.Configuration,
atomic_numbers: atomic.AtomicNumberTable,
cutoff: float,
buffer: float = 0.0,
long_range_cutoff: float = -1.0,
) -> torch_geometric.data.Data:
"""Build the torch_geometric graph data object from a configuration.
Parameters
----------
config: mlcolvar.data.graph.atomic.Configuration
The configuration from which to generate the graph data
atomic_numbers: mlcolvar.data.graph.atomic.AtomicNumberTable
The atomic number table used to build the node attributes
cutoff: float
The graph cutoff radius
buffer: float
Buffer size used in finding active environment atoms if
restricting the neighborhood to a subsystem (i.e., system + environment),
`see also mlcolvar.data.grap.neighborhood.get_neighborhood`
long_range_cutoff : float
Cutoff radius for the long-range edges defined on subsystem atoms.
If negative, no long-range interactions are considered, by default -1.0
"""
assert config.graph_labels is None or len(config.graph_labels.shape) == 2
assert not ((config.subsystem is not None) ^ (long_range_cutoff > 0))
# NOTE: here we do not take care about the nodes that are not taking part
# the graph, like, we don't even change the node indices in `edge_index`.
# Here we simply ignore them, and rely on the `RemoveIsolatedNodes` method
# that will be called later (in `create_dataset_from_configurations`).
edge_index, shifts, unit_shifts = get_neighborhood(positions=config.positions,
cutoff=cutoff,
cell=config.cell,
pbc=config.pbc,
system_indices=config.system,
environment_indices=config.environment,
buffer=buffer
)
if config.subsystem is not None:
config.subsystem = np.array(config.subsystem)
edge_index_lr, shifts_lr, unit_shifts_lr = get_neighborhood(
positions=config.positions[config.subsystem],
cutoff=long_range_cutoff,
cell=config.cell,
pbc=config.pbc,
)
edge_index_lr = np.vstack([config.subsystem[edge_index_lr[0]],
config.subsystem[edge_index_lr[1]]])
edge_index = np.hstack([edge_index,
edge_index_lr])
shifts = np.vstack([shifts,
shifts_lr])
unit_shifts = np.vstack([unit_shifts,
unit_shifts_lr])
edge_index = torch.tensor( edge_index, dtype=torch.long )
shifts = torch.tensor( shifts, dtype=torch.get_default_dtype() )
unit_shifts = torch.tensor( unit_shifts, dtype=torch.get_default_dtype() )
positions = torch.tensor( config.positions, dtype=torch.get_default_dtype() )
cell = torch.tensor( config.cell, dtype=torch.get_default_dtype() )
node_labels = torch.tensor( config.node_labels, dtype=torch.get_default_dtype() ) if config.node_labels is not None else None
graph_labels = torch.tensor( config.graph_labels, dtype=torch.get_default_dtype() ) if config.graph_labels is not None else None
weight = torch.tensor( config.weight, dtype=torch.get_default_dtype() ) if config.weight is not None else 1
# get indices from atomic numbers and convert to one_hot
indices = atomic_numbers.zs_to_indices(config.atomic_numbers)
one_hot = to_one_hot( torch.tensor( indices, dtype=torch.long ).unsqueeze(-1), n_classes=len(atomic_numbers) )
# set n_system and system_mask
if config.system is not None:
n_system = torch.tensor( [[len(config.system)]], dtype=torch.get_default_dtype() )
system_masks = torch.zeros((one_hot.shape[0], 1), dtype=torch.bool)
system_masks[config.system, 0] = 1
else:
n_system = torch.tensor( [[one_hot.shape[0]]], dtype=torch.get_default_dtype() )
system_masks = torch.ones((one_hot.shape[0], 1), dtype=torch.bool)
# set subsystem_masks and edge_masks_lr
if config.subsystem is not None:
subsystem_masks = torch.zeros((one_hot.shape[0], 1), dtype=torch.bool)
subsystem_masks[config.subsystem, 0] = 1
edge_masks_lr = torch.zeros((edge_index.shape[1], 1), dtype=torch.bool)
edge_masks_lr[-edge_index_lr.shape[1]:, 0] = 1
else:
subsystem_masks = torch.zeros((one_hot.shape[0], 1), dtype=torch.bool)
edge_masks_lr = torch.zeros((edge_index.shape[1], 1), dtype=torch.bool)
# set n_env
n_env = torch.tensor( [[one_hot.shape[0] - n_system.to(torch.int).item()]], dtype=torch.get_default_dtype() )
pyg_data = torch_geometric.data.Data(edge_index=edge_index,
shifts=shifts,
unit_shifts=unit_shifts,
positions=positions,
cell=cell,
node_attrs=one_hot,
node_labels=node_labels,
graph_labels=graph_labels,
n_system=n_system,
n_env=n_env,
system_masks=system_masks,
weight=weight,
subsystem_masks=subsystem_masks,
edge_masks_lr=edge_masks_lr,
)
return pyg_data
[docs]
def create_dataset_from_configurations(config: atomic.Configurations,
atomic_numbers: atomic.AtomicNumberTable,
cutoff: float,
buffer: float = 0.0,
long_range_cutoff: float = -1.0,
atom_names: List = None,
remove_isolated_nodes: bool = False,
show_progress: bool = True
) -> DictDataset:
"""Build DictDataset object containing torch_geometric graph data objects from configurations.
Parameters
----------
config: mlcolvar.graph.utils.atomic.Configurations
The configurations from whihc to generate the dataset
atomic_numbers: mlcolvar.graph.utils.atomic.AtomicNumberTable
The atomic number table used to build the node attributes
cutoff: float
The graph cutoff radius
buffer: float
Buffer size used in finding active environment atoms if
restricting the neighborhood to a subsystem (i.e., system + environment),
`see also mlcolvar.data.grap.neighborhood.get_neighborhood`
long_range_cutoff : float
Cutoff radius for the long-range edges defined on subsystem atoms.
If negative, no long-range interactions are considered, by default -1.0
atom_names, List[str]
Names of the system atoms. If using truncated graphs, use the system atoms only
remove_isolated_nodes: bool
If to remove isolated nodes from the dataset
show_progress: bool
If to show the progress bar
"""
if show_progress:
items = pbar(config, frequency=0.0001, prefix='Making graphs')
else:
items = config
# create a list of torch_geometric data objects, one for each configuration
data_list = [ _create_pyg_data_from_configuration(config=c,
atomic_numbers=atomic_numbers,
cutoff=cutoff,
buffer=buffer,
long_range_cutoff=long_range_cutoff,
) for c in items
]
try:
truncated = True if any([c.environment.any() for c in items]) else False
except AttributeError:
truncated = True if any([c.environment for c in items]) else False
# get atom names if needed
if atom_names is None:
atom_names = [f"X{i}" for i in range(data_list[0]['n_system'].to(torch.int64).item())]
if remove_isolated_nodes:
# TODO: not the worst way to fake the `is_node_attr` method of `torch_geometric.data.storage.GlobalStorage`
# If there are exact three atoms in the graph, the `RemoveIsolatedNodes` method will remove the cell vectors that
# correspond to the isolated node. This is a consequence of pyg regarding the cell vectors as some kind of node features.
# So here we first have to remove the isolated nodes and then set the cell back.
# this aux var is only to check what isolated nodes have been removed
_pre_remove_nodes_pos = torch.Tensor(np.array([d['positions'].numpy() for d in data_list]))
_pre_remove_nodes_system_masks = np.array([d['system_masks'].numpy() for d in data_list], dtype=bool)
cell_list = [d.cell.clone() for d in data_list]
transform = _RemoveIsolatedNodes()
data_list = [transform(d) for d in data_list]
# check what have been removed and restore cell
unique_idx = [] # store the indeces of the atoms that have been used at least once
for i in range(len(data_list)):
data_list[i].cell = cell_list[i]
# get and save the original index before removing isolated nodes for each entry
# we slice by the system mask as we don't care about environment atoms
original_idx = torch.unique( torch.where(torch.isin(torch.round(_pre_remove_nodes_pos[i][_pre_remove_nodes_system_masks[i].squeeze(), :], decimals=5),
torch.round(data_list[i]['positions'][data_list[i]['system_masks'].squeeze(), :], decimals=5))
)[0]
)
data_list[i]['system_names_idx'] = original_idx.to(torch.int64)
# update if needed the overall list
check = np.isin(original_idx.numpy(), unique_idx, invert=True)
if check.any():
aux = np.where(check)[0]
unique_idx.extend(original_idx[aux].tolist())
unique_idx.sort()
unique_idx = torch.Tensor(unique_idx).to(torch.int64)
# if not remove_isolated_nodes we simply take all the system atoms
else:
unique_idx = torch.arange(data_list[0]['n_system'].item()).to(torch.int64)
for i in range(len(data_list)):
data_list[i]['system_names_idx'] = unique_idx
# we also save the names of the atoms that have been actually used, ensuring correct dimensions
unique_names = np.array(atom_names)[unique_idx] if len(unique_idx) > 1 else np.array(np.array(atom_names)[unique_idx])
unique_names = unique_names.tolist()
dataset = DictDataset(dictionary={'data_list': data_list},
metadata={'atomic_numbers': atomic_numbers.zs,
'cutoff': cutoff,
'buffer': buffer,
'long_range_cutoff': long_range_cutoff,
'system_idx': unique_idx,
'system_atoms_names': unique_names,
'is_truncated_graph': truncated},
data_type='graphs')
return dataset
def to_one_hot(indices: torch.Tensor, n_classes: int) -> torch.Tensor:
"""Generates one-hot encoding with `n_classes` classes from `indices`
Parameters
----------
indices: torch.Tensor (shape: [N, 1])
Node indices
n_classes: int
Number of classes
Returns
-------
encoding: torch.tensor (shape: [N, n_classes])
The one-hot encoding
"""
shape = indices.shape[:-1] + (n_classes,)
one_hot = torch.zeros(shape, device=indices.device).view(shape)
# scatter_ is the in-place version of scatter
one_hot.scatter_(dim=-1, index=indices, value=1)
return one_hot.view(*shape)
class _RemoveIsolatedNodes(BaseTransform):
r"""Removes isolated nodes from the graph
This is taken from pytorch_geometric with a small modification to avoid the bug when n_nodes==n_edges
"""
def forward(self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
"""Remove isolated nodes from graphs in a pytorch_geometric Data object
Parameters
----------
data : Union[Data, HeteroData]
Pytorch_geometric Data object containing the graph data
Returns
-------
Union[Data, HeteroData]
Pytorch_geometric Data object containing the modified graph data
"""
# Gather all nodes that occur in at least one edge (across all types):
n_ids_dict = defaultdict(list)
for edge_store in data.edge_stores:
if 'edge_index' not in edge_store:
continue
if edge_store._key is None:
src = dst = None
else:
src, _, dst = edge_store._key
n_ids_dict[src].append(edge_store.edge_index[0])
n_ids_dict[dst].append(edge_store.edge_index[1])
n_id_dict = {k: torch.cat(v).unique() for k, v in n_ids_dict.items()}
n_map_dict = {}
for node_store in data.node_stores:
if node_store._key not in n_id_dict:
n_id_dict[node_store._key] = torch.empty(0, dtype=torch.long)
idx = n_id_dict[node_store._key]
assert data.num_nodes is not None
mapping = idx.new_zeros(data.num_nodes)
mapping[idx] = torch.arange(idx.numel(), device=mapping.device)
n_map_dict[node_store._key] = mapping
for edge_store in data.edge_stores:
if 'edge_index' not in edge_store:
continue
if edge_store._key is None:
src = dst = None
else:
src, _, dst = edge_store._key
row = n_map_dict[src][edge_store.edge_index[0]]
col = n_map_dict[dst][edge_store.edge_index[1]]
edge_store.edge_index = torch.stack([row, col], dim=0)
old_data = copy.copy(data)
for out, node_store in zip(data.node_stores, old_data.node_stores):
for key, value in node_store.items():
if key == 'num_nodes':
out.num_nodes = n_id_dict[node_store._key].numel()
elif node_store.is_node_attr(key) and key not in ['shifts', 'unit_shifts']:
out[key] = value[n_id_dict[node_store._key]]
return data
def create_test_graph_input(output_type: str,
n_atoms: int = 3,
n_samples: int = 60,
n_states: int = 2,
random_weights = False,
add_noise = True,
environment: bool = False,
long_range: bool = False,):
"""
Util function to generate several types of mock graph data objects for testing purposes.
The graphs are created drawing positions from a predefined set of positions that cover most use cases.
It can generate: one or some configuration objects, a dataset, a datamodule, a batch of example inputs or a single item.
Parameters
----------
output_type : str
Type of graph data object to create. Can be: 'configuration', 'configurations', 'datamodule', 'dataset', 'batch', 'example'
n_atoms : int, optional
Number of atoms for creating the graph, either 3 or 4, by default 3
n_samples : int, optional
Number of samples per state to create, by default 60
n_states : int, optional
Number of states for which to create data, by default 2. Configurations are then labelled accordingly.
random_weights : bool, optional
If to assign random weights to the entries, otherwise unitary weights are given, by default False
add_noise : bool, optional
If to add a random noise for each entry to the predefined positions, by default True
environment : bool, optional
Whether to include environment nodes in the graph, by default False
long_range : bool, optional
Whether to include long-range edges in the graph, by default False
Returns
-------
Graph data object of the chosen type
"""
if n_atoms == 3:
numbers = [8, 1, 1]
system_atoms = [0,1] if environment else None
environment_atoms = [3] if environment else None
buffer = 0.1 if environment else 0.0
node_labels = np.array([[0], [1], [1]])
_ref_positions = np.array(
[
[[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]],
[[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0]],
[[0.0, 0.0, 0.0], [0.07, -0.07, 0.0], [0.07, 0.07, 0.0]],
[[0.0, 0.0, 0.0], [0.0, -0.07, 0.07], [0.0, 0.07, 0.07]],
[[0.0, 0.0, 0.0], [0.07, 0.0, 0.07], [-0.07, 0.0, 0.07]],
[[0.1, 0.0, 1.1], [0.17, 0.07, 1.1], [0.17, -0.07, 1.1]],
],
dtype=np.float64
)
elif n_atoms == 4:
numbers = [8, 1, 1, 8]
system_atoms = [0,1,2] if environment else None
environment_atoms = [3] if environment else None
buffer = 0.1 if environment else 0.0
node_labels = np.array([[0], [1], [1], [0]])
_ref_positions = np.array(
[
[[0.0, 0.0, 0.0], [0.07, 0.07, 0.0] , [0.07, -0.07, 0.0], [0.05, -0.05, 0.0]],
[[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0], [0.05, 0.05, 0.0]],
[[0.0, 0.0, 0.0], [0.07, -0.07, 0.0], [0.07, 0.07, 0.0], [0.05, 0.05, 0.0]],
[[0.0, 0.0, 0.0], [0.0, -0.07, 0.07], [0.0, 0.07, 0.07], [0.0, 0.05, 0.05]],
[[0.0, 0.0, 0.0], [0.11, 0.11, 0.11] , [-0.07, 0.0, 0.07], [-0.05, 0.0, 0.05]],
[[0.1, 0.0, 1.1], [0.17, 0.07, 1.1] , [0.17, -0.07, 1.1], [0.15, -0.05, 1.1]],
],
dtype=np.float64
)
else:
raise ValueError(f'Example input can be generated either with 3 or 4 atoms, found {n_atoms}')
subsystem_atoms = [0,1] if long_range else None
np.random.seed(0)
idx = np.random.randint(low=0, high=6, size=(n_samples*n_states))
positions = _ref_positions[idx, :, :]
# let's add some noise to the positions for fun
if add_noise:
noise = np.random.randn(*positions.shape)*1e-5
positions = positions + noise
cell = np.identity(3, dtype=float) * 0.2
graph_labels = np.zeros((n_samples*n_states, 1, 1))
for i in range(1, n_states):
graph_labels[n_samples * i :] += 1
atomic_numbers = atomic.AtomicNumberTable.from_zs(numbers)
if random_weights:
weights = np.random.random_sample((n_samples*n_states, 1, 1))
else:
weights = np.ones((n_samples*n_states, 1, 1))
config = [
atomic.Configuration(
atomic_numbers=numbers,
positions=positions[i],
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels[i],
weight=weights[i],
system=system_atoms,
environment=environment_atoms,
subsystem=subsystem_atoms,
) for i in range(0, n_samples*n_states)
]
if output_type == 'configuration':
return config[0]
if output_type == 'configurations':
return config
dataset = create_dataset_from_configurations(
config=config,
atomic_numbers=atomic_numbers,
cutoff=0.1,
show_progress=False,
remove_isolated_nodes=True,
buffer=buffer,
long_range_cutoff=0.3 if long_range else -1.0,
)
if output_type == 'dataset':
return dataset
datamodule = DictModule(
dataset,
lengths=(0.8, 0.2),
batch_size=0,
shuffle=False,
)
if output_type == 'datamodule':
return datamodule
datamodule.setup()
batch = next(iter(datamodule.train_dataloader()))
if output_type == 'batch':
return batch
example = batch['data_list'].get_example(0)
example['batch'] = torch.zeros(len(example['positions']), dtype=torch.int64)
if output_type == 'example':
return example
return None
def create_graph_tracing_example(n_species: int,
environment: bool = False,
long_range: bool = False) -> dict:
"""
Util to create a tracing example for graph based models.
Parameters
----------
n_species : int
Number of chemical species to be considered in the model.
environment : bool, optional
Whether to include environment nodes in the graph, by default False.
long_range : bool, optional
Whether to include long-range edges in the graph, by default False.
Returns
-------
dict
Tracing graph input example as dict.
"""
numbers = [1, 1, 1]
node_labels = np.array([[0], [0], [0]])
_ref_positions = np.array(
[
[[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]],
[[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0]],
[[0.0, 0.0, 0.0], [0.07, -0.07, 0.0], [0.07, 0.07, 0.0]],
[[0.0, 0.0, 0.0], [0.0, -0.07, 0.07], [0.0, 0.07, 0.07]],
[[0.0, 0.0, 0.0], [0.11, 0.11, 0.11], [-0.07, 0.0, 0.07]],
[[0.1, 0.0, 1.1], [0.17, 0.07, 1.1], [0.17, -0.07, 1.1]],
],
dtype=np.float64
)
idx = np.random.randint(low=0, high=6, size=1)
positions = _ref_positions[idx, :, :]
cell = np.identity(3, dtype=float) * 0.2
graph_labels = np.zeros((1, 1, 1))
atomic_numbers = atomic.AtomicNumberTable.from_zs(numbers)
system_atoms = [0,1] if environment else None
environment_atoms = [2] if environment else None
subsystem_atoms = [0,1] if long_range else None
weights = np.ones((1, 1, 1))
config = [
atomic.Configuration(
atomic_numbers=numbers,
positions=positions[i],
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels[i],
weight=weights[i],
system=system_atoms,
environment=environment_atoms,
subsystem=subsystem_atoms,
) for i in range(0, 1)
]
# here we do not remove isolated nodes
dataset = create_dataset_from_configurations(config=config,
atomic_numbers=atomic_numbers,
cutoff=0.1,
long_range_cutoff=0.3 if long_range else -1.0,
show_progress=False,
remove_isolated_nodes=False
)
datamodule = DictModule(
dataset,
lengths=(0.8, 0.2),
batch_size=0,
shuffle=False,
)
datamodule.setup()
batch = next(iter(datamodule.train_dataloader()))
example = batch['data_list'].get_example(0)
example['batch'] = torch.zeros(len(example['positions']), dtype=torch.int64)
example = example.to_dict()
example['node_attrs'] = torch.cat((example['node_attrs'], torch.zeros(3, n_species - 1)), 1)
return example
# ===============================================================================
# ===============================================================================
# ==================================== TESTS ====================================
# ===============================================================================
# ===============================================================================
import numpy as np
def test_to_one_hot() -> None:
i = torch.tensor([[0], [2], [1]], dtype=torch.int64)
e = to_one_hot(i, 4)
assert (
e == torch.tensor(
[[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=torch.int64
)
).all()
def test_from_configuration() -> None:
# fake atomic numbers, positions, cell, graph label, node labels
numbers = [8, 1, 1]
positions = np.array([[0.0, 0.0, 0.0],
[0.07, 0.07, 0.0],
[0.07, -0.07, 0.0]],
dtype=float
)
cell = np.identity(3, dtype=float) * 0.2
graph_labels = np.array([[1]])
node_labels = np.array([[0], [1], [1]])
# init AtomicNumber object
atomic_numbers = atomic.AtomicNumberTable.from_zs(numbers)
# initialize configuration using all atoms
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
)
# create dataset from a configuration
data = _create_pyg_data_from_configuration(config, atomic_numbers, 0.1)
# check edges and shifts are created correctly
assert(data['edge_index'] == torch.tensor([[0, 0, 1, 1, 2, 2],
[2, 1, 0, 2, 1, 0]])
).all()
assert(data['shifts'] == torch.tensor([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, -0.2, 0.0],
[0.0, 0.0, 0.0]])
).all()
assert(data['unit_shifts'] == torch.tensor([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 0.0, 0.0]])
).all()
# check correct storage
assert(data['positions'] == torch.tensor([[0.0, 0.0, 0.0],
[0.07, 0.07, 0.0],
[0.07, -0.07, 0.0]])
).all()
assert(data['cell'] == torch.tensor([[0.2, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, 0.0, 0.2]])
).all()
assert(data['node_attrs'] == torch.tensor([[0.0, 1.0],
[1.0, 0.0],
[1.0, 0.0]])
).all()
assert(data['node_labels'] == torch.tensor([[0.0],
[1.0],
[1.0]])
).all()
assert(data['graph_labels'] == torch.tensor([[1.0]])).all()
assert(data['weight'] == 1.0)
# initialize configuration using two atoms (1 system, 1 env) as a subset
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
system=[1],
environment=[2]
)
data = _create_pyg_data_from_configuration(config, atomic_numbers, 0.1)
# check edges and shift are computed correctly
assert(data['edge_index'] == torch.tensor([[1, 2],
[2, 1]])
).all()
assert (data['shifts'] == torch.tensor([[0.0, 0.2, 0.0],
[0.0, -0.2, 0.0]])
).all()
assert(data['unit_shifts'] == torch.tensor([[0.0, 1.0, 0.0],
[0.0, -1.0, 0.0]])
).all()
# initialize configuration using three atoms (1 system, 2 env) as a subset and no buffer
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
system=[0],
environment=[1, 2]
)
data = _create_pyg_data_from_configuration(config, atomic_numbers, 0.1)
assert(data['edge_index'] == torch.tensor([[0, 0, 1, 1, 2, 2],
[2, 1, 0, 2, 1, 0]])
).all()
# check if pbc and cutoffs works. now the third atoms is too far
positions = np.array([[0.0, 0.0, 0.0],
[0.07, 0.07, 0.0],
[0.07, -0.08, 0.0]],
dtype=float
)
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
system=[0],
environment=[1, 2]
)
# create dataset with same cutoff
data = _create_pyg_data_from_configuration(config, atomic_numbers, 0.1)
# check third atom is not included anymore
assert (data['edge_index'] == torch.tensor([[0, 1],
[1, 0]])
).all()
# create dataset with slightly large cutoff
data = _create_pyg_data_from_configuration(config, atomic_numbers, 0.11)
# check the edge with the third atom is created once again
assert(data['edge_index'] == torch.tensor([[0, 0, 1, 1, 2, 2],
[2, 1, 0, 2, 1, 0]])
).all()
# check with buffer layer
# the third atoms should be included but with no edge to the system atom
data = _create_pyg_data_from_configuration(config, atomic_numbers, 0.1, 0.01)
assert(data['edge_index'] == torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]])
).all()
assert(data['shifts'] == torch.tensor([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, -0.2, 0.0]])
).all()
assert(data['unit_shifts'] == torch.tensor([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, -1.0, 0.0]])
).all()
# create a list of configurations
config = [atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=np.array([[i]]),
) for i in range(0, 10)]
# create dataset from list of configurations
dataset = create_dataset_from_configurations(config,
atomic_numbers,
0.1,
show_progress=False)
# check if the labels of the entries are created correctly
assert dataset.metadata['atomic_numbers'] == [1, 8]
assert (dataset[0]['data_list']['graph_labels'] == torch.tensor([[0.0]])).all()
assert (dataset[2]['data_list']['graph_labels'] == torch.tensor([[2.0]])).all()
assert (dataset[4]['data_list']['graph_labels'] == torch.tensor([[4.0]])).all()
# dataset_1 = dataset[np.array([0, -1])]
assert dataset.metadata['atomic_numbers'] == [1, 8]
assert (dataset[ 0]['data_list']['graph_labels'] == torch.tensor([[0.0]])).all()
assert (dataset[-1]['data_list']['graph_labels'] == torch.tensor([[9.0]])).all()
def test_from_configuration_long_cutoff() -> None:
# fake atomic numbers, positions, cell, graph label, node labels
numbers = [8, 1, 1]
positions = np.array([[0.0, 0.0, 0.0],
[0.07, 0.07, 0.0],
[0.07, -0.08, 0.0]],
dtype=float
)
cell = np.identity(3, dtype=float) * 0.2
graph_labels = np.array([[1]])
node_labels = np.array([[0], [1], [1]])
# init AtomicNumber object
atomic_numbers = atomic.AtomicNumberTable.from_zs(numbers)
# initialize configuration using all atoms
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
system=[1, 2],
environment=[0],
subsystem=[1, 2],
)
# create dataset from a configuration
data = _create_pyg_data_from_configuration(
config, atomic_numbers, 0.1, long_range_cutoff=0.11
)
# check edges and shifts are created correctly
assert (
data['edge_index'] == torch.tensor(
[[0, 1, 1, 2, 1, 2], [1, 0, 2, 1, 2, 1]]
)
).all()
assert (
data['shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, -0.2, 0.0],
[0.0, 0.2, 0.0],
[0.0, -0.2, 0.0],
])
).all()
assert (
data['unit_shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, -1.0, 0.0],
])
).all()
assert (
data['positions'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.07, 0.07, 0.0],
[0.07, -0.08, 0.0],
])
).all()
assert (
data['cell'] == torch.tensor([
[0.2, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, 0.0, 0.2],
])
).all()
assert (
data['node_attrs'] == torch.tensor([
[0.0, 1.0], [1.0, 0.0], [1.0, 0.0]
])
).all()
assert (data['edge_masks_lr'] == torch.tensor([[0]] * 4 + [[1]] * 2)).all()
assert (data['node_labels'] == torch.tensor([[0.0], [1.0], [1.0]])).all()
assert (data['graph_labels'] == torch.tensor([[1.0]])).all()
assert (data['edge_masks_lr'] == torch.tensor([[0]] * 4 + [[1]] * 2)).all()
assert (data['system_masks'] == torch.tensor([[0], [1], [1]])).all()
assert (data['subsystem_masks'] == torch.tensor([[0], [1], [1]])).all()
assert data['weight'] == 1.0
# initialize configuration using all atoms
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
system=[0, 2],
environment=[1],
subsystem=[0, 2],
)
# create dataset from a configuration
data = _create_pyg_data_from_configuration(
config, atomic_numbers, 0.1, long_range_cutoff=0.11
)
# check edges and shifts are created correctly
assert (
data['edge_index'] == torch.tensor(
[[0, 1, 1, 2, 0, 2], [1, 0, 2, 1, 2, 0]]
)
).all()
assert (data['edge_masks_lr'] == torch.tensor([[0]] * 4 + [[1]] * 2)).all()
assert (data['system_masks'] == torch.tensor([[1], [0], [1]])).all()
assert (data['subsystem_masks'] == torch.tensor([[1], [0], [1]])).all()
# fake atomic numbers, positions, cell, graph label, node labels
numbers = [8, 1, 1, 8, 1, 1]
positions = np.array(
[
[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0],
[0.0, 0.8, 0.0], [0.07, 0.88, 0.0], [0.07, 0.73, 0.0],
],
dtype=float
)
cell = np.identity(3, dtype=float)
graph_labels = np.array([[1]])
node_labels = np.array([[0], [1], [1], [0], [1], [1]])
atomic_numbers = atomic.AtomicNumberTable.from_zs(numbers)
# initialize configuration using all atoms
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
system=[0, 3],
environment=[1, 2, 4, 5],
subsystem=[0, 3],
)
# create dataset from a configuration
data = _create_pyg_data_from_configuration(
config, atomic_numbers, 0.1, long_range_cutoff=0.4
)
# check edges and shifts are created correctly
assert (
data['edge_index'] == torch.tensor([
[[0, 0, 1, 2, 3, 5, 0, 3], [2, 1, 0, 0, 5, 3, 3, 0]]
])
).all()
assert (
data['shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 1.0, 0.0],
])
).all()
assert (
data['unit_shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 1.0, 0.0],
])
).all()
assert (data['edge_masks_lr'] == torch.tensor(
[[0]] * 6 + [[1]] * 2
)).all()
assert (data['subsystem_masks'] == torch.tensor(
[[1], [0], [0], [1], [0], [0]]
)).all()
# create dataset from a configuration
data = _create_pyg_data_from_configuration(
config, atomic_numbers, 0.1, buffer=0.011, long_range_cutoff=0.4
)
# check edges and shifts are created correctly
assert (
data['edge_index'] == torch.tensor([
[0, 0, 1, 2, 2, 3, 4, 5, 0, 3],
[2, 1, 0, 4, 0, 5, 2, 3, 3, 0]
])
).all()
assert (
data['shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 1.0, 0.0],
])
).all()
assert (
data['unit_shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 1.0, 0.0],
])
).all()
assert (data['edge_masks_lr'] == torch.tensor(
[[0]] * 8 + [[1]] * 2
)).all()
def test_from_configurations() -> None:
# fake atomic numbers, positions, cell, graph label, node labels
numbers = [8, 1, 1]
positions = np.array([[0.0, 0.0, 0.0],
[0.07, 0.07, 0.0],
[0.07, -0.07, 0.0]],
dtype=float
)
cell = np.identity(3, dtype=float) * 0.2
graph_labels = np.array([[1]])
node_labels = np.array([[0], [1], [1]])
# init AtomicNumber object
atomic_numbers = atomic.AtomicNumberTable.from_zs(numbers)
# initialize configuration using all atoms
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
)
# create dataset from a configuration, even if single is the multiple function
dataset = create_dataset_from_configurations([config],
atomic_numbers,
0.1,
remove_isolated_nodes=True,
show_progress=False
)[0]
# take data entry from the DictDataset
data = dataset['data_list']
# check edges and shifts are created correctly
assert(data['edge_index'] == torch.tensor([[0, 0, 1, 1, 2, 2],
[2, 1, 0, 2, 1, 0]])
).all()
assert(data['shifts'] == torch.tensor([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, -0.2, 0.0],
[0.0, 0.0, 0.0]])
).all()
assert(data['unit_shifts'] == torch.tensor([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, -1.0, 0.0],
[0.0, 0.0, 0.0]])
).all()
# check correct storage
assert(data['positions'] == torch.tensor([[0.0, 0.0, 0.0],
[0.07, 0.07, 0.0],
[0.07, -0.07, 0.0]])
).all()
assert(data['cell'] == torch.tensor([[0.2, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, 0.0, 0.2]])
).all()
assert(data['node_attrs'] == torch.tensor([[0.0, 1.0],
[1.0, 0.0],
[1.0, 0.0]])
).all()
assert(data['node_labels'] == torch.tensor([[0.0],
[1.0],
[1.0]])
).all()
assert(data['graph_labels'] == torch.tensor([[1.0]])).all()
assert(data['weight'] == 1.0)
# initialize configuration using three atoms (1 system, 2 env) as a subset and no buffer
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
system=[1],
environment=[2]
)
dataset = create_dataset_from_configurations([config],
atomic_numbers,
0.1,
remove_isolated_nodes=True,
show_progress=False
)[0]
# take data entry from the DictDataset
data = dataset['data_list']
assert(data['positions'] == torch.tensor([[0.07, 0.07, 0.0],
[0.07, -0.07, 0.0]])
).all()
assert(data['cell'] == torch.tensor([[0.2, 0.0, 0.0],
[0.0, 0.2, 0.0],
[0.0, 0.0, 0.2]])
).all()
assert(data['node_attrs'] == torch.tensor([[1.0, 0.0],
[1.0, 0.0]])
).all()
assert(data['edge_index'] == torch.tensor([[0, 1],
[1, 0]])
).all()
assert(data['shifts'] == torch.tensor([[0.0, 0.2, 0.0],
[0.0, -0.2, 0.0]])
).all()
assert(data['unit_shifts'] == torch.tensor([[0.0, 1.0, 0.0],
[0.0, -1.0, 0.0]])
).all()
def test_from_configurations_long_cutoff() -> None:
# fake atomic numbers, positions, cell, graph label, node labels
numbers = [8, 1, 1, 8, 1, 1]
positions = np.array(
[
[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0],
[0.0, 0.0, 0.9], [0.07, 0.08, 0.9], [0.07, -0.07, 0.9],
],
dtype=float
)
cell = np.identity(3, dtype=float) * 1.1
graph_labels = np.array([[1]])
node_labels = np.array([[0], [1], [1], [0], [1], [1]])
# init AtomicNumber object
atomic_numbers = atomic.AtomicNumberTable.from_zs(numbers)
# initialize configuration using all atoms
config = atomic.Configuration(
atomic_numbers=numbers,
positions=positions,
cell=cell,
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels,
system=[0, 3],
environment=[1, 2, 4, 5],
subsystem=[0, 3],
)
# create dataset from a configuration, even if single is the multiple function
dataset = create_dataset_from_configurations(
[config],
atomic_numbers,
cutoff=0.11,
long_range_cutoff=0.4,
remove_isolated_nodes=True,
show_progress=False
)[0]
# take data entry from the DictDataset
data = dataset['data_list']
# check edges and shifts are created correctly
assert (
data['edge_index'] == torch.tensor([
[[0, 0, 1, 2, 3, 3, 4, 5, 0, 3], [2, 1, 0, 0, 5, 4, 3, 3, 3, 0]]
])
).all()
assert (
data['shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, -1.1],
[0.0, 0.0, 1.1],
])
).all()
assert (
data['unit_shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, -1.0],
[0.0, 0.0, 1.0],
])
).all()
assert (data['edge_masks_lr'] == torch.tensor(
[[0]] * 8 + [[1]] * 2
)).all()
# create dataset from a configuration, even if single is the multiple function
dataset = create_dataset_from_configurations(
[config],
atomic_numbers,
cutoff=0.1,
long_range_cutoff=0.4,
remove_isolated_nodes=True,
show_progress=False
)[0]
# take data entry from the DictDataset
data = dataset['data_list']
# check edges and shifts are created correctly
assert (
data['edge_index'] == torch.tensor([
[[0, 0, 1, 2, 3, 4, 0, 3], [2, 1, 0, 0, 4, 3, 3, 0]]
])
).all()
assert (
data['shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, -1.1],
[0.0, 0.0, 1.1],
])
).all()
assert (
data['unit_shifts'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, -1.0],
[0.0, 0.0, 1.0],
])
).all()
assert (data['edge_masks_lr'] == torch.tensor(
[[0]] * 6 + [[1]] * 2
)).all()
assert (
data['positions'] == torch.tensor([
[0.0, 0.0, 0.0],
[0.07, 0.07, 0.0],
[0.07, -0.07, 0.0],
[0.0, 0.0, 0.9],
[0.07, -0.07, 0.9]
])
).all()