Paramter Server

2018-01-26  本文已影响1057人  raincoffee

Paramter Server

​ Author:lyp@ Date:2017/01/26

在深度神经网络计算框架中,参数服务器是一个非常重要的基础概念,而其不同的实现对计算效果和计算能力都有直接的影响。

请你自学参数服务器的概念,并给出一个综述,介绍什么是参数服务器,它对机器学习的作用是什么,一般实现有哪些方案,各自又有哪些优缺点。

背景

1. 问题的提出?

​ 在大规模数据上跑机器学习任务是过去十多年内系统架构师面临的主要挑战之一,许多模型和抽象先后用于这一任务。

​ 现在的大数据机器学习系统,通常数据在1TB到1PB之间,参数范围在10^9 和10^12左右。而往往这些模型的参数需要被所有的worker节点频繁的访问,这就会带来很多问题和挑战:

  1. 访问这些巨量的参数,需要大量的网络带宽支持;
  2. 很多机器学习算法都是连续型的,只有上一次迭代完成(各个worker都完成)之后,才能进行下一次迭代,这就导致了如果机器之间性能差距大(木桶理论),就会造成性能的极大损失;
  3. 在分布式中,容错能力是非常重要的。很多情况下,算法都是部署到云环境中的(这种环境下,机器是不可靠的,并且job也是有可能被抢占的);

2. 业内如何解决?

​ 如何解决这些问题呢?对于机器学习分布式优化,有很多大公司在做了,包括:Amazon,Baidu,Facebook,Google,Microsoft 和 Yahoo。也有一些开源的项目,比如:YahooLDA 和 Petuum 和Graphlab。

​ 从最开始的MPI,到Hadoop,Spark 以及Paramter Server。都曾广泛应用于机器学习处理任务。总结一下:

Paramter Server发展历程

​ 参数服务器也经历了多次发展。

Paramter Server架构设计

1. Paramter Server 整体架构

PS架构主要包括两大部分。那就是一个参数服务器组server group 和多个工作组。在parameter server中,每个 server 实际上都只负责分到的部分参数(servers共同维持一个全局的共享参数),而每个 work 也只分到部分数据和处理任务;

image

一些概念解释:

2. Paramter Server通信设计

image

- ==Asynchronous Tasks and Dependency & Flexible Consistency==

​ 体会一下Asynchronous Task 跟 Synchronous Task 的区别。

​ 如果 iter12 需要在 iter11 computation,push 跟 pull 都完成后才能开始,那么就是Synchronous,反之就是Asynchronous.如iter 11 在 iter10计算完成后就开始执行。

![image](https://img.haomeiwen.com/i2472711/92152ae529e7940b.jpg?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)


​ 参数服务器和工作节点之间的通信都属于远程调用,那么,远程调用是比较耗时的行为,如果每次都保持同步的话,那么训练相对于单节点来说是减慢了许多的,因为远程调用的耗时。因而,PS框架让远程调用成为一部调用,比如参数的push和pull发出之后,立即使用当前值开始进行下一步的梯度计算,如上图,迭代11发出push和pull的请求后,立马开始进行梯度计算,而此时,使用的还是迭代10的值。

​ **Asynchronous Task**:能够提高系统的效率(因为节省了很多等待的过程),但是,它的缺点就是容易降低算法的收敛速率;

​ 所以,系统性能跟算法收敛速率之间是存在一个trade-off的,你需要同时考虑:

```xml
算法对于参数非一致性的敏感度;
训练数据特征之间的关联度;
硬盘的存储容量;

​ 考虑到用户使用的时候会有不同的情况,parameter server 为用户提供了多种任务依赖方式:


image

Paramter Server架构实现

1. Vector Clock

​ 为参数服务器中的每个参数添加一个时间戳,来跟踪参数的更新和防止重复发送数据。基于此,通信中的梯度更新数据中也应该有时间戳,防止重复更新。

​ 如果每个参数都有一个时间戳,那么参数众多,时间戳也众多。好在,parameter server 在push跟pull的时候,都是rang-based,这就带来了一个好处:这个range里面的参数共享的是同一个时间戳,这显然可以大大降低了空间复杂度。

2. Messages

​ Message是节点间交互的主要格式。一条 message 包括:时间戳,len(range)对k-v.

​ $[vc(R),(k1,v1),...,(kp,vp)]kj∈Randj∈{1,...p}$

​ 这是parameter server 中最基本的通信格式,不仅仅是共享的参数才有,task 的message也是这样的格式,只要把这里的(key, value) 改成 (task ID, 参数/返回值)。

​ Messages may carry a subset of all available keys within range R. The missing keys are assigned the same timestamp without changing their values.

​ 由于机器学习问题通常都需要很高的网络带宽,因此信息的压缩是必须的。

3. Replication and Consistency

​ parameter server 在数据一致性上,使用的是传统的一致性哈希算法,参数key与server node id被插入到一个hash ring中。具体实现可以参考另一篇blog一致性hash算法详解。动态增加和移除节点的同时还能保证系统存储与key分配的性能效率.

image

​ 两种方式保证slave跟master之间的数据一致性:

  1. 默认的复制方式: Chain replication (强一致性, 可靠):

    image

a. 更新:只能发生在数据头节点,然后更新逐步后移,直到更新到达尾节点,并由尾节点向客户确认更新成功;
b. 查询:为保证强一致性,客户查询只能在尾节点进行;

  1. Replication after Aggregation
    image

两个worker 节点分别向server传送x和y。server 首先通过一定方式(如:f(x+y) )进行aggregate,然后再进行复制操作;

当有n个worker的时候,复制只需要k/n的带宽。通常来说,k(复制次数)是一个很小的常数,而n的值大概是几百到几千;

4. Server Management

由于key的range特性,当参数服务器集群中增加一个节点时,步骤如下:

​ 在第二步,从其他节点上取数据的时候,其他节点上的操作也分为两步,第一是拷贝数据,这可能也会导致key range的切分。第二是不再接受和这些数据有关的消息,而是进行转发,转发到新节点。

​ 在第三步,收到广播信息后,节点会删除对应区间的数据,然后,扫描所有的和R有关发送出去的还没收到回复的消息,当这些消息回复时,转发到新节点。

​ 节点的离开与节点的加入类似。

5. Worker Management

添加工作节点比添加服务器节点要简单一些,步骤如下:

​ 当一个节点离开的时候,task scheduler可能会寻找一个替代,但恢复节点是十分耗时的工作,同时,损失一些数据对最后的结果可能影响并不是很大。所以,系统会让用户进行选择,是恢复节点还是不做处理。这种机制甚至可以允许用户删掉跑的最慢的节点来提升速度。

PS-lite 实现

PS-Lite是PS架构的一个轻量级的实现。它提供了push,pull,wait等APIs。整个项目代码量不多。

A light and efficient implementation of the parameter server framework. It provides clean yet powerful APIs. For example, a worker node can communicate with the server nodes by

  • Push(keys, values): push a list of (key, value) pairs to the server nodes
  • Pull(keys): pull the values from servers for a list of keys
  • Wait: wait untill a push or pull finished.

A simple example:

  std::vector<uint64_t> key = {1, 3, 5};
  std::vector<float> val = {1, 1, 1};
  std::vector<float> recv_val;
  ps::KVWorker<float> w;
  w.Wait(w.Push(key, val));
  w.Wait(w.Pull(key, &recv_val));

总体概览

整个项目的类图如下:

image

节点角色ID

image

一共三种类型,从上图可以看出Scheduler节点只有一个,多个Worker和多个Server可以组成一个Group,因此有WorkerGroup和ServerGroup;还有Worker节点和Server节点。每个节点以及每一个Group都有唯一确定的ID。
Scheduler、ServerGroup、WorkerGroup节点ID确定如下:

/** \brief node ID for the scheduler */
static const int kScheduler = 1;
/**
 * \brief the server node group ID
 *
 * group id can be combined:
 * - kServerGroup + kScheduler means all server nodes and the scheuduler
 * - kServerGroup + kWorkerGroup means all server and worker nodes
 */
static const int kServerGroup = 2;
/** \brief the worker node group ID */
static const int kWorkerGroup = 4;

上述定义在base.h中。

1、2、4的二进制表示分别为:001、010、001。这样可以做Group之间的合并,例如要和ServerGroup和WorkerGroup发信息,只需要destination node id设为2+4=6。
1-7用来表示节点的组合。单个节点的ID从8开始。单个Server和单个Worker节点从自己的rank(0、1、2……)转换到其ID:

/**
   * \brief convert from a worker rank into a node id
   * \param rank the worker rank
   */
  static inline int WorkerRankToID(int rank) {
    return rank * 2 + 9;
  }
  /**
   * \brief convert from a server rank into a node id
   * \param rank the server rank
   */
  static inline int ServerRankToID(int rank) {
    return rank * 2 + 8;
  }

ID到其rank转换:

  static inline int IDtoRank(int id) {
    return std::max((id - 8) / 2, 0);
  }

Postofficetd::unordered_map<int, std::vector<int>> node_ids_==保存了Node/NodeGroup与连接节点集合的对应关系==。

消息封装

zz

通信机制

Scheduler节点管理所有节点的地址。每个节点要知道Scheduler节点的IP、port;启动时绑定一个本地端口,并向Scheduler节点报告。==Scheduler节点在每个几点启动后,给节点分配ID,把节点信息通知出去(例如Worker节点要知道Server节点IP和端口,Server节点要知道Worker节点的IP和端口)==。节点在建立连接后,才会正式启动。

测试链接的过程:

至此,通信连接建立完成。

同步策略

异步工作时,Worker计算参数可能要依赖前面Pull是否完成。如果需要等待某一步操作,可以调用SimpleApp::Wait操作。具体实现是调用了Customer::WaitRequest(),它会跟踪request和response数量是否相同,直到相同才会返回;tracker_类型为std::vector<std::pair<int, int>>,记录了request和response数量,这个数据结构一直增长,会造成内存一直增长。

消息处理流程

每个节点都监听了本地一个端口;该连接的节点在启动时已经连接。 上述通信机制的时候已经描述过。

回顾一下通信机制中VAN start()方法的内容:

而针对消息处理流程,主要的逻辑集中在上述标黄那一步开始的。

对于==Server节点==:

  1. Van::Receiving()函数是单独一个线程来接收数据。数据接收后,根据不同命令执行不同动作,例如Control::ADD_NODE就是添加节点。如果需要下一步处理,会将消息传递给Customer::Accept函数。

    void Van::Receiving() {
      Meta nodes;
      Meta recovery_nodes;  // store recovery nodes
      recovery_nodes.control.cmd = Control::ADD_NODE;
    
      while (true) {
        Message msg;
        int recv_bytes = RecvMsg(&msg);
        // For debug, drop received message
        if (ready_.load() && drop_rate_ > 0) {
          unsigned seed = time(NULL) + my_node_.id;
          if (rand_r(&seed) % 100 < drop_rate_) {
            LOG(WARNING) << "Drop message " << msg.DebugString();
            continue;
          }
        }
    
        CHECK_NE(recv_bytes, -1);
        recv_bytes_ += recv_bytes;
        if (Postoffice::Get()->verbose() >= 2) {
          PS_VLOG(2) << msg.DebugString();
        }
        // duplicated message
        if (resender_ && resender_->AddIncomming(msg)) continue;
    
        if (!msg.meta.control.empty()) {
          // control msg
          auto& ctrl = msg.meta.control;
          if (ctrl.cmd == Control::TERMINATE) {
            ProcessTerminateCommand();
            break;
          } else if (ctrl.cmd == Control::ADD_NODE) {
            ProcessAddNodeCommand(&msg, &nodes, &recovery_nodes);
          } else if (ctrl.cmd == Control::BARRIER) {
            ProcessBarrierCommand(&msg);
          } else if (ctrl.cmd == Control::HEARTBEAT) {
            ProcessHearbeat(&msg);
          } else {
            LOG(WARNING) << "Drop unknown typed message " << msg.DebugString();
          }
        } else {
          ProcessDataMsg(&msg);
        }
      }
    }
    

  2. Customer::Accept()函数将消息添加到一个队列recv_queue_Customer::Receiving()是一个线程在运行,从队列取消息处理;处理过程中会使用函数对象recv_handle_处理消息,这个函数对象是SimpleApp::Process函数。

    void Customer::Receiving() {
      while (true) {
        Message recv;
        recv_queue_.WaitAndPop(&recv);
        if (!recv.meta.control.empty() &&
            recv.meta.control.cmd == Control::TERMINATE) {
          break;
        }
        //该线程处理消息
        recv_handle_(recv);
        if (!recv.meta.request) {
          std::lock_guard<std::mutex> lk(tracker_mu_);
          tracker_[recv.meta.timestamp].second++;
          tracker_cond_.notify_all();
        }
      }
    }
    
  3. SimpleApp::Process根据是消息类型(请求or响应,调用用户注册的函数来处理消息,request_handle_response_handle_分别处理请求和响应。

对于Worker节点,上面第3点略有不同。因为Worker都是通过PushPull来通信,而且参数都是key-value对。Pull·参数时,通过KVWorker::Process调用回调函数来处理消息。

调试及启动流程

PS Lite通过环境变量和外界交互。

启动流程:
1、首先启动Scheduler节点。这是要固定好Server和Worker数量。
2、启动Worker或Server节点。启动时连接Scheduler节点,绑定本地端口,并向Scheduler节点注册自己信息。
3、Scheduler等待所有Worker节点都注册后,给其分配id,并把节点信息传送出去。此时Scheduler节点已经准备好。
4、Worker或Server接收到Scheduler传送的信息后,建立对应节点的连接。此时Worker或Server已经准备好。

调试时,通过环境变量来控制调试日志。
PS_VERBOSE=1,会打印连接日志。
PS_VERBOSE=2,会打印所有数据通信日志。

源码test中连接事例

#include "ps/ps.h"
using namespace ps;

void StartServer() {
  if (!IsServer()) return;
  auto server = new KVServer<float>(0);
  //设置kv默认处理handle, 可以自定义
  server->set_request_handle(KVServerDefaultHandle<float>());
  RegisterExitCallback([server](){ delete server; });
}

void RunWorker() {
  if (!IsWorker()) return;
  KVWorker<float> kv(0, 0);

  // init
  int num = 10000;
  std::vector<Key> keys(num);
  std::vector<float> vals(num);

  int rank = MyRank();
  srand(rank + 7);
  for (int i = 0; i < num; ++i) {
    keys[i] = kMaxKey / num * i + rank;
    vals[i] = (rand() % 1000);
  }

  // push
  int repeat = 50;
  std::vector<int> ts;
  for (int i = 0; i < repeat; ++i) {
    ts.push_back(kv.Push(keys, vals));

    // to avoid too frequency push, which leads huge memory usage
    if (i > 10) kv.Wait(ts[ts.size()-10]);
  }
  for (int t : ts) kv.Wait(t);

  // pull
  std::vector<float> rets;
  kv.Wait(kv.Pull(keys, &rets));

  float res = 0;
  for (int i = 0; i < num; ++i) {
    res += fabs(rets[i] - vals[i] * repeat);
  }
  CHECK_LT(res / repeat, 1e-5);
  LL << "error: " << res / repeat;
}

int main(int argc, char *argv[]) {
  // setup server nodes
  StartServer();
  // start system
  Start(0);
  // run worker nodes
  RunWorker();
  // stop system
  Finalize(0, true);
  return 0;
}

其他

该部分内容暂未完成,进行学习。

PaddlePaddle

//@TODO

Tensorflow

//@TODO

Adam

//@TODO

Adam框架仍然基于Multi-Spert架构,这个架构的大体含义就是将集群分为如下几个部分:

  1. 数据服务类。存储数据,数据备份。向计算节点提供数据。
  2. 训练模型类。训练模型,然后更新参数。
  3. 参数服务器。维护一个共享的模型,计算节点计算完成后,可以向参数服务器发送请求更新参数。

参考文献

参考文献

  1. Scaling Distributed Machine Learning with the Parameter Server
  2. Parameter Server for Distributed Machine Learning
  3. PS-Lite Documents

参考Blog

  1. MPI 在大规模机器学习领域的前景如何
  2. 参数服务器——分布式机器学习的新杀器
  3. Allreduce (or MPI) vs. Parameter server approaches
  4. 横向对比三大分布式机器学习平台:Spark、PMLS、TensorFlow
  5. 机器学习入门:线性回归及梯度下降
  6. 详解并行逻辑回归
  7. 一致性HASH算法详解
  8. 【深度学习&分布式】Parameter Server 详解
  9. parameter_server架构
  10. 【分布式计算】MapReduce的替代者-Parameter Server
  11. Adam:大规模分布式机器学习框架
  12. ParameterServer入门和理解
  13. PS-Lite源码分析
  14. Google Protocol Buffer 的使用和原理
  15. 几种机器学习框架的对比和选择
  16. tensorflow架构
  17. 如何评价百度开源的深度学习框架 PaddlePaddle?

参考项目

  1. https://github.com/dmlc/ps-lite
上一篇 下一篇

猜你喜欢

热点阅读