机器学习平台

ps-lite概述

2019-03-16  本文已影响0人  王勇1024

概述

ps-lite旨在构建高可用分布式的机器学习应用。在ps-lite框架中,多个节点运行在多台物理机器上用于处理机器学习问题。通常会运行一个schedule节点和多个worker/server节点。、

ps-lite架构

分布式优化

假设我们想要解决下面的问题:

其中(yi,xi)是样本集,w是权重。
我们想要通过minibatch随机梯度下降(SGD,其中batch大小是b)的方式来解决这个问题。在时间t时,该算法首先随机挑选b个样本,然后通过下面的公式更新权重w

我们给出两个例子来说明ps-lite实现分布式解决这一问题的基本思想。

异步SGD

在第一个例子中,我们将SGD扩展为异步SGD。我们让server节点来维护w,server k获取到w的第k个分片,标识为wk。当从worker接收到梯度后,server k会更新它所维护的权重:

t = 0;
while (Received(&grad)) {
  w_k -= eta(t) * grad;
  t++;
}

\color{red}{reveived}方法返回server从任意worker节点接收的梯段,\color{red}{eta}方法返回时间\color{red}{t}时的训练速率。
对于一个worker,每次它都会做四件事情:

Read(&X, &Y);  // 读取一个minibatch X 和 Y
Pull(&w);           // 从server pull当前的权重
ComputeGrad(X, Y, w, &grad);  // 计算梯度
Push(grad);    // push梯度到server

ps-lite会提供\color{red}{push}\color{red}{pull}方法,用于和存有正确部分数据的server进行通信。

异步SGD的语义和单机版本不同。因为单机版本worker之间没有通信,所以就可能导致当一个worker节点正在计算梯度时权重发生变化。换句话说,每个worker都可能正在使用过期的权重。下图展示了2个server节点和3个worker节点的通信过程:

同步SGD

同步SGD的语义与单机算法完全相同,该模式使用scheduler来管理数据的同步。

for (t = 0, t < num_iteration; ++t) {
  for (i = 0; i < num_worker; ++i) {
     IssueComputeGrad(i, t);
  }
  for (i = 0; i < num_server; ++i) {
     IssueUpdateWeight(i, t);
  }
  WaitAllFinished();
}

\color{red}{IssueComputeGrad}\color{red}{IssueUpdateWeight}发布命令给worker和server,这个过程中\color{red}{WaitAllFinished}函数会一直等待,指定所有命令发布完成。
当worker接受到命令后,它会执行下面的函数:

ExecComputeGrad(i, t) {
   Read(&X, &Y);  // 读取 b / num_workers 个minibatch样本
   Pull(&w);           // 从server拉取最新的权重
   ComputeGrad(X, Y, w, &grad);  // 计算梯度
   Push(grad);       // push梯度到server
}

这个过程和异步SGD几乎一模一样,只是每次要处理b / num_workers个样本。
而server节点相对于异步SGD还要执行额外的一些步骤:

ExecUpdateWeight(i, t) {
   for (j = 0; j < num_workers; ++j) {
      Receive(&grad);
      aggregated_grad += grad;
   }
   w_i -= eta(t) * aggregated_grad;
}

选择哪种方式?

与单机算法相比,分布式算法增加了两个额外的开销,一是数据通信开销,即通过网络发送数据的开销;另一个是由于不完善的负载均衡和机器性能差异带来的同步开销。这两个开销可能会主宰大规模集群和TB级别数据的应用性能。

假设:

变量名称 变量含义
f 凸函数
n 样本数量
m worker数量
b minibatch大小
\tau 最大延迟
Tcomm 一个minibatch的数据通信开销
Tsync 同步开销

权衡结果如下:

SGD 收敛放缓 额外开销
同步 \sqrt b \frac{n}{b}(T_{comm}+T_{sync})
异步 \sqrt{b\tau } \frac{n}{mb}T_{comm}

从中我们得到如下结论:

上一篇下一篇

猜你喜欢

热点阅读