Skip to content
Snippets Groups Projects
Femnist.py 12.7 KiB
Newer Older
Rishi Sharma's avatar
Rishi Sharma committed
import json
import logging
Rishi Sharma's avatar
Rishi Sharma committed
import os
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
Rishi Sharma's avatar
Rishi Sharma committed
from torch.utils.data import DataLoader
Rishi Sharma's avatar
Rishi Sharma committed

from decentralizepy.datasets.Data import Data
Rishi Sharma's avatar
Rishi Sharma committed
from decentralizepy.datasets.Dataset import Dataset
from decentralizepy.datasets.Partitioner import DataPartitioner
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.models.Model import Model
Rishi Sharma's avatar
Rishi Sharma committed

NUM_CLASSES = 62
IMAGE_SIZE = (28, 28)
FLAT_SIZE = 28 * 28
Rishi Sharma's avatar
Rishi Sharma committed
PIXEL_RANGE = 256.0
Rishi Sharma's avatar
Rishi Sharma committed


Rishi Sharma's avatar
Rishi Sharma committed
    """
    Class for the FEMNIST dataset
Rishi Sharma's avatar
Rishi Sharma committed
    """

Rishi Sharma's avatar
Rishi Sharma committed
    def __read_file__(self, file_path):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Read data from the given json file

        Parameters
        ----------
        file_path : str
            The file path

        Returns
        -------
        tuple
            (users, num_samples, data)

        """
Rishi Sharma's avatar
Rishi Sharma committed
        with open(file_path, "r") as inf:
Rishi Sharma's avatar
Rishi Sharma committed
            client_data = json.load(inf)
        return (
            client_data["users"],
            client_data["num_samples"],
            client_data["user_data"],
        )
Rishi Sharma's avatar
Rishi Sharma committed

    def __read_dir__(self, data_dir):
        """
        Function to read all the FEMNIST data files in the directory
        Parameters
        ----------
        data_dir : str
            Path to the folder containing the data files
        Returns
        -------
        3-tuple
            A tuple containing list of clients, number of samples per client,
            and the data items per client
        """
        clients = []
        num_samples = []
        data = defaultdict(lambda: None)

        files = os.listdir(data_dir)
        files = [f for f in files if f.endswith(".json")]
        for f in files:
            file_path = os.path.join(data_dir, f)
Rishi Sharma's avatar
Rishi Sharma committed
            u, n, d = self.__read_file__(file_path)
            clients.extend(u)
            num_samples.extend(n)
            data.update(d)
        return clients, num_samples, data

Rishi Sharma's avatar
Rishi Sharma committed
    def file_per_user(self, dir, write_dir):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Function to read all the FEMNIST data files and write one file per user

        Parameters
        ----------
        dir : str
            Path to the folder containing the data files
        write_dir : str
            Path to the folder to write the files

        """
Rishi Sharma's avatar
Rishi Sharma committed
        clients, num_samples, train_data = self.__read_dir__(dir)
        for index, client in enumerate(clients):
            my_data = dict()
            my_data["users"] = [client]
            my_data["num_samples"] = num_samples[index]
            my_samples = {"x": train_data[client]["x"], "y": train_data[client]["y"]}
            my_data["user_data"] = {client: my_samples}
Rishi Sharma's avatar
Rishi Sharma committed
            with open(os.path.join(write_dir, client + ".json"), "w") as of:
Rishi Sharma's avatar
Rishi Sharma committed
                json.dump(my_data, of)
Rishi Sharma's avatar
Rishi Sharma committed
                print("Created File: ", client + ".json")
Rishi Sharma's avatar
Rishi Sharma committed

    def load_trainset(self):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Loads the training set. Partitions it if needed.

        """
Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Loading training set.")
        files = os.listdir(self.train_dir)
        files = [f for f in files if f.endswith(".json")]
        files.sort()
        c_len = len(files)

Rishi Sharma's avatar
Rishi Sharma committed
        # clients, num_samples, train_data = self.__read_dir__(self.train_dir)
Rishi Sharma's avatar
Rishi Sharma committed

        if self.sizes == None:  # Equal distribution of data among processes
            e = c_len // self.n_procs
            frac = e / c_len
            self.sizes = [frac] * self.n_procs
            self.sizes[-1] += 1.0 - frac * self.n_procs
            logging.debug("Size fractions: {}".format(self.sizes))

        self.uid = self.mapping.get_uid(self.rank, self.machine_id)
        my_clients = DataPartitioner(files, self.sizes).use(self.uid)
Rishi Sharma's avatar
Rishi Sharma committed
        my_train_data = {"x": [], "y": []}
        self.clients = []
        self.num_samples = []
        logging.debug("Clients Length: %d", c_len)
        logging.debug("My_clients_len: %d", my_clients.__len__())
        for i in range(my_clients.__len__()):
            cur_file = my_clients.__getitem__(i)

Rishi Sharma's avatar
Rishi Sharma committed
            clients, _, train_data = self.__read_file__(
                os.path.join(self.train_dir, cur_file)
            )
Rishi Sharma's avatar
Rishi Sharma committed
            for cur_client in clients:
                self.clients.append(cur_client)
                my_train_data["x"].extend(train_data[cur_client]["x"])
                my_train_data["y"].extend(train_data[cur_client]["y"])
                self.num_samples.append(len(train_data[cur_client]["y"]))
        self.train_x = (
            np.array(my_train_data["x"], dtype=np.dtype("float32"))
            .reshape(-1, 28, 28, 1)
            .transpose(0, 3, 1, 2)
        )
Rishi Sharma's avatar
Rishi Sharma committed
        self.train_y = np.array(my_train_data["y"], dtype=np.dtype("int64")).reshape(-1)
        logging.info("train_x.shape: %s", str(self.train_x.shape))
        logging.info("train_y.shape: %s", str(self.train_y.shape))
Rishi Sharma's avatar
Rishi Sharma committed
        assert self.train_x.shape[0] == self.train_y.shape[0]
        assert self.train_x.shape[0] > 0
Rishi Sharma's avatar
Rishi Sharma committed

    def load_testset(self):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Loads the testing set.

        """
Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Loading testing set.")
Rishi Sharma's avatar
Rishi Sharma committed
        _, _, d = self.__read_dir__(self.test_dir)
Rishi Sharma's avatar
Rishi Sharma committed
        test_x = []
        test_y = []
Rishi Sharma's avatar
Rishi Sharma committed
        for test_data in d.values():
Rishi Sharma's avatar
Rishi Sharma committed
            for x in test_data["x"]:
                test_x.append(x)
            for y in test_data["y"]:
                test_y.append(y)
        self.test_x = (
            np.array(test_x, dtype=np.dtype("float32"))
            .reshape(-1, 28, 28, 1)
            .transpose(0, 3, 1, 2)
        )
        self.test_y = np.array(test_y, dtype=np.dtype("int64")).reshape(-1)
        logging.info("test_x.shape: %s", str(self.test_x.shape))
        logging.info("test_y.shape: %s", str(self.test_y.shape))
Rishi Sharma's avatar
Rishi Sharma committed
        assert self.test_x.shape[0] == self.test_y.shape[0]
        assert self.test_x.shape[0] > 0
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
    def __init__(
        self,
        rank: int,
        machine_id: int,
        mapping: Mapping,
Rishi Sharma's avatar
Rishi Sharma committed
        n_procs="",
        train_dir="",
        test_dir="",
        sizes="",
        test_batch_size=1024,
    ):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Constructor which reads the data files, instantiates and partitions the dataset
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
Rishi Sharma's avatar
Rishi Sharma committed
        rank : int
Rishi Sharma's avatar
Rishi Sharma committed
            Rank of the current process (to get the partition).
        machine_id : int
            Machine ID
        mapping : decentralizepy.mappings.Mapping
            Mapping to convert rank, machine_id -> uid for data partitioning
            It also provides the total number of global processes
Rishi Sharma's avatar
Rishi Sharma committed
        train_dir : str, optional
            Path to the training data files. Required to instantiate the training set
            The training set is partitioned according to the number of global processes and sizes
Rishi Sharma's avatar
Rishi Sharma committed
        test_dir : str. optional
            Path to the testing data files Required to instantiate the testing set
        sizes : list(int), optional
            A list of fractions specifying how much data to alot each process. Sum of fractions should be 1.0
Rishi Sharma's avatar
Rishi Sharma committed
            By default, each process gets an equal amount.
Rishi Sharma's avatar
Rishi Sharma committed
        test_batch_size : int, optional
            Batch size during testing. Default value is 64

Rishi Sharma's avatar
Rishi Sharma committed
        """
Rishi Sharma's avatar
Rishi Sharma committed
        super().__init__(
            rank,
            machine_id,
            mapping,
            train_dir,
            test_dir,
            sizes,
            test_batch_size,
        )
Rishi Sharma's avatar
Rishi Sharma committed
            self.load_trainset()
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
            self.load_testset()
Rishi Sharma's avatar
Rishi Sharma committed

        # TODO: Add Validation

    def get_client_ids(self):
        """
        Function to retrieve all the clients of the current process
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
        list(str)
            A list of strings of the client ids.
Rishi Sharma's avatar
Rishi Sharma committed
        """
        return self.clients

    def get_client_id(self, i):
        """
        Function to get the client id of the ith sample
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
        i : int
            Index of the sample
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
        str
            Client ID
Rishi Sharma's avatar
Rishi Sharma committed
        Raises
        ------
        IndexError
            If the sample index is out of bounds
Rishi Sharma's avatar
Rishi Sharma committed
        """
        lb = 0
        for j in range(len(self.clients)):
            if i < lb + self.num_samples[j]:
                return self.clients[j]

        raise IndexError("i is out of bounds!")

Rishi Sharma's avatar
Rishi Sharma committed
    def get_trainset(self, batch_size=1, shuffle=False):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Function to get the training set
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
Rishi Sharma's avatar
Rishi Sharma committed
        batch_size : int, optional
Rishi Sharma's avatar
Rishi Sharma committed
            Batch size for learning
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
Rishi Sharma's avatar
Rishi Sharma committed
        torch.utils.Dataset(decentralizepy.datasets.Data)
Rishi Sharma's avatar
Rishi Sharma committed
        Raises
        ------
        RuntimeError
            If the training set was not initialized
Rishi Sharma's avatar
Rishi Sharma committed
        """
        if self.__training__:
Rishi Sharma's avatar
Rishi Sharma committed
            return DataLoader(
                Data(self.train_x, self.train_y), batch_size=batch_size, shuffle=shuffle
            )
Rishi Sharma's avatar
Rishi Sharma committed
        raise RuntimeError("Training set not initialized!")

    def get_testset(self):
        """
        Function to get the test set
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
Rishi Sharma's avatar
Rishi Sharma committed
        torch.utils.Dataset(decentralizepy.datasets.Data)
Rishi Sharma's avatar
Rishi Sharma committed
        Raises
        ------
        RuntimeError
            If the test set was not initialized
Rishi Sharma's avatar
Rishi Sharma committed
        """
        if self.__testing__:
Rishi Sharma's avatar
Rishi Sharma committed
            return DataLoader(
                Data(self.test_x, self.test_y), batch_size=self.test_batch_size
            )
Rishi Sharma's avatar
Rishi Sharma committed
        raise RuntimeError("Test set not initialized!")

Rishi Sharma's avatar
Rishi Sharma committed
    def test(self, model, loss):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Function to evaluate model on the test dataset.

        Parameters
        ----------
        model : decentralizepy.models.Model
            Model to evaluate
        loss : torch.nn.loss
            Loss function to evaluate

        Returns
        -------
        tuple
            (accuracy, loss_value)

        """
Rishi Sharma's avatar
Rishi Sharma committed
        testloader = self.get_testset()
Rishi Sharma's avatar
Rishi Sharma committed

        logging.debug("Test Loader instantiated.")
Rishi Sharma's avatar
Rishi Sharma committed

        correct_pred = [0 for _ in range(NUM_CLASSES)]
        total_pred = [0 for _ in range(NUM_CLASSES)]
Rishi Sharma's avatar
Rishi Sharma committed

        total_correct = 0
        total_predicted = 0

Rishi Sharma's avatar
Rishi Sharma committed
        with torch.no_grad():
Rishi Sharma's avatar
Rishi Sharma committed
            loss_val = 0.0
            count = 0
Rishi Sharma's avatar
Rishi Sharma committed
            for elems, labels in testloader:
                outputs = model(elems)
Rishi Sharma's avatar
Rishi Sharma committed
                loss_val += loss(outputs, labels).item()
                count += 1
Rishi Sharma's avatar
Rishi Sharma committed
                _, predictions = torch.max(outputs, 1)
                for label, prediction in zip(labels, predictions):
Rishi Sharma's avatar
Rishi Sharma committed
                    logging.debug("{} predicted as {}".format(label, prediction))
Rishi Sharma's avatar
Rishi Sharma committed
                    if label == prediction:
                        correct_pred[label] += 1
Rishi Sharma's avatar
Rishi Sharma committed
                        total_correct += 1
Rishi Sharma's avatar
Rishi Sharma committed
                    total_pred[label] += 1
Rishi Sharma's avatar
Rishi Sharma committed
                    total_predicted += 1
Rishi Sharma's avatar
Rishi Sharma committed

        logging.debug("Predicted on the test set")
Rishi Sharma's avatar
Rishi Sharma committed

        for key, value in enumerate(correct_pred):
Rishi Sharma's avatar
Rishi Sharma committed
            if total_pred[key] != 0:
                accuracy = 100 * float(value) / total_pred[key]
            else:
                accuracy = 100.0
Rishi Sharma's avatar
Rishi Sharma committed
            logging.debug("Accuracy for class {} is: {:.1f} %".format(key, accuracy))

Rishi Sharma's avatar
Rishi Sharma committed
        accuracy = 100 * float(total_correct) / total_predicted
Rishi Sharma's avatar
Rishi Sharma committed
        loss_val = loss_val / count
Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Overall accuracy is: {:.1f} %".format(accuracy))
Rishi Sharma's avatar
Rishi Sharma committed
        return accuracy, loss_val
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed

class LogisticRegression(Model):
Rishi Sharma's avatar
Rishi Sharma committed
    """
    Class for a Logistic Regression Neural Network for FEMNIST
Rishi Sharma's avatar
Rishi Sharma committed
    """

    def __init__(self):
        """
        Constructor. Instantiates the Logistic Regression Model
            with 28*28 Input and 62 output classes
Rishi Sharma's avatar
Rishi Sharma committed
        """
        super().__init__()
        self.fc1 = nn.Linear(FLAT_SIZE, NUM_CLASSES)

    def forward(self, x):
        """
        Forward pass of the model
Rishi Sharma's avatar
Rishi Sharma committed
        Parameters
        ----------
        x : torch.tensor
            The input torch tensor
Rishi Sharma's avatar
Rishi Sharma committed
        Returns
        -------
        torch.tensor
            The output torch tensor
Rishi Sharma's avatar
Rishi Sharma committed
        """
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
Rishi Sharma's avatar
Rishi Sharma committed
        return x


class CNN(Model):
Rishi Sharma's avatar
Rishi Sharma committed
    """
    Class for a CNN Model for FEMNIST

    """
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
    def __init__(self):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Constructor. Instantiates the CNN Model
            with 28*28*1 Input and 62 output classes

        """
Rishi Sharma's avatar
Rishi Sharma committed
        super().__init__()
Rishi Sharma's avatar
Rishi Sharma committed
        # 1.6 million params
Rishi Sharma's avatar
Rishi Sharma committed
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        self.fc1 = nn.Linear(7 * 7 * 64, 512)
        self.fc2 = nn.Linear(512, NUM_CLASSES)

    def forward(self, x):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Forward pass of the model

        Parameters
        ----------
        x : torch.tensor
            The input torch tensor

        Returns
        -------
        torch.tensor
            The output torch tensor

        """
Rishi Sharma's avatar
Rishi Sharma committed
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x