Skip to content
Snippets Groups Projects
Dataset.py 2.9 KiB
Newer Older
from decentralizepy import utils
from decentralizepy.mappings.Mapping import Mapping
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
class Dataset:
    """
    This class defines the Dataset API.
    All datasets must follow this API.
Rishi Sharma's avatar
Rishi Sharma committed
    """

Rishi Sharma's avatar
Rishi Sharma committed
    def __init__(
        self,
        rank: int,
        machine_id: int,
        mapping: Mapping,
Rishi Sharma's avatar
Rishi Sharma committed
        train_dir="",
        test_dir="",
        sizes="",
        test_batch_size="",
    ):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Constructor which reads the data files, instantiates and partitions the dataset
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
Rishi Sharma's avatar
Rishi Sharma committed
        rank : int
            Rank of the current process (to get the partition).
        machine_id : int
            Machine ID
        mapping : decentralizepy.mappings.Mapping
            Mapping to convert rank, machine_id -> uid for data partitioning
            It also provides the total number of global processes
        train_dir : str, optional
            Path to the training data files. Required to instantiate the training set
            The training set is partitioned according to the number of global processes and sizes
        test_dir : str. optional
            Path to the testing data files Required to instantiate the testing set
        sizes : list(int), optional
            A list of fractions specifying how much data to alot each process. Sum of fractions should be 1.0
            By default, each process gets an equal amount.
Rishi Sharma's avatar
Rishi Sharma committed
        test_batch_size : int, optional
            Batch size during testing. Default value is 64

Rishi Sharma's avatar
Rishi Sharma committed
        """
Rishi Sharma's avatar
Rishi Sharma committed
        self.rank = rank
        self.machine_id = machine_id
        self.mapping = mapping
        # the number of global processes, needed to split-up the dataset
        self.n_procs = mapping.get_n_procs()
        self.train_dir = utils.conditional_value(train_dir, "", None)
        self.test_dir = utils.conditional_value(test_dir, "", None)
        self.sizes = utils.conditional_value(sizes, "", None)
Rishi Sharma's avatar
Rishi Sharma committed
        self.test_batch_size = utils.conditional_value(test_batch_size, "", 64)
        if self.sizes:
            if type(self.sizes) == str:
                self.sizes = eval(self.sizes)

        if train_dir:
            self.__training__ = True
        else:
            self.__training__ = False

        if test_dir:
            self.__testing__ = True
        else:
            self.__testing__ = False

    def get_trainset(self):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Function to get the training set
Rishi Sharma's avatar
Rishi Sharma committed
        torch.utils.Dataset(decentralizepy.datasets.Data)
        Raises
        ------
        RuntimeError
            If the training set was not initialized
        """
        raise NotImplementedError

    def get_testset(self):
        """
        Function to get the test set
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
Rishi Sharma's avatar
Rishi Sharma committed
        torch.utils.Dataset(decentralizepy.datasets.Data)
        Raises
        ------
        RuntimeError
            If the test set was not initialized
Rishi Sharma's avatar
Rishi Sharma committed
        """