深入理解横向联邦学习
联邦学习是Google在2017年提出来的,旨在在保护用户隐私的情况下使用用户更敏感的数据来训练机器学习模型。这种学习方式就是后来人们所谓的横向联邦学习(区别于纵向联邦学习,这个会在另一篇文章中做介绍),这也是国外企业和机构主要的研究方向,反观国内纵向联邦学习是主流(毕竟国内互联网用户没有隐私,笑)。
从基本的技术架构上来看,横向联邦学习本质上就是一种分布式机器学习,谓之联邦学习多少有点新瓶装旧酒的味道。当然,我们不能纯粹从技术角度来看待一个新事物,就像当年的云计算,我们更应该思考它的动机和场景。
云计算没有东西
回到上面分布式机器学习的话题,显然,对于Google来说,简直是信手拈来,毕竟MapReduce架构就是自己提出来的。最初,Google利用联邦学习来解决谷歌键盘如何预测用户下一个输入以及内容推荐的算法问题。一般而言,这个小问题对Google这家做出Tensorflow的业界最牛技术公司来说自然不是问题,可是它难就难在巧妇难为无米之炊,臣妾手上没有数据呀!要解决这个问题,国内的互联网公司或许直接投机取巧的偷偷把用户数据传到云端进行训练了。但是Google毕竟是号称不作恶的Google,他们的工程师灵机一动:我们在不获取用户数据的情况下使用用户数据不就可以了吗? 于是便诞生了联邦学习。
先来看一下MapReduce架构的异步分布式机器学习架构:
分布式机器学习
训练步骤如下:
- Server端对训练数据做shuffle,然后切分数据,下发到各个Worker节点,并将初始化模型参数下发到Worker节点。
- Worker节点得到训练数据,并且用初始化参数初始化各自的本地模型。
- Worker节点使用分发到自己的数据计算目标函数的梯度,并发送到Server端
- Server端得到Worker节点的梯度马上使用这个梯度更新模型参数,并将其下发到对应的worker节点(异步的方式,已经有论文证明的这种方式的正确性,出于性能考虑,业界一般都采用异步更新的方式,如果是同步的方式则是Server端得到所有节点的梯度之后,做一个加权平均再更新模型参数,然后下发到各个节点。)
- 如此循环往复,直至模型收敛到目标精度。
而横向联邦学习的基础架构跟上图一摸一样,它们之间的区别主要在于权限上的不同。分布式机器学习,Server端和Worker端都属于自己,因此自己具备所有数据以及控制权限,但是再联邦学习上,显然Server端是没法控制Worker端的,因此Worker端的性能和稳定性不可控,也不再需要数据分发,因此它的训练步骤如下:
1.Server端初始化模型参数,下发到在线的Worker端。
2.Worker端收到参数后初始化模型,用自己的数据计算梯度,发送给Server。
3.Server接收到梯度更新模型参数,发送给Worker。
4.如此循环往复,直至模型收敛到目标精度。
观察上面的这个过程,用户的数据始终只再用户自己的Worker端,Worker端向服务器发送的仅仅只是梯度而已,因此,这种方式达到了既不获取用户数据又用了用户数据的目标。
不过看完以上介绍,你是否感受到了一股浓浓的新瓶装旧酒的味道呢?不急,在细细品尝,你会发现联邦学习所面临的算法有效性以及安全性的挑战不是普通的分布式机器学习可以比拟的。
下面请你思考几个问题:
- 只把梯度上传到服务器,用户隐私就真的得到保护了吗 ?(深度梯度泄漏攻击)
- 各个用户节点的数据并不是独立同分布的,甚至可能极度的不平衡,这样子还能用原来的方式训练吗?
- 各个用户节点的设备性能和网络质量千差万别,导致通信时延差距巨大,这样子还能用原来的方式训练吗?
- 各个用户节点的设备往服务器发送的信息并非自己可控的,如何避免其中有恶意节点来搞破坏呢?
- 用户节点的数量也许远超传统的分布式机器学习,而且各个设备的网络良莠不齐,如何支撑其这么大的通信开销呢?
没错,这些才是联邦学习需要去解决的问题,它的性能和安全还有待大家去进一步的探索和研究。
与此同时,我们应该知道,并不是所有的机器学习问题都可以用在联邦学习上的,比如,大部分机器学习场景下,需要人工标注的数据集就没法用联邦学习了。