fedavg学习(个人学习笔记)

运行代码

一、环境配置

1. 首先确保你已经在项目目录下激活了虚拟环境,不然会下载到默认路径!

 在终端进入项目目录,并运行以下命令:

<python路径>\Scripts\activate

 激活后会显示自己的需要你环境名,如下:

2. pytorch官网下载对应电脑配置的pytorch,选好cuda版本

不清楚如何选择版本的可以先进入命令行窗口,执行命令nvidia-smi查看驱动版本号

如图,表格右上角就是我的版本号 :

因为我的版本比较高,所以我在官网上选择的是这个:

3. 复制刚刚官网下面的coomand,直接在自己的python虚拟环境中执行即可一键安装,非常快捷,如果安装配置成功就可以运行代码了。

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

注意这段指令是我所安装的版本,大家安装时应该根据自己电脑的配置在官网进行选择 

4. 检查安装:可以在终端运行如下指令检查pytorch、cuda是否成功安装

python
import torch
print(torch.cuda.device_count())
torch.cuda.is_available()

 如果成功,终端会显示:

1
True

到这里就配置好了,当然过程中可能没有这么顺利,会出现很多报错,需要一个一个检查修复。 

我之前第一次就是出了很多报错,自己一个一个研究了很久,网上搜了很多博客,还和同学讨论,最后也是解决了,比较艰辛

二、运行测试

1. 在pycharm终端输入下面指令 测试一下

 python main_fed.py --dataset cifar --epoch 10 --num_channel 3 --gpu 0 --model cnn --iid  

 如果cuda没有配置成功,这一步会报错!!!

该指令选择了 CIFAR 数据集,卷积神经网络 (CNN),在使用独立同分布 (IID) 的数据进行训练,轮次为10轮,结果如下图:

 2. 修改一下参数继续测试,这次选择了 MNIST数据集跑50轮

python main_fed.py --dataset mnist --iid --num_channels 1 --model cnn --epochs 50 --gpu 0

运行结果如下

 3. 继续测试,这次选择 MNIST数据集,非独立同分布(non-iid)数据跑50轮

python main_fed.py --dataset mnist  --num_channels 1 --model cnn --epochs 50 --gpu 0 

运行结果:

很好,运行测试正常,说明环境配置没问题。可以继续修改参数运行获得想要的实验数据

 代码学习

一、背景介绍

联邦学习(Federated Learning)是一种分布式机器学习方法,旨在训练模型而不需要将原始数据集传输到中央服务器上。在传统的机器学习方法中,通常会将所有数据集中到一个中央服务器进行训练,但这样做可能涉及到数据隐私和安全的问题。

相比之下,联邦学习通过保持数据在本地设备或边缘设备上进行训练,将机器学习的计算推向数据而不是将数据传输到中央服务器。这意味着个体的隐私和数据安全得到了更好的保护。

在联邦学习中,数据所有者(通常是用户设备,如智能手机、传感器等)在本地训练模型并仅将模型的参数更新上传到中央服务器。中央服务器将这些参数整合在一起,更新全局模型,然后将更新后的全局模型参数发送回参与方。这个过程可以迭代多次,直到全局模型收敛或达到预定的停止准则。

通过联邦学习,数据隐私得到了保护,因为原始数据不需要共享或暴露给中央服务器或其他参与方。并且,联邦学习还可以解决数据集分布不均衡的问题,因为每个参与方可以保留自己本地的数据分布。

联邦学习可以应用于各种场景,如医疗保健、物联网、金融等,以利用分布式数据进行模型训练,并同时保护数据隐私和安全。

联邦平均(Federated Averaging)是联邦学习(Federated Learning)中的一个重要概念。在联邦学习中,多个参与方各自训练本地模型,并将本地模型的参数更新发送到中央服务器,中央服务器根据这些参数更新来更新全局模型。而联邦平均是一种常用的方法,用于整合来自各个参与方的模型参数更新,从而更新全局模型。

具体而言,联邦平均的过程一般包括以下步骤:

  1. 初始阶段,中央服务器初始化全局模型的参数。
  2. 各个参与方在本地进行模型训练,获得本地模型的参数更新。
  3. 参与方将本地模型的参数更新发送到中央服务器。
  4. 中央服务器收集所有参与方的参数更新,并计算这些参数更新的平均值。
  5. 中央服务器使用平均值来更新全局模型的参数。
  6. 更新后的全局模型参数传输回各个参与方,以便它们在下一轮训练中使用。

这样,通过不断迭代上述步骤,全局模型就能够在保障数据隐私的前提下,融合来自多个参与方的知识,从而达到更好的模型表现。

二、main_fed核心模块

一些基本概念:
  • IID(独立同分布):在IID划分方式下,假设数据集中的样本之间是独立的,且来自于同一分布。换句话说,每个样本对于模型而言都是等价的,没有区别。在IID划分的数据集中,用户之间的数据分布是相同的,每个用户的样本是从整个数据集中随机采样得到的。

  • Non-IID(非独立非同分布):在Non-IID划分方式下,数据集中的样本之间可能存在一定的相关性或分布差异。这意味着不同用户之间的数据分布可能不同,每个用户的样本可能来自于不同的数据分布或代表不同的特定情况。Non-IID划分方式通常用于更真实的场景,其中用户的数据可能包含个性化特征或特定领域的数据。

  • MNIST 数据集:

    • 全名: Modified National Institute of Standards and Technology database
    • 内容: 包含手写数字的灰度图像,数字从0到9。
    • 规模: 60,000个训练样本和10,000个测试样本。
    • 图像大小: 28x28 像素。
    • 用途: 通常用于学习和测试图像分类算法的基础。因为它相对较小,可以用于快速原型设计。
  • CIFAR 数据集:

    • 全名: Canadian Institute for Advanced Research dataset
    • 内容: 包含10个类别的彩色图像,每个类别包含6000张图像。
    • 规模: 50,000个训练样本和10,000个测试样本。
    • 图像大小: 32x32 像素。
    • 类别: 飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。
    • 用途: 用于测试图像分类算法的性能,因为它相对较为复杂,包含彩色图像和多个类别。
核心代码:

主要如下,已经标记注释:

loss_train = []  # 用于存储每轮训练的损失
cv_loss, cv_acc = [], []  # 用于存储交叉验证的损失和准确率
val_loss_pre, counter = 0, 0  # 初始化验证集上的前一次损失值和计数器
net_best = None  # 用于存储在验证集上表现最好的神经网络模型
best_loss = None  # 用于存储在验证集上的最佳损失值
val_acc_list, net_list = [], []  # 用于存储验证集的准确率和神经网络模型列表

if args.all_clients: 
    print("Aggregation over all clients")
    w_locals = [w_glob for i in range(args.num_users)]  # 将全局模型的权重复制给所有用户
for iter in range(args.epochs):  # 迭代训练轮数
    loss_locals = []  # 用于存储每个客户端的本地损失
    if not args.all_clients:  # 如果不使用全局模型
        w_locals = []  # 清空本地模型参数列表
    m = max(int(args.frac * args.num_users), 1)  # 计算用于本轮训练的客户端数量
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)  # 随机选择一部分客户端参与本轮训练
    for idx in idxs_users:  # 迭代每一个被选择的客户端
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])  # 创建本地更新对象
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))  # 在选定的客户端上进行训练,得到本地模型和损失
        if args.all_clients:
            w_locals[idx] = copy.deepcopy(w)  # 将本地客户端模型的权重复制给全局模型
        else:
            w_locals.append(copy.deepcopy(w))  # 将本地客户端模型的权重添加到本地模型参数列表中
        loss_locals.append(copy.deepcopy(loss))  # 记录本地客户端的损失
    # 更新全局模型权重
    w_glob = FedAvg(w_locals)  # 使用FedAvg算法聚合所有客户端的权重得到新的全局模型

这段代码实现了联邦学习的核心逻辑:每轮迭代中,随机选择一部分客户端参与训练,在每个客户端上训练本地模型并计算本地损失,然后使用FedAvg算法将所有客户端的本地模型进行聚合,更新全局模型的权重。在整个训练过程中,还记录了损失值、准确率等指标,并保存了最好的模型和损失值。

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

    # testing
    net_glob.eval()
    acc_train, loss_train = test_img(net_glob, dataset_train, args)
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))

这段代码完成两个工作: 

  1. 绘制损失曲线:

    • 使用 plt.figure() 创建一个新的图形窗口。
    • 使用 plt.plot(range(len(loss_train)), loss_train) 绘制损失曲线,其中 range(len(loss_train)) 生成横坐标,loss_train 是纵坐标上的损失值。
    • 使用 plt.ylabel('train_loss') 添加纵坐标标签。
    • 使用 plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid)) 保存绘制的图形为文件,文件名通过 format 方法动态生成,其中包含数据集、模型、训练轮数、通信轮数和分布方式等参数信息,便于区分与查找。
  2. 模型测试:

    • 使用 net_glob.eval() 将全局模型切换到评估模式。
    • 分别调用 test_img 函数对训练集 dataset_train 和测试集 dataset_test 进行测试,得到训练集和测试集上的准确率和损失值。
    • 使用 print 函数输出训练集和测试集的准确率信息。

三、models包

参考学习博客:【联邦学习新手必看】手把手教你读懂FedAvg代码,并顺利运行

1. Fed.py

文件包括一个fedavg函数,作用是将客户端的权重数据聚合,返回计算得到的联邦平均模型参数 w_avg

import copy
import torch

def FedAvg(w):
    # 创建一个深拷贝的模型参数对象
    w_avg = copy.deepcopy(w[0])
    
    # 对模型参数的每个键进行循环迭代
    for k in w_avg.keys():
        # 迭代每个模型参数(除了第一个)
        for i in range(1, len(w)):
            # 将每个模型参数的值加到联邦平均模型参数上
            w_avg[k] += w[i][k]
        
        # 计算平均值,即将累加的值除以模型数量
        # 这行代码的作用是将模型参数 w_avg[k] 的值与模型数量 len(w) 相除,以获得联邦平均模型参数的值。
        # 由于 w_avg 是一个字典,w_avg[k] 是一个 PyTorch 张量对象。
        # 因此,这行代码的目的是对模型参数进行逐元素的除法运算,并返回一个新的具有相同形状的张量,其中每个元素等于相应位置上 w_avg[k] 的元素除以 len(w)
        w_avg[k] = torch.div(w_avg[k], len(w))
    
    # 返回计算得到的联邦平均模型参数
    return w_avg
2. Nets.py

这段代码定义了三个神经网络模型类:MLPCNNMnist 和 CNNCifar,它们都是基于 PyTorch 的 nn.Module 类来构建的。这部分主要通过博客

  • MLP 类定义了一个简单的多层感知机(MLP)模型。其结构包括一个输入层、一个隐藏层和一个输出层。
class MLP(nn.Module):
    # __init__ 方法中,通过 nn.Linear 定义了输入层和隐藏层之间的线性变换,nn.ReLU 定义了激活函数,nn.Dropout 定义了一个用于防止过拟合的随机失活层。
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    # forward 方法中,对输入进行了展平操作,并依次通过输入层、激活函数、随机失活层和隐藏层,最终返回模型的输出。
    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x
  •  CNNMnist 类定义了一个用于处理MNIST数据集的卷积神经网络(CNN)模型。该模型通过两个卷积层提取图像特征,然后通过线性层进行分类。ReLU激活函数和最大池化层用于非线性变换和特征降采样,Dropout层用于减少过拟合。
class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        #用于提取图像的特征。args.num_channels 表示输入图像的通道数,10表示输出通道数,kernel_size=5 表示卷积核的大小为5x5
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        
        #创建第二个二维卷积层,进一步提取特征
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        
        #创建一个二维Dropout层,用于随机失活卷积层的输出特征图。
        self.conv2_drop = nn.Dropout2d()
        
        #创建一个线性层,将卷积层输出的特征图转换为50维的向量。
        self.fc1 = nn.Linear(320, 50)
        
        #创建最后一个线性层,将50维的向量映射到类别数量
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        #通过第一个卷积层,并应用ReLU激活函数和最大池化层来提取特征
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        
        #通过第二个卷积层,并应用ReLU激活函数、Dropout和最大池化层来进一步提取特征
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        
        #对特征图进行形状变换,将其展平为一维向量
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        
        #通过线性层进行特征到隐藏层的线性变换,并应用ReLU激活函数
        x = F.relu(self.fc1(x))
        
        #应用Dropout层,随机失活一部分隐藏层神经元
        x = F.dropout(x, training=self.training)
        
        #通过线性层进行隐藏层到输出层的线性变换
        x = self.fc2(x)
        return x
  •  CNNCifar用于处理CIFAR数据集。该模型通过两个卷积层提取图像特征,然后通过线性层进行分类。ReLU激活函数和最大池化层用于非线性变换和特征降采样。
class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        #创建一个二维卷积层,用于提取图像的特征。输入通道数为3,输出通道数为6,卷积核大小为5x5。
        self.conv1 = nn.Conv2d(3, 6, 5)
        
        #创建一个最大池化层,用于特征降采样
        self.pool = nn.MaxPool2d(2, 2)
        
        #创建第二个二维卷积层,进一步提取特征
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        #创建一个线性层,将卷积层输出的特征图转换为120维的向量
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        
        #创建一个线性层,将120维的向量映射到84维的向量
        self.fc2 = nn.Linear(120, 84)
        
        #创建最后一个线性层,将84维的向量映射到类别数量
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        #通过第一个卷积层,并应用ReLU激活函数和最大池化层来提取特征
        x = self.pool(F.relu(self.conv1(x)))
        
        #通过第二个卷积层,并应用ReLU激活函数和最大池化层来进一步提取特征
        x = self.pool(F.relu(self.conv2(x)))
        
        #对特征图进行形状变换,将其展平为一维向量
        x = x.view(-1, 16 * 5 * 5)
        
        #通过线性层进行特征到隐藏层的线性变换,并应用ReLU激活函数
        x = F.relu(self.fc1(x))
        
        #通过线性层进行隐藏层到隐藏层的线性变换,并应用ReLU激活函数
        x = F.relu(self.fc2(x))
        
        #通过线性层进行隐藏层到输出层的线性变换
        x = self.fc3(x)
        return x
3. test.py

该函数用于对给定的测试数据集进行模型评估。它通过迭代数据加载器,对每个批量的数据进行前向传播和损失计算,然后累加损失和正确分类的样本数。最后计算平均测试损失和准确率,并将其返回

#net_g表示要测试的模型,datatest表示测试数据集,args表示其他参数。在函数开头,将net_g设置为评估模式
def test_img(net_g, datatest, args):
    net_g.eval()
    # testing
    #计算测试损失和正确分类的样本数
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    l = len(data_loader)
    # 对数据加载器进行迭代,每次迭代获取一个批量的数据和对应的目标标签
    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.cuda(), target.cuda()
        #调用net_g模型对数据进行前向传播
        log_probs = net_g(data)
        # sum up batch loss
        # 使用交叉熵损失函数F.cross_entropy计算损失并累加到test_loss中
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability
        # 利用预测的对数概率计算预测的类别,并与目标标签进行比较,统计正确分类的样本数
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    # 计算平均测试损失和准确率
    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * correct / len(data_loader.dataset)
    #是否打印详细的测试结果
    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    return accuracy, test_loss
4. Update.py

定义了一个名为DatasetSplit的自定义数据集类,继承自Dataset类。通过使用DatasetSplit类,可以从原始数据集中创建一个子数据集,该子数据集仅包含特定的样本。这在分割数据集用于训练和验证时非常有用,可以根据索引划分数据集并创建相应的训练集和验证集。

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

 定义了一个名为LocalUpdate的类,用于在本地进行模型的训练和更新。train方法中,通过迭代数据加载器的批次,对模型进行前向传播、计算损失、反向传播和参数更新,最终返回模型的状态字典和训练周期的平均损失

class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        #保存传入的参数args,用于配置训练过程中的超参数
        self.args = args
        
        #保存一个交叉熵损失函数的实例,用于计算训练过程中的损失
        self.loss_func = nn.CrossEntropyLoss()
        
        #用于保存选择的客户端
        self.selected_clients = []
        
        #创建一个数据加载器DataLoader,加载一个子数据集DatasetSplit,其中子数据集由参数dataset和idxs指定,设置批量大小为self.args.local_bs,并进行随机洗牌
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
        #将模型设置为训练模式
        net.train()
        # train and update
        # 创建一个torch.optim.SGD的优化器,使用net.parameters()作为优化器的参数,设置学习率为self.args.lr和动量为self.args.momentum
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        #用于保存每个训练周期的损失
        epoch_loss = []
        for iter in range(self.args.local_ep):
            #用于保存每个批次的损失
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                #清零模型参数的梯度
                net.zero_grad()
                
                #通过模型进行前向传播,获取预测的对数概率
                log_probs = net(images)
                
                #使用损失函数计算损失
                loss = self.loss_func(log_probs, labels)
                
                #对损失进行反向传播和参数更新
                loss.backward()
                optimizer.step()
                
                #批次索引能被10整除,打印当前训练进度和损失
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                #计算每个训练周期的平均损失,并将其添加到epoch_loss中
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        #返回模型的状态字典和所有训练周期的平均损失
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

四、utils包

util 包通常用于放置与具体业务逻辑无关的通用工具、辅助函数和辅助类。这些工具和函数在代码开发过程中可以被多个模块或组件共享和重用。

1. option.py

这段代码是一个 Python 脚本,被用作命令行工具以解析和处理用户输入的命令行参数。作用是帮助用户在命令行中指定和配置训练所需的参数,例如指定训练轮数、选择模型类型、设置学习率等,以便在代码中使用这些参数进行模型训练。

下面大致标注了每个参数的用途:

def args_parser():
    # 创建 ArgumentParser 对象
    parser = argparse.ArgumentParser()
    
    # federated arguments(联邦学习参数)
    parser.add_argument('--epochs', type=int, default=10, help="rounds of training")  # 训练轮数
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")  # 用户数量
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")  # 客户端比例
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")  # 本地训练轮数
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")  # 本地批大小
    parser.add_argument('--bs', type=int, default=128, help="test batch size")  # 测试批大小
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")  # 学习率
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")  # 动量
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")  # 数据拆分方式

    # model arguments(模型参数)
    parser.add_argument('--model', type=str, default='mlp', help='model name')  # 模型名称
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')  # 卷积核数量
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')  # 卷积核大小
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")  # 归一化方式
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")  # 卷积网络滤波器数量
    parser.add_argument('--max_pool', type=str, default='True', help="Whether use max pooling rather than strided convolutions")  # 是否使用最大池化

    # other arguments(其他参数)
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")  # 数据集名称
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')  # 是否独立同分布
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")  # 类别数量
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")  # 图像通道数
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")  # GPU ID
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')  # 提前停止轮数
    parser.add_argument('--verbose', action='store_true', help='verbose print')  # 是否详细打印
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')  # 随机种子
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')  # 是否聚合所有客户端
    # 解析并返回命令行参数
    args = parser.parse_args()
    return args
2. sample.py

这个部分是一个用于从数据集中随机抽样生成独立同分布(IID)或非独立同分布(non-IID)数据的辅助函数。

对于 MNIST 数据集,有两个函数:mnist_iid 和 mnist_noniid

  • mnist_iid 函数从 MNIST 数据集随机抽取独立同分布的数据,返回一个字典,其中键是客户端的索引,值是对应客户端的图像索引集合。

import numpy as np

def mnist_iid(dataset, num_users):
    """
    从MNIST数据集中独立随机采样生成指定数量的IID(Independent and Identically Distributed)客户端数据集。

    :param dataset: MNIST数据集
    :param num_users: 客户端数量
    :return: 包含客户端数据集索引的字典
    """
    # 计算每个客户端所需的样本数量
    num_items = int(len(dataset) / num_users)

    # 初始化字典来存储每个客户端数据集的样本索引
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]

    # 为每个客户端生成对应的样本索引
    for i in range(num_users):
        # 从所有样本索引中无序不放回地选择num_items个样本,并将它们存储在字典中
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))

        # 更新所有样本索引,去掉已分配给当前客户端的样本
        all_idxs = list(set(all_idxs) - dict_users[i])
    
    # 返回包含客户端数据集索引的字典
    return dict_users
  • mnist_noniid 函数从 MNIST 数据集随机抽取非独立同分布的数据,返回一个字典,其中键是客户端的索引,值是对应客户端的图像索引集合。

def mnist_noniid(dataset, num_users):
    """
    从MNIST数据集中非IID(Non-IID)地随机采样生成指定数量的客户端数据集。

    :param dataset: MNIST数据集
    :param num_users: 客户端数量
    :return: 包含客户端数据集索引的字典
    """
    # 定义切分的数量和每个切分中的图像数量
    num_shards, num_imgs = 200, 300

    # 创建一个切分索引列表
    idx_shard = [i for i in range(num_shards)]

    # 创建一个字典来存储每个客户端数据集的样本索引
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}

    # 获取图像的索引和标签
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # 对标签进行排序
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # 划分和分配索引
    for i in range(num_users):
        # 随机选择2个切分
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        
        # 将切分对应的图像索引添加到对应客户端的字典中
        for rand in rand_set:
            dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    
    # 返回包含客户端数据集索引的字典
    return dict_users

对于 CIFAR10 数据集,有一个函数 cifar_iid

  • cifar_iid 函数从 CIFAR10 数据集随机抽取独立同分布的数据,返回一个字典,其中键是客户端的索引,值是对应客户端的图像索引集合。
def cifar_iid(dataset, num_users):
    """
    从CIFAR10数据集中随机采样生成指定数量的IID(I.I.D)客户端数据集。

    :param dataset: CIFAR10数据集
    :param num_users: 客户端数量
    :return: 包含客户端数据集索引的字典
    """
    # 计算每个客户端的图像数量
    num_items = int(len(dataset)/num_users)

    # 创建一个空字典来存储每个客户端数据集的样本索引
    dict_users = {}

    # 创建一个包含所有图像索引的列表
    all_idxs = [i for i in range(len(dataset))]

    # 遍历每个客户端
    for i in range(num_users):
        # 从所有图像索引中随机选择指定数量的索引,并且不允许重复选择
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))

        # 从所有图像索引中删除已分配给当前客户端的索引
        all_idxs = list(set(all_idxs) - dict_users[i])
    
    # 返回包含客户端数据集索引的字典
    return dict_users

 五、结语

该博客记录了很多学习过程的笔记,主要用于学习知识和回顾,参考借鉴了很多博客,也还有很多要改进的地方,继续加油吧!