Skip to content
Snippets Groups Projects
Partitioner.py 2.94 KiB
Newer Older
Rishi Sharma's avatar
Rishi Sharma committed
from random import Random

""" Adapted from https://pytorch.org/tutorials/intermediate/dist_tuto.html """


class Partition(object):
    """
    Class for holding the data partition
Rishi Sharma's avatar
Rishi Sharma committed
    """

    def __init__(self, data, index):
        """
        Constructor. Caches the data and the indices
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
        data : indexable
        index : list
            A list of indices
Rishi Sharma's avatar
Rishi Sharma committed
        """
        self.data = data
        self.index = index

    def __len__(self):
        """
        Function to retrieve the length
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
        int
            Number of items in the data
Rishi Sharma's avatar
Rishi Sharma committed
        """
        return len(self.index)

    def __getitem__(self, index):
        """
        Retrieves the item in data with the given index
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
        index : int
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
        Data
            The data sample with the given `index` in the dataset
Rishi Sharma's avatar
Rishi Sharma committed
        """
        data_idx = self.index[index]
        return self.data[data_idx]


class DataPartitioner(object):
    """
    Class to partition the dataset
Rishi Sharma's avatar
Rishi Sharma committed
    """

    def __init__(self, data, sizes=[1.0], seed=1234):
        """
        Constructor. Partitions the data according the parameters
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
        data : indexable
            An indexable list of data items
        sizes : list(float)
            A list of fractions for each process
        seed : int, optional
            Seed for generating a random subset
Rishi Sharma's avatar
Rishi Sharma committed
        """
        self.data = data
        self.partitions = []
        rng = Random()
        rng.seed(seed)
        data_len = len(data)
        indexes = [x for x in range(0, data_len)]
        rng.shuffle(indexes)

        for frac in sizes:
            part_len = int(frac * data_len)
            self.partitions.append(indexes[0:part_len])
            indexes = indexes[part_len:]

    def use(self, rank):
        """
        Get the partition for the process with the given `rank`
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
        rank : int
            Rank of the current process
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
        Partition
            The dataset partition of the current process
Rishi Sharma's avatar
Rishi Sharma committed
        """
        return Partition(self.data, self.partitions[rank])
Rishi Sharma's avatar
Rishi Sharma committed

Jeffrey Wigger's avatar
Jeffrey Wigger committed

Rishi Sharma's avatar
Rishi Sharma committed
class SimpleDataPartitioner(DataPartitioner):
    """
    Class to partition the dataset

    """

    def __init__(self, data, sizes=[1.0]):
        """
        Constructor. Partitions the data according the parameters

        Parameters
        ----------
        data : indexable
            An indexable list of data items
        sizes : list(float)
            A list of fractions for each process

        """
        self.data = data
        self.partitions = []
        data_len = len(data)
        indexes = [x for x in range(0, data_len)]

        for frac in sizes:
            part_len = int(frac * data_len)
            self.partitions.append(indexes[0:part_len])
            indexes = indexes[part_len:]