Source code for mlcolvar.data.datamodule

#!/usr/bin/env python

# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
PyTorch Lightning DataModule object for DictDatasets.
"""

__all__ = ["DictModule"]


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

import math
from typing import Sequence, Union, Optional
import warnings

import torch
import torch_geometric
import numpy as np
import lightning
from torch.utils.data import Subset
from torch import default_generator, randperm

from mlcolvar.data import DictLoader, DictDataset

# =============================================================================
# DICTIONARY DATAMODULE CLASS
# =============================================================================


[docs] class DictModule(lightning.LightningDataModule): """Lightning DataModule constructed for :class:`~mlcolvar.data.dataset.DictDataset` . The DataModule automatically splits the :class:`~mlcolvar.data.dataset.DictDataset` s (using either random or sequential splitting) into training, validation, and (optionally) test sets. The class can also merge multiple :class:`~mlcolvar.data.dataset.DictDataset`s that have different keys (see example below). The datasets can have different sizes but must all have the same number of batches. Examples -------- >>> x = torch.randn((50, 2)) >>> dataset = DictDataset({'data': x, 'labels': x.square().sum(dim=1)}) >>> datamodule = DictModule(dataset, lengths=[0.75,0.2,0.05], batch_size=25) >>> # This is usually called by PyTorch Lightning. >>> datamodule.setup() >>> train_loader = datamodule.train_dataloader() >>> for batch in train_loader: ... batch_x = batch['data'] ... batch_y = batch['labels'] ... print(batch_x.shape, batch_y.shape) torch.Size([25, 2]) torch.Size([25]) torch.Size([13, 2]) torch.Size([13]) >>> val_loader = datamodule.val_dataloader() >>> test_loader = datamodule.test_dataloader() You can also iterate over multiple datasets. These can have different keys, but they must be of the same length. >>> dataset2 = DictDataset({'data2': torch.randn(50, 1), 'weights': torch.arange(50)}) >>> datamodule = DictModule([dataset, dataset2], lengths=[0.8, 0.2], batch_size=5) >>> datamodule.setup() >>> train_loader = datamodule.train_dataloader() >>> # Print first batch. >>> for batch in train_loader: ... print('batch dataset0:', list(batch['dataset0'].keys())) ... print('batch dataset1:', list(batch['dataset1'].keys())) ... break batch dataset0: ['data', 'labels'] batch dataset1: ['data2', 'weights'] """
[docs] def __init__( self, dataset: DictDataset, lengths: Sequence = (0.8, 0.2), batch_size: Union[int, Sequence] = 0, random_split: bool = True, shuffle: Union[bool, Sequence] = True, generator: Optional[torch.Generator] = None, ): """Create a ``DataModule`` wrapping a :class:`~mlcolvar.data.dataset.DictDataset`. For the ``batch_size`` and ``shuffle`` parameters, either a single value or a list-type of values (with same size as ``lengths``) can be provided. Parameters ---------- dataset : DictDataset or Sequence[DictDataset] The dataset or a list of datasets. If a list, the datasets can have different keys but they must all have the same number of samples. lengths : list-like, optional Lengths of the training, validation, and (optionally) test datasets. This must be a list of (float) fractions summing to 1. The default is ``[0.8,0.2]``. batch_size : int or list-like, optional Batch size, by default 0 (== ``len(dataset)``). random_split: bool, optional Whether to randomly split train/valid/test or sequentially, by default ``True``. shuffle : int or list-like, optional Whether to shuffle the batches in the ``DataLoader``, by default ``True``. generator : torch.Generator, optional Set random generator for reproducibility, by default ``None``. See Also -------- :class:`~mlcolvar.data.dataloader.DictLoader` The PyTorch loader built by the data module. """ super().__init__() self.dataset = dataset self.DataLoader = self._get_dataloader() self._lengths = lengths # Keeping this private for now. Changing it at runtime would # require changing dataset_split and the dataloaders. self._random_split = random_split # save generator if given, otherwise set it to torch.default_generator self.generator = generator if generator is not None else default_generator if self.generator is not None and not self._random_split: warnings.warn( "A torch.generator was provided but it is not used with random_split=False" ) # Make sure batch_size and shuffle are lists. if self._dataset_type == 'graphs' and batch_size == 0: batch_size = len(dataset) # make this explicit for torch_geometric if isinstance(batch_size, int): self.batch_size = [batch_size for _ in lengths] else: self.batch_size = batch_size if isinstance(shuffle, bool): self.shuffle = [shuffle for _ in lengths] else: self.shuffle = shuffle # This is initialized in setup(). self._dataset_split = None # dataloaders self.train_loader = None self.valid_loader = None self.test_loader = None
@property def _dataset_type(self): if not isinstance(self.dataset, list): _dataset_type = self.dataset.metadata['data_type'] else: it = iter(list(self.dataset)) _dataset_type = next(it).metadata['data_type'] if not all(d.metadata['data_type'] for d in it): raise ValueError("not all the dataset are of the same type!") return _dataset_type def _get_dataloader(self): # decide which loader to use if self._dataset_type == 'descriptors': DataLoader = DictLoader elif self._dataset_type == 'graphs': DataLoader = torch_geometric.loader.DataLoader return DataLoader
[docs] def setup(self, stage: Optional[str] = None): if self._dataset_split is None: if isinstance(self.dataset, DictDataset): self._dataset_split = self._split(self.dataset) else: # List of datasets. dataset_split = [self._split(d) for d in self.dataset] # Shape (n_datasets, n_loaders) -> (n_loaders, n_datasets) self._dataset_split = list(map(list, zip(*dataset_split)))
[docs] def train_dataloader(self): """Return training dataloader.""" self._check_setup() if self.train_loader is None: self.train_loader = self.DataLoader( self._dataset_split[0], batch_size=self.batch_size[0], shuffle=self.shuffle[0], ) return self.train_loader
[docs] def val_dataloader(self): """Return validation dataloader.""" self._check_setup() if len(self._lengths) < 2: raise NotImplementedError( "Validation dataset not available, you need to pass two lengths to datamodule." ) if self.valid_loader is None: self.valid_loader = self.DataLoader( self._dataset_split[1], batch_size=self.batch_size[1], shuffle=self.shuffle[1], ) return self.valid_loader
[docs] def test_dataloader(self): """Return test dataloader.""" self._check_setup() if len(self._lengths) < 3: raise NotImplementedError( "Test dataset not available, you need to pass three lengths to datamodule." ) if self.test_loader is None: self.test_loader = self.DataLoader( self._dataset_split[2], batch_size=self.batch_size[2], shuffle=self.shuffle[2], ) return self.test_loader
[docs] def predict_dataloader(self): raise NotImplementedError()
[docs] def teardown(self, stage: str): pass
def __repr__(self) -> str: parts = [ f"DictModule(dataset -> {self.dataset!r}", ( "train_loader -> DictLoader(" f"length={self._lengths[0]}, " f"batch_size={self.batch_size[0]}, " f"shuffle={self.shuffle[0]})" ), ] if len(self._lengths) >= 2: parts.append( "valid_loader -> DictLoader(" f"length={self._lengths[1]}, " f"batch_size={self.batch_size[1]}, " f"shuffle={self.shuffle[1]})" ) if len(self._lengths) >= 3: parts.append( "test_loader =DictLoader(" f"length={self._lengths[2]}, " f"batch_size={self.batch_size[2]}, " f"shuffle={self.shuffle[2]})" ) return ",\n\t ".join(parts) + ")" def __str__(self) -> str: return self.__repr__() def _split(self, dataset): """Perform the random or sequential spliting of a single dataset. Returns a list of Subset[DictDataset] objects. """ dataset_split = split_dataset( dataset, self._lengths, self._random_split, self.generator ) return dataset_split def _check_setup(self): """Raise an error if setup() has not been called.""" if self._dataset_split is None: raise AttributeError( "The datamodule has not been set up yet. To get the dataloaders " "outside a Lightning trainer please call .setup() first." )
[docs] def get_graph_inputs(self, mode='train'): """Generate an input that can be used as input for a GNN model Parameters ---------- mode : str, optional Type of loader to be used, either 'train' or 'val'/'valid', by default 'train' """ self.setup() if mode == 'train': loader=self.train_dataloader elif (mode=='val' or mode=='valid'): loader=self.val_dataloader else: raise ValueError(f"Mode can either be 'train', 'val', 'valid', found {mode}!") return next(iter(loader()))['data_list']
def split_dataset( dataset, lengths: Sequence, random_split: bool, generator: Optional[torch.Generator] = default_generator, ) -> list: """ Sequentially or randomly split a dataset into non-overlapping new datasets of given lengths. If random_split=True the behavior is the same as torch.utils.data.dataset.random_split. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided. After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. Optionally fix the generator for reproducible results. """ if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: subset_lengths = [] for i, frac in enumerate(lengths): if frac < 0 or frac > 1: raise ValueError(f"Fraction at index {i} is not between 0 and 1") n_items_in_split = int( math.floor(len(dataset) * frac) # type: ignore[arg-type] ) subset_lengths.append(n_items_in_split) remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] # add 1 to all the lengths in round-robin fashion until the remainder is 0 for i in range(remainder): idx_to_add_at = i % len(subset_lengths) subset_lengths[idx_to_add_at] += 1 lengths = subset_lengths for i, length in enumerate(lengths): if length == 0: warnings.warn( f"Length of split at index {i} is 0. " f"This might result in an empty dataset." ) # Cannot verify that dataset is Sized if sum(lengths) != len(dataset): # type: ignore[arg-type] raise ValueError( "Sum of input lengths does not equal the length of the input dataset!" ) if random_split: indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload] return [ Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths) ] else: return [ Subset(dataset, np.arange(offset - length, offset)) for offset, length in zip(_accumulate(lengths), lengths) ] def sequential_split(dataset, lengths: Sequence) -> list: """ Sequentially split a dataset into non-overlapping new datasets of given lengths. The behavior is the same as torch.utils.data.dataset.random_split. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided. After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. """ warnings.warn( "The function sequential_split is deprecated, use split_dataset(.., .., random_split=False, ..)", FutureWarning, stacklevel=2, ) return split_dataset(dataset=dataset, lengths=lengths, random_split=False) # Taken from python 3.5 docs, removed from PyTorch 2.3 onward def _accumulate(iterable, fn=lambda x, y: x + y): "Return running totals" # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 it = iter(iterable) try: total = next(it) except StopIteration: return yield total for element in it: total = fn(total, element) yield total