目录
Horovod 是Uber于2017年发布的一个易于使用的高性能的分布式训练框架,在业界得到了广泛应用。
本系列将通过源码分析来带领大家了解 Horovod。这几篇介绍 horovod 如何运行在 spark 之上。本文是第九篇,介绍 horovod on spark 如何启动。
本系列其他文章如下:
[源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识
[源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入
[源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun背后做了什么
[源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver
[源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架
[源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构
[源码解析] 深度学习分布式训练框架 horovod (7) --- DistributedOptimizer
[源码解析] 深度学习分布式训练框架 horovod (8) --- on spark
首先,我们还是要祭出架构图,这样大家可以按图索骥。
总体来说,Horovod on Spark 的总体逻辑分为以下阶段:
我们下面就具体看看如何启动。
本部分主要逻辑是:启动 SparkDriverService 服务,利用 _make_spark_thread 启动 Spark task,然后 horovod 会等待启动结束。
SparkDriverService 继承了 driver_service.BasicDriverService,所以其内部启动了一个 socket server,可以进行网络交互。
Horovod 利用 SparkDriverService 来和 Spark executor(通过其中运行的SparkTaskService)交互,比如收集信息,让 spark 启动训练job等等。这是一个 RPC 机制。
具体 SparkDriverService 的功能可以参见其内部处理的各种 Request,比如
和前文介绍的 HorovodRunDriverService 有些类似。
其中,其成员变量 _fn 就是训练函数,以后当 SparkTaskService 请求代码的时候,就通过 CodeResponse 把 _fn 直接发送回去。这样就解决了代码发布问题。
class SparkDriverService(driver_service.BasicDriverService):
NAME = 'driver service'
def __init__(self, initial_np, num_proc, fn, args, kwargs, key, nics):
super(SparkDriverService, self).__init__(num_proc,
SparkDriverService.NAME,
key, nics)
self._initial_np = initial_np
self._fn = fn # 保存用户代码
self._args = args # 用户参数
self._kwargs = kwargs
self._key = key
self._nics = nics # 网卡信息
self._ranks_to_indices = {}
self._spark_job_failed = False
self._lock = threading.Lock()
self._task_shutdown = threading.Event()
def _handle(self, req, client_address):
if isinstance(req, TaskHostHashIndicesRequest): # 获取 task host 地址
return TaskHostHashIndicesResponse(self._task_host_hash_indices[req.host_hash])
if isinstance(req, SetLocalRankToRankRequest): # 从 local rank 得到 rank 信息
self._lock.acquire()
try:
# get index for host and local_rank
indices = self._task_host_hash_indices[req.host]
index = indices[req.local_rank]
values = list(self._ranks_to_indices.values())
prev_pos = values.index(index) if index in values else None
if prev_pos is not None:
prev_rank = list(self._ranks_to_indices.keys())[prev_pos]
del self._ranks_to_indices[prev_rank]
# memorize rank's index
self._ranks_to_indices[req.rank] = index
finally:
self._lock.release()
return SetLocalRankToRankResponse(index)
if isinstance(req, TaskIndexByRankRequest): # 是从 rank 获取到 task index
self._lock.acquire()
try:
return TaskIndexByRankResponse(self._ranks_to_indices[req.rank])
finally:
self._lock.release()
if isinstance(req, CodeRequest): # SparkTaskService会用来请求用户代码
return CodeResponse(self._fn, self._args, self._kwargs)
if isinstance(req, WaitForTaskShutdownRequest): # 等待任务结束
self._task_shutdown.wait()
return network.AckResponse()
return super(SparkDriverService, self)._handle(req, client_address)
在 Horovod.spark.run 之中,_make_spark_thread 建立了 thread。这里关键代码是:
mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
result = procs.mapPartitionsWithIndex(mapper).collect()
mapPartitionsWithIndex 这句代码会促使 Spark 在多个 Executor 之中运行 mapper 函数,并且得到运行结果。
即创建 settings.num_proc
个 Spark tasks,每个 task 会运行 mapper(_task_fn), 外部的 run 函数会等待这些执行结果。其实如果需要使用RDD,也许可以使用 foreachPartition
,这样每个结点上将会在内存中持有RDD的一个分区。
def _make_spark_thread(spark_context, spark_job_group, driver, result_queue,
settings, use_gloo, is_elastic):
"""Creates `settings.num_proc` Spark tasks in a parallel thread."""
def run_spark():
"""Creates `settings.num_proc` Spark tasks, each executing `_task_fn` and waits for them to terminate."""
try:
spark_context.setJobGroup(spark_job_group, "Horovod Spark Run", interruptOnCancel=True)
procs = spark_context.range(0, numSlices=settings.max_np if settings.elastic else settings.num_proc)
# We assume that folks caring about security will enable Spark RPC encryption,
# thus ensuring that key that is passed here remains secret.
mapper = _make_mapper(driver.addresses(), settings, use_gloo, is_elastic)
# 促使 Spark 在多个 Executor 之中运行 mapper 函数,并且得到运行结果
result = procs.mapPartitionsWithIndex(mapper).collect()
result_queue.put(result)
except:
driver.notify_spark_job_failed()
raise
spark_thread = in_thread(target=run_spark, daemon=False)
return spark_thread
启动了 spark task 之后,horovod 主进程会调用如下来等待 task 全部 启动完成。
# wait for all tasks to register, notify them and initiate task-to-task address registration
_notify_and_register_task_addresses(driver, settings)
即,run 函数中,当 _make_spark_thread 之后,horovod 主进程调用 _notify_and_register_task_addresses,从而调用 driver.wait_for_initial_registration(settings.start_timeout) ,进行总体等待。
等待的内容是:等待所有 num_proc tasks 来注册。当所有 spark thread 都ready 之后,主 horovod 进程会继续运行。
在 horovod 主进程之中,会使用 _notify_and_register_task_addresses
来等待这些 spark task 来注册,从而调用 driver.wait_for_initial_registration(settings.start_timeout) ,进行总体等待。
注意,同时发送注册请求之后, spark task 自己也调用 task.wait_for_initial_registration 等待 horovod 再通知下一阶段的启动。
而在horovod 主进程的 _notify_and_register_task_addresses 其实也很复杂:
具体代码如下:
def _notify_and_register_task_addresses(driver, settings, notify=True):
# wait for num_proc tasks to register
# 等待task来注册,需要等待 num_proc 个task
driver.wait_for_initial_registration(settings.start_timeout)
def notify_and_register(index): # 注册task,并且通知各个 task 开始下一步
task_client = task_service.SparkTaskClient(index,
driver.task_addresses_for_driver(index),
settings.key, settings.verbose)
if notify:
task_client.notify_initial_registration_complete()
next_task_index = (index + 1) % settings.num_proc
next_task_addresses = driver.all_task_addresses(next_task_index)
task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)
for index in driver.task_indices():
in_thread(notify_and_register, (index,)) #在thread之中启动task
driver.wait_for_task_to_task_address_updates(settings.start_timeout)
我们目前只能看其第一步 “等待注册”。
在这里 SparkDriverSerivce 首先等待所有 spark executor 注册。
在 class BasicDriverService(network.BasicService): 有如下代码,可以看到,只有全部 _num_proc 注册完成,当所有 spark thread 都ready 之后,主 horovod 进程会继续运行。
这里关键是:while len(self._all_task_addresses) < self._num_proc
就是等待 self._all_task_addresses 的数目达到 _num_proc。
class BasicDriverService(network.BasicService):
def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
# 等待 self._all_task_addresses 的数目达到 _num_proc
while len(self._all_task_addresses) < self._num_proc:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('tasks to start')
finally:
self._wait_cond.release()
关于等待代码,我们要做一下特殊说明,具体看图。
这里有两套 wait_for_initial_registration。可以认为是两套 barrier。
就是:
在 run 函数中,当 _make_spark_thread 之后,horovod 主进程调用 _notify_and_register_task_addresses,从而调用 driver.wait_for_initial_registration(settings.start_timeout) ,进行总体等待。
等待的内容是:等待所有 num_proc tasks 来注册。当所有 spark thread 都ready 之后,主 horovod 进程会继续运行。这里关键是:
while len(self._all_task_addresses) < self._num_proc
就是等待 self._all_task_addresses 的数目达到 _num_proc。
def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
while len(self._all_task_addresses) < self._num_proc:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('tasks to start')
finally:
self._wait_cond.release()
在 BasicDriverService 之中,如果收到了 spark executor 的注册请求就进行处理,这里最重要是:
self._all_task_addresses[req.index] = req.task_addresses
当所有的 spark executor 都注册了,这里就等待成功。
每个 spark thread 在 _task_fn 之中运行,就是在 spark task 之中运行。这里也可以看出来是 Spark task 的一个总体流程:
register_task
;task.wait_for_initial_registration(settings.start_timeout)
;wait_for_command_termination
来等待结束;task.wait_for_initial_registration 会等待 self._initial_registration_complete = True 这个条件,就是等待 register_task
注册完成。
每个 Spark Executor 都有一个 SparkTaskService,所以 每个spark task 都有自己的 _initial_registration_complete。
hovorod.run 主进程会逐一通知每个 SparkTaskService 的 _initial_registration_complete。
即,哪个 SparkTaskService 好了,就通知哪个 SparkTaskService 的 _initial_registration_complete。这样,这个 SparkTaskService 就可以正式运行了。
总体等待流程具体如图,数字就是执行顺序:
SparkDriverSerivce 调用 driver.wait_for_initial_registration 来等待 SparkTaskSerivce 的注册,这是 barrier 1;
SparkTaskSerivce 1 进行注册,然后 SparkTaskSerivce 1 自己也调用 task.wait_for_initial_registration 等待 horovod 再通知下一阶段的启动,这是 barrier 2;
SparkTaskSerivce 2 进行注册,然后 SparkTaskSerivce 2 自己也调用 task.wait_for_initial_registration 等待 horovod 再通知下一阶段的启动,这是 barrier 2;
hovorod.run 主进程在发现所有 task 都注册之后,barrier 1 等待结束,会逐一通知每个 SparkTaskService 的 _initial_registration_complete。只有 4 完成之后,两个 SparkTaskSerivce 才能继续执行 5,6;
SparkTaskSerivce 1 对于 barrier 2 等待结束,继续执行;
SparkTaskSerivce 2 对于 barrier 2 等待结束,继续执行;
SparkTaskSerivce 1 SparkTaskSerivce 2 SparkDriverSerivce + + +
| | |
| | |
| | |
| | | 1
| | |
| | |
| | v
| |
| | +--------------------------------------+
| | | barrier 1 |
| | 2 | |
| 3 +-------> | |
| | | |
+-----------------------------------> | driver.wait_for_initial_registration |
| | | |
| | | |
| | | |
| | +--------------------+-----------------+
| | |
| | |</code></pre>
+-----------+----------------------+ | 4 |
|barrier 2 | <---------------------------------+
| | | |
|task.wait_for_initial_registration| | |
| | | |
+-----------+----------------------+ | |
| | |
| +-------------+----------------------+ |
| | barrier 2 | 4 |
| 6 | +<------+
| | task.wait_for_initial_registration | |
| | | |
| +-------------+----------------------+ |
| | |
| | |
| | 5 |
| | |
v v v
我们接下来详细介绍 task 启动内容 和 driver 后续工作。
本阶段我们详细介绍下 Spark Task 的启动过程。
这部分主要功能是:多线程在 spark executor 之中启动 spark task,每个spark task会运行_task_fn
函数,_task_fn
函数会运行一个 SparkTaskService,SparkTaskSerivce 会向 hovorod 主进程中的 SparkDriverTask 进行注册,并且等待下一步运行启动的指令;
此时程序(不是训练程序,而是 SparkTaskService)已经在 Spark Executor内部运行了。我们看看在 spark Executor 之中,是如何启动运行 SparkTaskService 的。
Horovod 在 thread 里面通过 _make_mapper 来让 Spark 运行 _task_fn。
def _make_mapper(driver_addresses, settings, use_gloo, is_elastic):
def _mapper(index, _):
yield _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic)
return _mapper
_task_fn
的作用是为了注册 horovod 进入到 spark task。即,在每一个 spark task (executor) 之中启动一个 SparkTaskService。
一定要注意:这些 SparkTaskService 是运行在 spark executor 之中,通过网络与 horovod 之中的 SparkDriverService 交互。
可以看到,_task_fn 的总体逻辑是:
具体如下:
def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic):
settings.key = key
hosthash = host_hash(salt='{}-{}'.format(index, time.time()) if is_elastic else None)
os.environ['HOROVOD_HOSTNAME'] = hosthash
# 启动 SparkTaskService,SparkTaskService本身包括一个socket server,可以和driver交互
task = task_service.SparkTaskService(index, settings.key, settings.nics,...)
try:
driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
# 向 horovod 中的 Driver 注册
driver_client.register_task(index, task.addresses(), hosthash)
# 这里依然运行在spark task之中,但因为不是SparkTaskService,所以只是做协助工作,最后静静等待
if not is_elastic:
# 等待下一步启动的开始指示
task.wait_for_initial_registration(settings.start_timeout)
task_indices_on_this_host = driver_client.task_host_hash_indices(hosthash)
local_rank_zero_index = task_indices_on_this_host[0]
else:
local_rank_zero_index = None
if is_elastic:
...... # 后续文章会介绍
elif use_gloo or index == local_rank_zero_index:
# Either Gloo or first task with MPI.
# 使用Gloo或者使用MPI的第一个task,让这个task做操作
task.wait_for_command_start(settings.start_timeout)
# 等待结束
task.wait_for_command_termination()
else:
# The other tasks with MPI need to wait for the first task to finish.
# 让其他的task等待第一个task结束
first_task_addresses = driver_client.all_task_addresses(local_rank_zero_index)
first_task_client = \
task_service.SparkTaskClient(local_rank_zero_index,
first_task_addresses, settings.key,
settings.verbose)
# 调用 task.wait_for_command_termination() 等待结束
first_task_client.wait_for_command_termination()
return task.fn_result()
finally:
task.shutdown()
再次强调如下代码:
task = task_service.SparkTaskService(index, settings.key, settings.nics,...)
每一个_task_fn 中都定义了一个 SparkTaskService,即每一个 Spark Executor 都会生成一个(或者多个) SparkTaskService,在 spark task 之中运行并且作用。
SparkTaskService 定义如下,因为继承了BasicTaskService,所以其内部最终也会启动一个 socket server,以便同 horovod 中的 SparkDriverService 交互:
class SparkTaskService(task_service.BasicTaskService):
NAME_FORMAT = 'task service #%d'
def __init__(self, index, key, nics, minimum_command_lifetime_s, verbose=0):
# on a Spark cluster we need our train function to see the Spark worker environment
# this includes PYTHONPATH, HADOOP_TOKEN_FILE_LOCATION and _HOROVOD_SECRET_KEY
env = os.environ.copy()
# we inject the secret key here
env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)
# we also need to provide the current working dir to mpirun_exec_fn.py
env['HOROVOD_SPARK_WORK_DIR'] = os.getcwd()
super(SparkTaskService, self).__init__(SparkTaskService.NAME_FORMAT % index,
index, key, nics, env, verbose)
self._key = key
self._minimum_command_lifetime_s = minimum_command_lifetime_s
self._minimum_command_lifetime = None
SparkTaskService 的基本功能如下。
具体代码如下:
def _run_command(self, command, env, event,
stdout=None, stderr=None, index=None,
prefix_output_with_timestamp=False):
# 在 spark 之中启动训练job
super(SparkTaskService, self)._run_command(command, env, event,
stdout, stderr, index,
prefix_output_with_timestamp)
if self._minimum_command_lifetime_s is not None:
self._minimum_command_lifetime = timeout.Timeout(self._minimum_command_lifetime_s,
message='Just measuring runtime')
def _handle(self, req, client_address):
# 返回资源
if isinstance(req, ResourcesRequest):
return ResourcesResponse(self._get_resources())
# 获取 task 地址
if isinstance(req, GetTaskToTaskAddressesRequest):
next_task_index = req.task_index
next_task_addresses = req.all_task_addresses
# We request interface matching to weed out all the NAT'ed interfaces.
next_task_client = \
SparkTaskClient(next_task_index, next_task_addresses,
self._key, self._verbose,
match_intf=True)
return GetTaskToTaskAddressesResponse(next_task_client.addresses())
return super(SparkTaskService, self)._handle(req, client_address)
def _get_resources(self):
# 返回 spark 资源
if LooseVersion(pyspark.__version__) >= LooseVersion('3.0.0'):
task_context = pyspark.TaskContext.get()
if task_context:
return task_context.resources()
else:
print("Not running inside Spark worker, no resources available")
return dict()
def wait_for_command_termination(self):
"""
Waits for command termination. Ensures this method takes at least
self._minimum_command_lifetime_s seconds to return after command started.
"""
try:
# 等待命令执行结束
return super(SparkTaskService, self).wait_for_command_termination()
finally:
# command terminated, make sure this method takes at least
# self._minimum_command_lifetime_s seconds after command started
# the client that started the command needs some time to connect again
# to wait for the result (see horovod.spark.driver.rsh).
if self._minimum_command_lifetime is not None:
time.sleep(self._minimum_command_lifetime.remaining())
下一步代码就是用来向 Driver 注册 本 task。
driver_client.register_task(index, task.addresses(), hosthash)
注册具体通过如下完成,这里调用了 network.py 的 _send 函数,就是通过 socket,spark executor 和 horovod driver 进行了网络交互:
class BasicDriverClient(network.BasicClient):
def register_task(self, index, task_addresses, host_hash):
self._send(RegisterTaskRequest(index, task_addresses, host_hash))
我们先来到 Horovod 中运行的 Driver来看看(下一节内容,这里提前看看)。
在 BasicDriverService 之中,如果收到了RegisterTaskRequest请求就进行处理,这里最重要是:
self._all_task_addresses[req.index] = req.task_addresses
这样,self._all_task_addresses 的数目就增加了。
而我们之前提到了,horovod 正在 driver.wait_for_initial_registration 上面等待,其关键是:
while len(self._all_task_addresses) < self._num_proc
如果self._all_task_addresses
的数目达到了_num_proc
,driver.wait_for_initial_registration 就结束了,就顺利执行。
具体处理 RegisterTaskRequest 的代码如下,BasicDriverService 之中有各种成员变量,用来维护各种所需信息,我们在前文 [原创 源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver 中已经详细讲解过,_handle函数的RegisterTaskRequest 处理就是用来更新这些成员变量:
class BasicDriverService(network.BasicService):
def _handle(self, req, client_address):
if isinstance(req, RegisterTaskRequest):
self._wait_cond.acquire()
try:
self._all_task_addresses[req.index] = req.task_addresses
# Just use source address for service for fast probing.
self._task_addresses_for_driver[req.index] = \
self._filter_by_ip(req.task_addresses, client_address[0])
# Remove host hash earlier registered under this index.
if req.index in self._task_index_host_hash:
earlier_host_hash = self._task_index_host_hash[req.index]
if earlier_host_hash != req.host_hash:
self._task_host_hash_indices[earlier_host_hash].remove(req.index)
# Make index -> host hash map.
self._task_index_host_hash[req.index] = req.host_hash
# Make host hash -> indices map.
if req.host_hash not in self._task_host_hash_indices:
self._task_host_hash_indices[req.host_hash] = []
self._task_host_hash_indices[req.host_hash].append(req.index)
# TODO: this sorting is a problem in elastic horovod
self._task_host_hash_indices[req.host_hash].sort()
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()
前面提到了,当 spark task 向 driver 发送注册请求之后,Spark task 通过 task.wait_for_initial_registration(settings.start_timeout) 来等待下一步启动的开始指示。就是 driver 认为你一景注册完成了,可以开始进入下一步了。
task.wait_for_initial_registration 会等待 self._initial_registration_complete = True 这个条件,就是等待 register_task
注册完成。
class BasicTaskService(network.BasicService):
def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
while not self._initial_registration_complete:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('tasks to start')
finally:
self._wait_cond.release()
每个 Spark Executor 都有一个 SparkTaskService,所以 每个spark task 都有自己的 _initial_registration_complete。
hovorod.run 主进程会逐一通知每个 SparkTaskService 的 _initial_registration_complete。即,哪个 SparkTaskService 好了,就通知哪个 SparkTaskService 的 _initial_registration_complete。
hovorod.run 主进程 是通过发送 NotifyInitialRegistrationCompleteRequest完成这一步的。
def notify_initial_registration_complete(self):
self._send(NotifyInitialRegistrationCompleteRequest())
BasicTaskService 在等待 NotifyInitialRegistrationCompleteRequest,如果收到了,就设置为 True,这样wait_for_initial_registration 就等待结束了。
if isinstance(req, NotifyInitialRegistrationCompleteRequest):
self._wait_cond.acquire()
try:
self._initial_registration_complete = True
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()
就说明当本 thread 注册在 horovod 之后,就算本 spark thread 启动成功了。
+-------------------------------------+ +----------------------------------------------------+
| Horovod Main thread | | Spark Executor |
| | | _task_fn |
| | | + |
| | | | |
| | | | |
| | | v |
| +-------------------------------+ | | +---------------------+------------------------+ |
| | SparkDriverService | | | | SparkTaskService | |
| | | | | | + | |
| | | | 1 register | | | | |
| | self._all_task_addresses <----------------------------------------+ | |
| | | | | | | | |
| | + | | | | | | |
| | | | | | | | | |
| | | 3 | | | | | | |
| | | | | | | | 2 | |
| | v | | | | | | |
| | self._wait_cond.notify_all() | | | | | | |
| | + | | | | v | |
| | | | | | | +---------+---------------------------+ | |
| | | | | | | | | | |
| | | | | | | | task.wait_for_initial_registration | | |
| | | | | | | | | | |
| | | | | | | +-------------------------------------+ | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | | | | | | | |
| | v | | | | | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| +-------------------------------+ | | +----------------------------------------------+ |
+-------------------------------------+ +----------------------------------------------------+
手机如下:
本阶段的作用是:Horovod 收到所有 task 结束的信息之后,通知各个 task,进入下一阶段。
前面提到。在 horovod 主进程之中,会使用 _notify_and_register_task_addresses
来等待这些 spark task 来注册,从而调用 driver.wait_for_initial_registration(settings.start_timeout) ,进行总体等待。
注意,同时发送注册请求之后, spark task 自己也调用 task.wait_for_initial_registration 等待horovod 再通知下一阶段的启动。
而 _notify_and_register_task_addresses 中其实也很复杂:
调用 driver.wait_for_initial_registration 等待task来注册;(目前这一步已经完成)
利用 notify_and_register 注册task,并且通知各个 task 开始下一步;(我们这里进入后面这两步)
利用 driver.wait_for_task_to_task_address_updates 再次确认下所有 task 都OK;
def _notify_and_register_task_addresses(driver, settings, notify=True):
# wait for num_proc tasks to register
driver.wait_for_initial_registration(settings.start_timeout)
def notify_and_register(index):
# 注册task,并且通知各个 task 开始下一步
task_client = task_service.SparkTaskClient(index,
driver.task_addresses_for_driver(index),
settings.key, settings.verbose)if notify:
task_client.notify_initial_registration_complete()
next_task_index = (index + 1) % settings.num_proc
next_task_addresses = driver.all_task_addresses(next_task_index)
task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)
for index in driver.task_indices():
in_thread(notify_and_register, (index,)) # 注册task,并且通知各个 task 开始下一步
# 再次确认下所有 task 都OK
driver.wait_for_task_to_task_address_updates(settings.start_timeout)
可以看到 notify_and_register 的作用就是:
调用 task_client.notify_initial_registration_complete() 通知 spark task 注册成功了,这样就让所有等待 task.wait_for_initial_registration 的 spark executor 一起运行下一阶段。
调用 driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses) 来让 Driver 完成注册。
def wait_for_task_to_task_address_updates(self, timeout):
self._wait_cond.acquire()
try:
while len(self._task_addresses_for_tasks) < self._initial_np:
self.check_for_spark_job_failure()
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
finally:
self._wait_cond.release()
这里会再次确认所有 spark task 都OK。
def wait_for_task_to_task_address_updates(self, timeout):
self._wait_cond.acquire()
try:
while len(self._task_addresses_for_tasks) < self._initial_np:
self.check_for_spark_job_failure()
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
finally:
self._wait_cond.release()
在 Spark task 之中,如果收到了下一步启动指示之后,会调用 wait_for_command_termination 进行等待。
其实,这一步也就意味着 spark exector 自己本身的逻辑任务结束了,因为以后都是 SparkTaskService 自己独立完成的动作,它来负责训练代码的启动。既然 _task_fn
的逻辑任务已经结束,那么静静地等待即可。
在 horovod-master/horovod/spark/task/task_service.py
def wait_for_command_termination(self):
"""
Waits for command termination. Ensures this method takes at least
self._minimum_command_lifetime_s seconds to return after command started.
"""
try:
return super(SparkTaskService, self).wait_for_command_termination()
finally:
# command terminated, make sure this method takes at least
# self._minimum_command_lifetime_s seconds after command started
# the client that started the command needs some time to connect again
# to wait for the result (see horovod.spark.driver.rsh).
if self._minimum_command_lifetime is not None:
time.sleep(self._minimum_command_lifetime.remaining())
在 horovod-master/horovod/runner/common/service/task_service.py 中可以看到,就是等待训练代码所在的 thread 结束。
def wait_for_command_termination(self):
self._command_thread.join() # 马上会说明
这里对 _command_thread 略作说明。
在 SparkTaskService 处理 RunCommandRequest 时候,运行 Command 的 thread 就是被赋值为 _command_thread。
class BasicTaskService(network.BasicService):
def _handle(self, req, client_address):
if isinstance(req, RunCommandRequest): # 运行命令请求
self._wait_cond.acquire()
try:
if self._command_thread is None:
if self._command_env:
env = self._command_env.copy()
self._add_envs(env, req.env)
req.env = env
self._command_abort = threading.Event()
self._command_stdout = Pipe() if req.capture_stdout else None
self._command_stderr = Pipe() if req.capture_stderr else None
# 配置各种参数信息
args = (req.command, req.env, self._command_abort,
self._command_stdout, self._command_stderr,
self._index,
req.prefix_output_with_timestamp)
# 启动一个新线程来运行命令
self._command_thread = in_thread(self._run_command, args)
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()
逻辑如下:
+-------------------------------------+ +----------------------------------------------------+
| Horovod Main thread | | Spark Executor |
| | | _task_fn |
| | | + |
| | | | |
| | | | |
| | | v |
| +-------------------------------+ | | +---------------------+------------------------+ |
| | SparkDriverService | | | | SparkTaskService | |
| | | | | | + | |
| | | | 1 register | | | | |
| | self._all_task_addresses <----------------------------------------+ | |
| | | | | | | | |
| | + | | | | | | |
| | | | | | | | | |
| | | 3 | | | | | | |
| | | | | | | | 2 | |
| | v | | | | | | |
| | self._wait_cond.notify_all() | | | | | | |
| | + | | | | v | |
| | | | + + + +---------+---------------------------+ | |
| | | 4 | RegistrationComplete | | | |
| | | +-----------------+-------------+--+---> | task.wait_for_initial_registration | | |
| | | | | | | | | | |
| | | | | | | +---------+---------------------------+ | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | 5 | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | v | |
| | | | | | | wait_for_command_termination | |
| | | | 6 | RunCommand | | + | |
| | | | | | | | | |
| | +-----------------------------------------------> | 7 | |
| | | | | | | v | |
| | v | | | | self._command_thread.join() | |
| | | | | | | |
| | | | | | | |
| | | | | | | |
| +-------------------------------+ | | +----------------------------------------------+ |
+-------------------------------------+ +----------------------------------------------------+
手机如下:
至此,第一阶段完成,我们下一篇继续,敬请期待。
总体来说,Horovod on Spark 的总体逻辑分为以下阶段:
本文介绍了前三个阶段,即启动阶段。下文介绍后续两个阶段,敬请期待。
★★★★★★关于生活和技术的思考★★★★★★
微信公众账号:罗西的思考
如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。
手机扫一扫
移动阅读更方便
你可能感兴趣的文章