Source code for mlcolvar.data.dataset

import torch
import numpy as np
from mlcolvar.core.transform.utils import Statistics
from torch.utils.data import Dataset

__all__ = ["DictDataset"]


[docs] class DictDataset(Dataset): """Define a torch dataset from a dictionary of lists/array/tensors and names. E.g. { 'data' : torch.Tensor([1,2,3,4]), 'labels' : [0,0,1,1], 'weights' : np.asarray([0.5,1.5,1.5,0.5]) } """
[docs] def __init__(self, dictionary: dict = None, feature_names=None, create_ref_idx : bool = False, **kwargs): """Create a Dataset from a dictionary or from a list of kwargs. Parameters ---------- dictionary : dict Dictionary with names and tensors feature_names : array-like List or numpy array with feature names """ # assert type dict if (dictionary is not None) and (not isinstance(dictionary, dict)): raise TypeError( f"DictDataset requires a dictionary , not {type(dictionary)}." ) # Add kwargs to dict if dictionary is None: dictionary = {} dictionary = {**dictionary, **kwargs} if len(dictionary) == 0: raise ValueError("Empty datasets are not supported") # convert to torch.Tensors for key, val in dictionary.items(): if not isinstance(val, torch.Tensor): dictionary[key] = torch.Tensor(val) # save dictionary self._dictionary = dictionary # save feature names self.feature_names = feature_names # check that all elements of dict have same length it = iter(dictionary.values()) self.length = len(next(it)) if not all(len(l) == self.length for l in it): raise ValueError("not all arrays in dictionary have same length!") # add indexing of entries for shuffling and slicing reference if create_ref_idx and "ref_idx" not in self._dictionary.keys(): dictionary['ref_idx'] = torch.arange(len(self), dtype=torch.int)
def __getitem__(self, index): if isinstance(index, str): # raise TypeError(f'Index ("{index}") should be a slice, and not a string. To access the stored dictionary use .dictionary["{index}"] instead.') return self._dictionary[index] else: slice_dict = {} for key, val in self._dictionary.items(): slice_dict[key] = val[index] return slice_dict def __setitem__(self, index, value): if isinstance(index, str): # check lengths if len(value) != len(self): raise ValueError( f"length of value ({len(value)}) != length of dataset ({len(self)})." ) self._dictionary[index] = value else: raise NotImplementedError( f"Only string indexes can be set, {type(index)} is not supported." ) def __len__(self): value = next(iter(self._dictionary.values())) return len(value)
[docs] def get_stats(self): """Compute statistics ('mean','Std','Min','Max') of the dataset. Returns ------- stats dictionary of dictionaries with statistics """ stats = {} for k in self.keys: print("KEY: ", k, end="\n\n\n") if k != "ref_idx": stats[k] = Statistics(self._dictionary[k]).to_dict() return stats
def __repr__(self) -> str: string = "DictDataset(" for key, val in self._dictionary.items(): string += f' "{key}": {list(val.shape)},' string = string[:-1] + " )" return string @property def keys(self): return tuple(self._dictionary.keys()) @property def feature_names(self): """Feature names.""" return self._feature_names @feature_names.setter def feature_names(self, value): self._feature_names = ( np.asarray(value, dtype=str) if value is not None else value )
def test_DictDataset(): # from list dataset_dict = { "data": torch.Tensor([[1.0], [2.0], [0.3], [0.4]]), "labels": [0, 0, 1, 1], "weights": np.asarray([0.5, 1.5, 1.5, 0.5]), } dataset = DictDataset(dataset_dict) print(len(dataset)) print(dataset[0]) print(dataset[0:2]["data"]) print(dataset[0:2]["data"].dtype) # test with dataloader from torch.utils.data import DataLoader loader = DataLoader(dataset, batch_size=1) batch = next(iter(loader)) print(batch["data"]) # test with fastdataloader from .dataloader import DictLoader loader = DictLoader(dataset, batch_size=1) batch = next(iter(loader)) print(batch)