FedProto:跨异构客户端的联邦原型学习(论文阅读)

avatar
作者
猴君
阅读量:5

题目:FedProto: Federated Prototype Learning across Heterogeneous Clients

网址:http://arxiv.org/abs/2105.00243

 摘要

在联邦学习(FL)中,当客户端知识在梯度空间中聚集时,客户端间的异构性通常会影响优化的收敛和泛化性能。例如,客户端可能在数据分布、网络延迟、输入/输出空间和/或模型架构方面存在差异,这很容易导致其局部梯度的不对齐。为了提高对异构的容忍度,我们提出了一种新的联邦原型学习框架(federalprototype learning, FedProto),在该框架中,客户端和服务器端通过抽象类原型而不是梯度进行通信。FedProto将从不同客户端收集到的局部原型聚合在一起,然后将全局原型发送回所有客户端,以规范局部模型的训练。在每个客户端上进行训练的目的是使局部数据的分类错误最小化,同时使得到的局部原型与相应的全局原型足够接近。

补充(什么是局部梯度不对齐): 在联邦学习中,各个客户端拥有不同且分布不均的数据,这导致每个客户端训练出的局部模型和计算出的 梯度方向各不相同。当这些梯度被发送到中心服务器进行聚合时,如果差异较大,会使得全局模型难以收 敛或者性能较差

预备知识

异构联邦学习

异构一般分为模型异构和统计异构(也叫Non-IID问题)两部分,现有的方法都只关注了一个异构,且都使用了基于梯度的聚合方法,在通信效率和解决基于梯度的攻击方面存在缺陷

本文提出的FedProto方法同时考虑了模型异构和统计异构的问题,并且这种方法提出了基于原型聚合的FL框架,只有原型在服务器和客户端之间传输。所提出的解决方案不需要聚合模型参数或梯度,因此它具有成为各种异构FL场景的健壮框架的巨大潜力。

原型学习

Prototype-based learning(原型学习)是一种机器学习方法,它的核心思想是通过存储一组代表性的样本(原型),然后使用这些原型来进行分类、回归或聚类等任务,受原型学习的启发,在异构数据集上合并原型可以有效地集成来自不同数据分布的特征表示

例如,当我们谈论“狗”时,不同的人会有一个独特的“想象图片”或“原型”来代表“狗”这个概念。由于不同的生活经历和视觉记忆,他们的原型可能会略有不同。在人与人之间交换这些特定概念的原型可以使他们获得更多关于“狗”概念的知识。将每个FL客户端视为类人智能体,我们的方法的核心思想是交换原型而不是共享模型参数或原始数据,这可以自然地匹配人类的知识获取行为。

个人理解:相比于传统的联邦学习方法,Fedproto是基于原型也就是每个类别的特征在全局进行聚合,将来自每个client的特征进行合并得到一个更加全面的特征,然后在将这个特征分发给每个client,每个client在本地再进一步训练自己掌握的的特征向全局特征靠拢,这样就实现了每个client的模型优化,且在这个过程中每个client的异构是没有影响的,因为我们只关注原型。

具体方法

FedAvg

在联邦学习中,每个客户端都有一个从分布Pi(x, y)中提取的本地私有数据集Di,其中x和y分别表示输入特征和相应的类标签

FedAvg的目标函数为(1),在给定数据集和模型的情况下,通过调整参数ω来最小化加权损失函数,从而找到最佳拟合模型参数

其中,N为所有客户端的实例总数,F为共享模型(ω为全局模型的参数),Ls 表示损失函数,此目标函数用来最小化加权的客户端损失函数

真实异构联邦场景

在现实世界中,(统计异构)Pi因客户端而异,也就导致Di的Non-IID,例如不同客户机上的Pi可以是不同类子集上的数据分布,(模型异构)Fi因客户端而异。对于第i个客户端的训练过程是最小化损失函数,目标函数变为(2)

但由于Fi有不同的模型架构,所以ωi有不同的形式和大小,全局模型参数ω不能通过平均优化ωi来优化

基于原型的聚合

一般来说,基于深度学习的模型包括两个部分:

(1)表示层(即嵌入函数),用于将输入从原始特征空间转换到嵌入空间

(在fedproto中表示层就是起到一个特征提取的作用)

我们把i客户端上参数为 𝜙𝑖 的表示层运算(输入x)表示为 fi(𝜙𝑖 ; x)

(2)决策层,为给定的学习任务做出分类决策

决策层通常指神经网络的最后一层,那么表示层则代表除了最后一层之外的其他层

原型

原型的表示是基于类别的,第i个客户端中第j类别的原型表示为(3),Di,j是本地数据集Di中j类别数据组成的子集,原型是类j中实例的嵌入向量的平均值

本地模型更新

在局部损失函数的基础上加入正则化项(局部原型和全局原型的L2距离),使局部原型接近全局原型,同时最小化分类误差的损失

LS: 监督学习的损失函数,其作用是衡量模型预测结果与真实标签之间的差距

LR: 用于衡量本地原型C(j)与相应的全局原型C¯(j)之间的距离(使用L2距离),

损失函数就由这两部分组成

注意: 根据文章,这里是通过试验和调整来确定λ的最佳值

全局原型聚合

全局原型是由局部原型加权平均得到的。具体来说,对于每个类别j,服务器会从拥有该类别的客户端中收集原型,然后对这些原型进行加权平均,得到该类别的全局原型C¯(j)。

算法

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!