机器学习平台

PS-Lite核心类

2019-03-17  本文已影响0人  王勇1024

简单看一下各个类以及它们之间的关系

Van

Van用于向各个节点发送消息。Van类负责建立链接、使用receiving thread监听收到的message,Van只定义接口,具体实现是依赖ZMQVan(源码只允许使用zmqvan)

/**
 * Van用于向各个节点发送远程消息
 * 如果设置变量PS_RESEND=1,van就会对超过PS_RESEND_TIMEOUT毫秒还未收到ACK的消息进行重发
 */
class Van {
 public:
    // 工厂方法
    static Van *Create(const std::string &type);
    /** 构造函数 */
    Van() {}
    /**析构函数 */
    virtual ~Van() {}
    // 启动van。调用Send方法前必须调用Start方法,该方法会初始化对其它节点的连接,启动接收线程
    // 接收线程会持续接收消息。如果是系统控制消息,会把消息转交给postoffice,其它消息会转交给对应的app
    virtual void Start(int customer_id);
    // 发送消息,该方法是线程安全的,返回发送的字节数,如果发送失败则返回-1
    int Send(const Message &msg);
    // 返回当前节点
    inline const Node &my_node() const {
      CHECK(ready_) << "call Start() first";
      return my_node_;
    }
    // 停止接收线程
    virtual void Stop();
    // 获取下一个可用的时间戳,该方法是线程安全的
    inline int GetTimestamp() { return timestamp_++; }
    // 是否准备好发送消息
    inline bool IsReady() { return ready_; }

 protected:
    // 连接到指定节点
    virtual void Connect(const Node &node) = 0;
    // 与当前节点的端口绑定
    virtual int Bind(const Node &node, int max_retry) = 0;
    // 该方法会阻塞,直到接收大消息,并返回收到到的字节数,如果接收失败或超时,则返回-1
    virtual int RecvMsg(Message *msg) = 0;
    // 发送消息并返回字节数
    virtual int SendMsg(const Message &msg) = 0;
    // 将元数据转成字符串
    void PackMeta(const Meta &meta, char **meta_buf, int *buf_size);
    // 将字符串转成元数据
    void UnpackMeta(const char *meta_buf, int buf_size, Meta *meta);

    Node scheduler_;
    Node my_node_;
    bool is_scheduler_;
    std::mutex start_mu_;
};
}  // namespace ps
#endif  // PS_INTERNAL_VAN_H_

Resender

Resender主要用于实现消息重发,如果在指定时间内没有收到ACK消息,则对消息进行重发
在分布式系统中,通信也是不可靠的,丢包、延时都是必须考虑的场景。PS Lite 设计了 Resender 类来提高通信的可靠性,它引入了 ACK 机制。即:

Postoffice

Postoffice是全局管理类,单例模式创建。管理当前节点角色、其他节点的连接、心跳信息、配置信息。顾名思义,Postoffice类会维护了一张全局的“地址簿”,记录了所有节点的信息。

class Postoffice {
 public:
  // 返回单例
  static Postoffice* Get() {
    static Postoffice e; return &e;
  }
  // 获取Van实例
  Van* van() { return van_; }
  // 启动系统。该方法会被阻塞,直到所有的节点都启动
  void Start(int customer_id, const char* argv0, const bool do_barrier);
  // 停止系统。所有的节点在退出时都需要调用该方法
  void Finalize(const int customer_id, const bool do_barrier = true);
  // 添加消费者
  void AddCustomer(Customer* customer);
  // 移除消费者
  void RemoveCustomer(Customer* customer);
  // 获取指定的消费者
  Customer* GetCustomer(int app_id, int customer_id, int timeout = 0) const;
  /**
   * \brief get the id of a node (group), threadsafe
   *
   * if it is a  node group, return the list of node ids in this
   * group. otherwise, return {node_id}
   */
  const std::vector<int>& GetNodeIDs(int node_id) const {
    const auto it = node_ids_.find(node_id);
    CHECK(it != node_ids_.cend()) << "node " << node_id << " doesn't exist";
    return it->second;
  }
  // 获取所有节点key的范围
  const std::vector<Range>& GetServerKeyRanges();
  /**
   * \brief the template of a callback
   */
  using Callback = std::function<void()>;
  /**
   * \brief Register a callback to the system which is called after Finalize()
   *
   * The following codes are equal
   * \code {cpp}
   * RegisterExitCallback(cb);
   * Finalize();
   * \endcode
   *
   * \code {cpp}
   * Finalize();
   * cb();
   * \endcode
   * \param cb the callback function
   */
  void RegisterExitCallback(const Callback& cb) {
    exit_callback_ = cb;
  }
  /**
   * \brief convert from a worker rank into a node id
   * \param rank the worker rank
   */
  static inline int WorkerRankToID(int rank) {
    return rank * 2 + 9;
  }
  /**
   * \brief convert from a server rank into a node id
   * \param rank the server rank
   */
  static inline int ServerRankToID(int rank) {
    return rank * 2 + 8;
  }
  /**
   * \brief convert from a node id into a server or worker rank
   * \param id the node id
   */
  static inline int IDtoRank(int id) {
  /** 返回worker节点数量 */
  int num_workers() const { return num_workers_; }
  /** 返回server节点数量 */
  int num_servers() const { return num_servers_; }
  /** \brief Returns the rank of this node in its group
   *
   * Each worker will have a unique rank within [0, NumWorkers()). So are
   * servers. This function is available only after \ref Start has been called.
   */
  int my_rank() const { return IDtoRank(van_->my_node().id); }
  /** 如果当前节点是worker节点,则返回true */
  int is_worker() const { return is_worker_; }
  /**如果当前节点是server节点,则返回true */
  int is_server() const { return is_server_; }
  /** 如果当前节点是scheduler节点,则返回true */
  int is_scheduler() const { return is_scheduler_; }
  /** \brief Returns the verbose level. */
  int verbose() const { return verbose_; }
  /** \brief Return whether this node is a recovery node */
  bool is_recovery() const { return van_->my_node().is_recovery; }
  /**
   * \brief barrier
   * \param node_id the barrier group id
   */
  void Barrier(int customer_id, int node_group);
  // 处理控制消息,van收到控制消息后会调用该方法
  void Manage(const Message& recv);
  // 更新心跳记录
  void UpdateHeartbeat(int node_id, time_t t) {
    std::lock_guard<std::mutex> lk(heartbeat_mu_);
    heartbeats_[node_id] = t;
  }
  // 获取一定时间内未报告心跳消息的节点ID
  std::vector<int> GetDeadNodes(int t = 60);
};

Customer

Customer用来通信,跟踪request和response。每一个连接对应一个Customer实例,连接对方的id和Customer实例的id相同。

/**
 * \brief The object for communication.
 *
 * As a sender, a customer tracks the responses for each request sent.
 *
 * It has its own receiving thread which is able to process any message received
 * from a remote node with `msg.meta.customer_id` equal to this customer's id
 */
class Customer {
 public:
  /**
   * \brief the handle for a received message
   * \param recved the received message
   */
  using RecvHandle = std::function<void(const Message& recved)>;

  /**
   * \brief constructor
   * \param app_id the globally unique id indicating the application the postoffice
   *               serving for
   * \param customer_id the locally unique id indicating the customer of a postoffice
   * \param recv_handle the functino for processing a received message
   */
  Customer(int app_id, int customer_id, const RecvHandle& recv_handle);

  /**
   * \brief desconstructor
   */
  ~Customer();

  /**
   * \brief return the globally unique application id
   */
  inline int app_id() { return app_id_; }


  /**
   * \brief return the locally unique customer id
   */
  inline int customer_id() { return customer_id_; }

  /**
   * \brief get a timestamp for a new request. threadsafe
   * \param recver the receive node id of this request
   * \return the timestamp of this request
   */
  int NewRequest(int recver);


  /**
   * \brief wait until the request is finished. threadsafe
   * \param timestamp the timestamp of the request
   */
  void WaitRequest(int timestamp);

  /**
   * \brief return the number of responses received for the request. threadsafe
   * \param timestamp the timestamp of the request
   */
  int NumResponse(int timestamp);

  /**
   * \brief add a number of responses to timestamp
   */
  void AddResponse(int timestamp, int num = 1);

  /**
   * \brief accept a received message from \ref Van. threadsafe
   * \param recved the received the message
   */
  inline void Accept(const Message& recved) {
    recv_queue_.Push(recved);
  }
};

心跳机制

为了记录网络的可达性,PS Lite 设计了心跳机制。具体而言:

路由

在多 Server 架构下,一个很重要的问题是如何分布多个参数。换句话说,给定一个参数的键,如何确定其存储在哪一台 Server 上。路由功能直接影响到 Worker 在 Push/Pull 阶段的通信。

PS Lite 将路由逻辑放置在 Worker 端,采用范围划分的策略,即每一个 Server 有自己固定负责的键的范围。这个范围是在 Worker 启动的时候确定的。具体代码参见方法 Postoffice::GetServerKeyRanges(),细节如下:

需要注意的是,在不能刚好整除的情况下,键域上界的一小段被丢弃了。

调试

在系统运行中,我们经常希望能打印一些收到的消息来方便定位问题。PS Lite 通过环境变量 PS_DROP_MSG 提供了这一功能,其值代表输出消息的概率(不含百分号)。

例如,我们启动某个 Server 前,配置了环境变量 PS_DROP_MSG=70。那么该 Server 进程会按照 70% 的概率随机打印其收到的消息。

参考资料

MXNet之ps-lite及parameter server原理
PS-Lite源码分析
ps-lite源码剖析

上一篇 下一篇

猜你喜欢

热点阅读