目录
参数服务器是机器学习训练一种范式,是为了解决分布式机器学习问题的一个编程框架,其主要包括服务器端,客户端和调度器,与其他范式相比,参数服务器把模型参数存储和更新提升为主要组件,并且使用多种方法提高了处理能力。
本文是参数服务器系列第一篇,介绍ps-lite的总体设计和基础模块 Postoffice。
l
如果做一个类比,参数服务器是机器学习领域的分布式内存数据库,其作用是存储模型和更新模型。
我们来看看机器学习的几个步骤,这些步骤不断循环往复。
如果使用参数服务器训练,我们可以把如上步骤对应如下:
具体如下图:
FP/BP +--------+ Gather/Sum FP/BP +-------+ Gather/Sum
+----------> | grad 1 +------+ +----------------------> |grad 2 +-----------+
| +--------+ | | +-------+ |
+-----+----+ v +--------------+-------------------+ v
| | +---+----------+ Update | | +------+-----+ Update +------------------+
| weight 1 | | total grad 1 +--------->+weight 2 = weight 1 - total grad 1| |total grad 2+--------> |weight 2 = ...... |
| | +---+----------+ | | +------+-----+ +------------------+
+-----+----+ ^ +--------------+-------------------+ ^
| FP/BP +--------+ | | FP/BP +-------+ |
+----------> | grad 2 +------+ +----------------------> |grad 2 +-----------+
+--------+ Gather/Sum +-------+ Gather/Sum
手机如下:
因此我们可以推导出参数服务器之中各个模块的作用:
参数服务器属于机器学习训练的一个范式,具体可以分为三代(目前各大公司应该有自己内部最新实现,可以算为第四代)。
在参数服务器之前,大部分分布式机器学习算法是通过定期同步来实现的,比如集合通信的all-reduce,或者 map-reduce类系统的reduce步骤。但是定期同步有两个问题:
因此,当async sgd出现之后,就有人提出了参数服务器。
参数服务器的概念最早来自于Alex Smola于2010年提出的并行LDA的框架。它通过采用一个分布式的Memcached作为存放共享参数的存储,这样就提供了有效的机制用于分布式系统中不同的Worker之间同步模型参数,而每个Worker只需要保存他计算时所以来的一小部分参数即可,也避免了所有进程在一个时间点上都停下来同步。但是独立的kv对带来了很大的通信开销,而且服务端端难以编程。
第二代由Google的Jeff Dean进一步提出了第一代Google大脑的解决方案:DistBelief。DistBelief将巨大的深度学习模型分布存储在全局的参数服务器中,计算节点通过参数服务器进行信息传递,很好地解决了SGD和L-BFGS算法的分布式训练问题。
再后来就是李沐所在的DMLC组所设计的参数服务器。根据论文中所写,该parameter server属于第三代参数服务器,就是提供了更加通用的设计。架构上包括一个Server Group和若干个Worker Group。
我们首先用沐神论文中的图来看看系统架构。
解释一下图中整体架构中每个模块:
在分布式计算梯度时,系统的数据流如下:
图中每个步骤的作用为:
上面两个图的依据是其原始代码。ps-lite 是后来的精简版代码,所以有些功能在 ps-lite 之中没有提供。
从网上找到了一些 ps-lite发展历程,可以看到其演进的思路。
第一代是parameter,针对特定算法(如逻辑回归和LDA)进行了设计和优化,以满足规模庞大的工业机器学习任务(数百亿个示例和10-100TB数据大小的功能)。
后来尝试为机器学习算法构建一个开源通用框架。 该项目位于dmlc / parameter_server。
鉴于其他项目的需求不断增长,创建了ps-lite,它提供了一个干净的数据通信API和一个轻量级的实现。 该实现基于dmlc / parameter_server,但为不同的项目重构了作业启动器,文件IO和机器学习算法代码,如dmlc-core和wormhole
根据在开发dmlc / mxnet期间学到的经验,从v1进一步重构了API和实现。 主要变化包括:
ps-lite 其实是Paramter Server的实现的一个框架,其中参数处理具体相关策略需用户自己实现。
Parameter Server包含三种角色:Worker,Server,Scheduler。具体关系如下图:
具体角色功能为:
其中引入scheduler的好处如下:
引入一个 scheduler 模块,则会形成一个比较经典的三角色分布式系统架构;worker 和 server 的角色和职责不变,而 scheduler 模块则有比较多的选择:
引入 scheduler 模块的另一个好处是给实现模型并行留出了空间;
scheduler 模块不仅有利于实现模型并行训练范式,还有其他好处:比如通过针对特定模型参数相关性的理解,对参数训练过程进行细粒度的调度,可以进一步加快模型收敛速度,甚至有机会提升模型指标。
熟悉分布式系统的同学可能会担心 scheduler 模块的单点问题,这个通过 raft、zab 等 paxos 协议可以得到比较好的解决。
ps-lite系统中的一些基础模块如下:
Environment:一个单例模式的环境变量类,它通过一个 std::unordered_map<std::string, std::string> kvs
维护了一组 kvs 借以保存所有环境变量名以及值;
PostOffice:一个单例模式的全局管理类,一个 node 在生命期内具有一个PostOffice,依赖它的类成员对Node进行管理;
Van:通信模块,负责与其他节点的网络通信和Message的实际收发工作。PostOffice持有一个Van成员;
SimpleApp:KVServer和KVWorker的父类,它提供了简单的Request, Wait, Response,Process功能;KVServer和KVWorker分别根据自己的使命重写了这些功能;
Customer:每个SimpleApp对象持有一个Customer类的成员,且Customer需要在PostOffice进行注册,该类主要负责:
Node :信息类,存储了本节点的对应信息,每个 Node 可以使用 hostname + port 来唯一标识。
从源码中的例子可以看出,使用ps-lite 提供的脚本 local.sh 可以启动整个系统,这里 test_connection 为编译好的可执行程序。
./local.sh 2 3 ./test_connection
具体 local.sh 代码如下。注意,在shell脚本中,有三个shift,这就让脚本中始终使用$1。
针对我们的例子,脚本参数对应了就是
可以从脚本中看到,本脚本做了两件事:
具体如下:
#!/bin/bash
# set -x
if [ $# -lt 3 ]; then
echo "usage: $0 num_servers num_workers bin [args..]"
exit -1;
fi
# 对环境变量进行各种配置,此后不同节点都会从这些环境变量中获取信息
export DMLC_NUM_SERVER=$1
shift
export DMLC_NUM_WORKER=$1
shift
bin=$1
shift
arg="$@"
# start the scheduler
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000
export DMLC_ROLE='scheduler'
${bin} ${arg} &
# start servers
export DMLC_ROLE='server'
for ((i=0; i<${DMLC_NUM_SERVER}; ++i)); do
export HEAPPROFILE=./S${i}
${bin} ${arg} &
done
# start workers
export DMLC_ROLE='worker'
for ((i=0; i<${DMLC_NUM_WORKER}; ++i)); do
export HEAPPROFILE=./W${i}
${bin} ${arg} &
done
wait
我们依然使用官方例子看看。
ps-lite 使用的是 C++语言,其中 worker, server, scheduler 都使用同一套代码。这会让习惯于Java,python的同学非常不适应,大家需要适应一个阶段。
针对这个示例程序,起初会让人疑惑,为什么每次程序运行,代码中都会启动 scheduler,worker,server?其实,从下面注释就能看出来,具体执行是依据环境变量来决定。如果环境变量设置了本次角色是 server,则不会启动 scheduler 和 worker。
#include <cmath>
#include "ps/ps.h"
using namespace ps;
void StartServer() {
if (!IsServer()) {
return;
}
auto server = new KVServer<float>(0);
server->set_request_handle(KVServerDefaultHandle<float>()); //注册functor
RegisterExitCallback([server](){ delete server; });
}
void RunWorker() {
if (!IsWorker()) return;
KVWorker<float> kv(0, 0);
// init
int num = 10000;
std::vector<Key> keys(num);
std::vector<float> vals(num);
int rank = MyRank();
srand(rank + 7);
for (int i = 0; i < num; ++i) {
keys[i] = kMaxKey / num * i + rank;
vals[i] = (rand() % 1000);
}
// push
int repeat = 50;
std::vector<int> ts;
for (int i = 0; i < repeat; ++i) {
ts.push_back(kv.Push(keys, vals)); //kv.Push()返回的是该请求的timestamp
// to avoid too frequency push, which leads huge memory usage
if (i > 10) kv.Wait(ts[ts.size()-10]);
}
for (int t : ts) kv.Wait(t);
// pull
std::vector<float> rets;
kv.Wait(kv.Pull(keys, &rets));
// pushpull
std::vector<float> outs;
for (int i = 0; i < repeat; ++i) {
// PushPull on the same keys should be called serially
kv.Wait(kv.PushPull(keys, vals, &outs));
}
float res = 0;
float res2 = 0;
for (int i = 0; i < num; ++i) {
res += std::fabs(rets[i] - vals[i] * repeat);
res2 += std::fabs(outs[i] - vals[i] * 2 * repeat);
}
CHECK_LT(res / repeat, 1e-5);
CHECK_LT(res2 / (2 * repeat), 1e-5);
LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
}
int main(int argc, char *argv[]) {
// start system
Start(0); // Postoffice::start(),每个node都会调用到这里,但是在 Start 函数之中,会依据本次设定的角色来不同处理,只有角色为 scheduler 才会启动 Scheduler。
// setup server nodes
StartServer(); // Server会在其中做有效执行,其他节点不会有效执行。
// run worker nodes
RunWorker(); // Worker 会在其中做有效执行,其他节点不会有效执行。
// stop system
Finalize(0, true); //结束。每个节点都需要执行这个函数。
return 0;
}
其中KVServerDefaultHandle是functor,用与处理server收到的来自worker的请求,具体如下:
/**
* \brief an example handle adding pushed kv into store
*/
template <typename Val>
struct KVServerDefaultHandle { //functor,用与处理server收到的来自worker的请求
// req_meta 是存储该请求的一些元信息,比如请求来自于哪个节点,发送给哪个节点等等
// req_data 是发送过来的数据
// server 是指向当前server对象的指针
void operator()(
const KVMeta& req_meta, const KVPairs<Val>& req_data, KVServer<Val>* server) {
size_t n = req_data.keys.size();
KVPairs<Val> res;
if (!req_meta.pull) { //收到的是pull请求
CHECK_EQ(n, req_data.vals.size());
} else { //收到的是push请求
res.keys = req_data.keys; res.vals.resize(n);
}
for (size_t i = 0; i < n; ++i) {
Key key = req_data.keys[i];
if (req_meta.push) { //push请求
store[key] += req_data.vals[i]; //此处的操作是将相同key的value相加
}
if (req_meta.pull) { //pull请求
res.vals[i] = store[key];
}
}
server->Response(req_meta, res);
}
std::unordered_map<Key, Val> store;
};
Postoffice 是一个单例模式的全局管理类,其维护了系统的一个全局信息,具有如下特点:
请注意:这些代码都是在 Postoffice 类内,没有按照角色分开成多个模块。
类 UML 图如下:
下面我们只给出关键变量和成员函数说明,因为每个节点都包含一个 PostOffice,所以 PostOffice 的数据结构中包括了各种节点所需要的变量,会显得比较繁杂。
主要变量作用如下:
主要函数作用如下:
具体如下:
class Postoffice {
/**
* \brief start the system
*
* This function will block until every nodes are started.
* \param argv0 the program name, used for logging.
* \param do_barrier whether to block until every nodes are started.
*/
void Start(int customer_id, const char* argv0, const bool do_barrier);
/**
* \brief terminate the system
*
* All nodes should call this function before existing.
* \param do_barrier whether to do block until every node is finalized, default true.
*/
void Finalize(const int customer_id, const bool do_barrier = true);
/**
* \brief barrier
* \param node_id the barrier group id
*/
void Barrier(int customer_id, int node_group);
/**
* \brief process a control message, called by van
* \param the received message
*/
void Manage(const Message& recv);
/**
* \brief update the heartbeat record map
* \param node_id the \ref Node id
* \param t the last received heartbeat time
*/
void UpdateHeartbeat(int node_id, time_t t) {
std::lock_guard<std::mutex> lk(heartbeat_mu_);
heartbeats_[node_id] = t;
}
/**
* \brief get node ids that haven't reported heartbeats for over t seconds
* \param t timeout in sec
*/
std::vector<int> GetDeadNodes(int t = 60);
private:
void InitEnvironment();
Van* van_;
mutable std::mutex mu_;
// app_id -> (customer_id -> customer pointer)
std::unordered_map<int, std::unordered_map<int, Customer*>> customers_;
std::unordered_map<int, std::vector<int>> node_ids_;
std::mutex server_key_ranges_mu_;
std::vector<Range> server_key_ranges_;
bool is_worker_, is_server_, is_scheduler_;
int num_servers_, num_workers_;
std::unordered_map<int, std::unordered_map<int, bool> > barrier_done_;
int verbose_;
std::mutex barrier_mu_;
std::condition_variable barrier_cond_;
std::mutex heartbeat_mu_;
std::mutex start_mu_;
int init_stage_ = 0;
std::unordered_map<int, time_t> heartbeats_;
Callback exit_callback_;
/** \brief Holding a shared_ptr to prevent it from being destructed too early */
std::shared_ptr<Environment> env_ref_;
time_t start_time_;
DISALLOW_COPY_AND_ASSIGN(Postoffice);
};
首先我们介绍下 node id 映射功能,就是如何在逻辑节点和物理节点之间做映射,如何把物理节点划分成各个逻辑组,如何用简便的方法做到给组内物理节点统一发消息。
三个逻辑组的定义如下:
/** \brief node ID for the scheduler */
static const int kScheduler = 1;
/**
* \brief the server node group ID
*
* group id can be combined:
* - kServerGroup + kScheduler means all server nodes and the scheuduler
* - kServerGroup + kWorkerGroup means all server and worker nodes
*/
static const int kServerGroup = 2;
/** \brief the worker node group ID */
static const int kWorkerGroup = 4;
node id 是物理节点的唯一标示,rank 是每一个逻辑概念(scheduler,work,server)内部的唯一标示。这两个标示由一个算法来确定。
如下面代码所示,如果配置了 3 个worker,则 worker 的 rank 从 0 ~ 2,那么这几个 worker 实际对应的 物理 node ID 就会使用 WorkerRankToID 来计算出来。
for (int i = 0; i < num_workers_; ++i) {
int id = WorkerRankToID(i);
for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
kWorkerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
具体计算规则如下:
/**
* \brief convert from a worker rank into a node id
* \param rank the worker rank
*/
static inline int WorkerRankToID(int rank) {
return rank * 2 + 9;
}
/**
* \brief convert from a server rank into a node id
* \param rank the server rank
*/
static inline int ServerRankToID(int rank) {
return rank * 2 + 8;
}
/**
* \brief convert from a node id into a server or worker rank
* \param id the node id
*/
static inline int IDtoRank(int id) {
#ifdef _MSC_VER
#undef max
#endif
return std::max((id - 8) / 2, 0);
}
这样我们可以知道,1-7 的id表示的是node group,单个节点的id 就从 8 开始。
而且这个算法保证server id为偶数,node id为奇数。
因为有时请求要发送给多个节点,所以ps-lite用了一个 map 来存储每个 node group / single node 对应的实际的node节点集合,即 确定每个id值对应的节点id集。
std::unordered_map<int, std::vector<int>> node_ids_
如何使用这个node_ids_?我们还是需要看之前的代码:
for (int i = 0; i < num_workers_; ++i) {
int id = WorkerRankToID(i);
for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
kWorkerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
我们回忆一下之前的节点信息:
所以,为了实现 “设置 1-7 内任意一个数字 可以发送给其对应的 所有node” 这个功能,对于每一个新节点,需要将其对应多个id(node,node group)上,这些id组就是本节点可以与之通讯的节点。例如对于 worker 2 来说,其 node id 是 2 * 2 + 8 = 12,所以需要将它与
这 5 个id 相对应,即需要在 node_ids_ 这个映射表中对应的 4, 4 + 1, 4 + 2, 4 +1 + 2, 12 这五个 item 之中添加。就是上面代码中的内部 for 循环条件。即,node_ids_ [4], node_ids_ [5],node_ids_ [6],node_ids_ [7] ,node_ids_ [12] 之中,都需要把 12 添加到 vector 最后。
workers 跟 servers 之间通过 push 跟 pull 来通信。worker 通过 push 将计算好的梯度发送到server,然后通过 pull 从server更新参数。
parameter server 中,参数都是可以被表示成(key, value)的集合,比如一个最小化损失函数的问题,key就是feature ID,而value就是它的权值。对于稀疏参数来说,value不存在的key,就可以认为value是0。
把参数表示成 k-v, 形式更自然,易于理解和编程实现。
分布式算法有两个额外成本:数据通信成本,负载均衡不理想和机器性能差异导致的同步成本。
对于高维机器学习训练来说,因为高频特征更新极为频繁,所会导致网络压力极大。如果每一个参数都设一个key并且按key更新,那么会使得通信变得更加频繁低效,为了抹平这个问题,就需要有折衷和平衡,即,
利用机器学习算法的特性,给每个key对应的value赋予一个向量或者矩阵,这样就可以一次性传递多个参数,权衡了融合与同步的成本。
做这样的操作的前提是假设参数是有顺序的。缺点是在对于稀疏模型来说,总会在向量或者矩阵里会有参数为0,这在单个参数状态下是不用存的,所以,造成了数据的冗余。
但这样做有两点好处:
为了提高计算性能和带宽效率,参数服务器也会采用批次更新的办法,来减轻高频 key 的压力。比如把minibatch之中高频key合并成一个minibatch进行更新。
ps-lite 允许用户使用 Range Push 跟 Range Pull 操作。
路由功能指的就是:Worker 在做 Push/Pull 时候,如何知道把消息发送给哪些 Servers。
我们知道,ps-lite 是多 Server 架构,一个很重要的问题是如何分布多个参数。比如给定一个参数的键,如何确定其存储在哪一台 Server 上。所以必然有一个路由逻辑用来确立 key与server的对应关系。
PS Lite 将路由逻辑放置在 Worker 端,采用范围划分的策略,即每一个 Server 有自己固定负责的键的范围。这个范围是在 Worker 启动的时候确定的。细节如下:
[MAX/N*i, MAX/N*(i+1))
。需要注意的是,在不能刚好整除的情况下,键域上界的一小段被丢弃了。
具体实现如下:
首先,ps-lite的key只支持int类型。
#if USE_KEY32
/*! \brief Use unsigned 32-bit int as the key type */
using Key = uint32_t;
#else
/*! \brief Use unsigned 64-bit int as the key type */
using Key = uint64_t;
#endif
/*! \brief The maximal allowed key value */
static const Key kMaxKey = std::numeric_limits<Key>::max();
其次,将int范围均分即可
const std::vector<Range>& Postoffice::GetServerKeyRanges() {
if (server_key_ranges_.empty()) {
for (int i = 0; i < num_servers_; ++i) {
server_key_ranges_.push_back(Range(
kMaxKey / num_servers_ * i,
kMaxKey / num_servers_ * (i+1)));
}
}
return server_key_ranges_;
}
从之前分析中我们可以知道,ps-lite 是通过环境变量来控制具体节点。
具体某个节点属于哪一种取决于启动节点之前设置了哪些环境变量以及其数值。
环境变量包括:节点角色,worker&server个数、ip、port等。
InitEnvironment 函数就是创建了 Van,得到了 worker 和 server 的数量,得到了本节点的类型。
void Postoffice::InitEnvironment() {
const char* val = NULL;
std::string van_type = GetEnv("DMLC_PS_VAN_TYPE", "zmq");
van_ = Van::Create(van_type);
val = CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_WORKER"));
num_workers_ = atoi(val);
val = CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_SERVER"));
num_servers_ = atoi(val);
val = CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE"));
std::string role(val);
is_worker_ = role == "worker";
is_server_ = role == "server";
is_scheduler_ = role == "scheduler";
verbose_ = GetEnv("PS_VERBOSE", 0);
}
主要就是:
具体代码如下:
void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) {
start_mu_.lock();
if (init_stage_ == 0) {
InitEnvironment();
// init node info.
// 对于所有的worker,进行node设置
for (int i = 0; i < num_workers_; ++i) {
int id = WorkerRankToID(i);
for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
kWorkerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
// 对于所有的server,进行node设置
for (int i = 0; i < num_servers_; ++i) {
int id = ServerRankToID(i);
for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup,
kServerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
// 设置scheduler的node
for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
node_ids_[g].push_back(kScheduler);
}
init_stage_++;
}
start_mu_.unlock();
// start van
van_->Start(customer_id);
start_mu_.lock();
if (init_stage_ == 1) {
// record start time
start_time_ = time(NULL);
init_stage_++;
}
start_mu_.unlock();
// do a barrier here
if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
}
总的来讲,schedular节点通过计数的方式实现各个节点的同步。具体来说就是:
ps-lite 使用 Barrier 来控制系统的初始化,就是大家都准备好了再一起前进。这是一个可选项。具体如下:
Node会调用 Barrier 函数 告知Scheduler,随即自己进入等待状态。
注意,调用时候是
if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
这就是说,等待所有的 group,即 scheduler 节点也要给自己发送消息。
void Postoffice::Barrier(int customer_id, int node_group) {
if (GetNodeIDs(node_group).size() <= 1) return;
auto role = van_->my_node().role;
if (role == Node::SCHEDULER) {
CHECK(node_group & kScheduler);
} else if (role == Node::WORKER) {
CHECK(node_group & kWorkerGroup);
} else if (role == Node::SERVER) {
CHECK(node_group & kServerGroup);
}
std::unique_lock<std::mutex> ulk(barrier_mu_);
barrier_done_[0][customer_id] = false;
Message req;
req.meta.recver = kScheduler;
req.meta.request = true;
req.meta.control.cmd = Control::BARRIER;
req.meta.app_id = 0;
req.meta.customer_id = customer_id;
req.meta.control.barrier_group = node_group; // 记录了等待哪些
req.meta.timestamp = van_->GetTimestamp();
van_->Send(req); // 给 scheduler 发给 BARRIER
barrier_cond_.wait(ulk, [this, customer_id] { // 然后等待
return barrier_done_[0][customer_id];
});
}
处理等待的动作在 Van 类之中,我们提前放出来。
具体ProcessBarrierCommand逻辑如下:
如果 msg->meta.request 为true,说明是 scheduler 收到消息进行处理。
request==false
的BARRIER
消息。如果 msg->meta.request 为 false,说明是收到消息这个 respones,可以解除barrier了,于是进行处理,调用 Manage 函数 。
barrier_done_
置为true,然后通知所有等待条件变量barrier_cond_.notify_all()
。void Van::ProcessBarrierCommand(Message* msg) {
auto& ctrl = msg->meta.control;
if (msg->meta.request) { // scheduler收到了消息,因为 Postoffice::Barrier函数 会在发送时候做设置为true。
if (barrier_count_.empty()) {
barrier_count_.resize(8, 0);
}
int group = ctrl.barrier_group;
++barrier_count_[group]; // Scheduler会对Barrier请求进行计数
if (barrier_count_[group] ==
static_cast
barrier_count_[group] = 0;
Message res;
res.meta.request = false; // 回复时候,这里就是false
res.meta.app_id = msg->meta.app_id;
res.meta.customer_id = msg->meta.customer_id;
res.meta.control.cmd = Control::BARRIER;
for (int r : Postoffice::Get()->GetNodeIDs(group)) {
int recver_id = r;
if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
res.meta.recver = recver_id;
res.meta.timestamp = timestamp_++;
Send(res);
}
}
}
} else { // 说明这里收到了 barrier respones,可以解除 barrier了。具体见上面的设置为false处。
Postoffice::Get()->Manage(*msg);
}
}
Manage 函数就是解除了 barrier。
void Postoffice::Manage(const Message& recv) {
CHECK(!recv.meta.control.empty());
const auto& ctrl = recv.meta.control;
if (ctrl.cmd == Control::BARRIER && !recv.meta.request) {
barrier_mu_.lock();
auto size = barrier_done_[recv.meta.app_id].size();
for (size_t customer_id = 0; customer_id < size; customer_id++) {
barrier_done_[recv.meta.app_id][customer_id] = true;
}
barrier_mu_.unlock();
barrier_cond_.notify_all(); // 这里解除了barrier
}
}
具体示意如下:
+
Scheduler | Worker
+ | +
| | |
| | |
+--------------------------------+ | +-----------------+
| | | | |
| | | | |
| | | | |
| v | | v
| receiver_thread_ | | receiver_thread_
| + | | |
| | | | |
v BARRIER | | BARRIER v |
Postoffice::Barrier +-----------------> | <---------------------+ Postoffice::Barrier |
+ | | + |
| | | | |
| | | | |
| | | | |
| v | | |
v | v |
barrier_cond_.wait ProcessBarrierCommand | barrier_cond_.wait |
| + | | |
| | | | |
| All Nodes OK | | | |
| | | | |
| +--------------+ | BARRIER | |
| | +----------------------------------------------> |
| | BARRIER | | | |
| +------------> | | | |
| | | | |
| | | | |
+<-------------------------------< | | <---------------+
| barrier_cond_.notify_all | | barrier_cond_.notify_all
v | v
+
手机如下:
至此,Postoffice的分析我们初步完成,其余功能我们将会结合 Van 和 Customer 在后续文章中分析。
★★★★★★关于生活和技术的思考★★★★★★
微信公众账号:罗西的思考
如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。
基于Parameter Server的可扩展分布式机器学习架构
Talk - Scaling Distributed Machine Learning with System and Algorithm Co-design 笔记
手机扫一扫
移动阅读更方便
你可能感兴趣的文章