[源码解析] TensorFlow 分布式之 ParameterServerStrategy V1
阅读原文时间:2022年05月11日阅读:2

[源码解析] TensorFlow 分布式之 ParameterServerStrategy V1

目录

本章我们看看 ParameterServerStrategy,就是第一版代码。研究这个是因为目前工业界还有很多公司在使用,而且其内部机制也比较清晰易懂,值得我们分析。

安利两个github,都是非常好的学习资料,推荐。

https://github.com/yuhuiaws/ML-study

https://github.com/Jack47/hack-SysML

另外推荐西门宇少的最新大作让Pipeline在Transformer LM上沿着Token level并行起来——TeraPipe

本系列其他文章是:

[翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"

[源码解析] TensorFlow 分布式环境(1) --- 总体架构

[源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

[源码解析] TensorFlow 分布式环境(4) --- WorkerCache

[源码解析] TensorFlow 分布式环境(5) --- Session

[源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑

[源码解析] TensorFlow 分布式环境(8) --- 通信机制

[翻译] 使用 TensorFlow 进行分布式训练

[源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇

[源码解析] TensorFlow 之 分布式变量

[源码解析] TensorFlow 分布式之 MirroredStrategy

[源码解析] TensorFlow 分布式之 MirroredStrategy 分发计算

参数服务器训练是一种常见的数据并行方法,用于在多台机器上扩展机器学习模型。一个参数服务器训练集群由工作者和参数服务器组成。变量是在参数服务器上创建的,它们在每个步骤中被工作者读取和更新。默认情况下,工作者独立地读取和更新这些变量,而不互相同步。在这种配置下,它被称为异步训练。

Tensorflow 支持两种方式实现 parameter server:低阶 API 创建 parameter server 集群方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。ParameterServerStrategyV1 的主要作用就是把变量分布在 ps 之上,计算分布在 worker 之上。我们将从几个方面来研究:

  • 如何与集群建立连接。
  • 如何获取数据。
  • 如何生成变量。
  • 如何运行。

1.1 总体逻辑

ParameterServerStrategyV1 是一个异步的多工作者参数服务器 tf.distribution 策略。这个策略需要两个角色:工作者(worker)和参数服务器。变量和对这些变量的更新将被分配给参数服务器,其他操作则被分配给 工作者。

当每个工作者有一个以上的 GPU 时,操作将被复制到所有 GPU 上,但变量不会被复制,每个工作者共享一个共同的视图,以确定一个变量被分配到哪个参数服务器。缺省状态下,ParameterServerStrategyV1 使用 TFConfigClusterResolver 来查找多工作者的配置,这需要一个 'TF_CONFIG' 环境变量,并且 'TF_CONFIG' 必须有一个集群规格。

该类假设每个工作者独立运行相同的代码,而但参数服务器则运行一个标准服务器。这意味着,虽然每个工作者将在所有 GPU 上同步计算一个梯度更新,但工作器之间的更新是异步进行的。即使只有 CPU 或一个 GPU,也应该调用"call_for_each_replica(fn, …)" 来进行任何可能跨副本复制的操作(即多个 GPU)。当定义"fn" 时,需要注意以下几点:

  1. 一般不建议在策略的作用域(scope)内再打开一个设备作用域。设备作用域(即调用 tf.device)将合并或者覆盖操作的设备,但不会改变变量的设备。
  2. 也不建议在策略的作用域(scope)内再打开一个 colocation 作用域(strategy.extended.colocate_vars_with),对于 colocating variables,则使用strategy.extended.colocate_vars_with 。协同操作可能会产生设备分配冲突。

注意:该策略仅适用于 Estimator API。当你创建"RunConfig"时,把这个策略的一个实例传递给"experimental_distribute"参数。而这个"RunConfig"的实例应该被传递给"Estimator"实例,然后在这个"Estimator" 实例上调用"train_and_evaluate"。

1.2 使用

ParameterServerStrategy 的使用样例如下:

  strategy = tf.distribute.experimental.ParameterServerStrategy()
  run_config = tf.estimator.RunConfig(
      experimental_distribute.train_distribute=strategy)
  estimator = tf.estimator.Estimator(config=run_config)
  tf.estimator.train_and_evaluate(estimator,...)

1.3 定义

ParameterServerStrategyV1 的定义和初始化没有什么可以研究的,主要是使用 ParameterServerStrategyExtended 完成初始化,摘录如下:

@tf_export(v1=["distribute.experimental.ParameterServerStrategy"])
class ParameterServerStrategyV1(distribute_lib.StrategyV1):
  def __init__(self, cluster_resolver=None):
  """Initializes this strategy with an optional cluster_resolver.

    Args:
      cluster_resolver: Optional
        tf.distribute.cluster_resolver.ClusterResolver object. Defaults to a
        tf.distribute.cluster_resolver.TFConfigClusterResolver.
  """
    if cluster_resolver is None:
      cluster_resolver = TFConfigClusterResolver()
    super(ParameterServerStrategyV1, self).__init__(
        ParameterServerStrategyExtended(
            self, cluster_resolver=cluster_resolver))
    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
      "ParameterServerStrategy")

ParameterServerStrategyExtended 派生自 distribute_lib.StrategyExtendedV1,提供了可以分布式感知的算法附加 API。

class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of ParameterServerStrategy and CentralStorageStrategy."""

  def __init__(self,
               container_strategy,
               cluster_resolver=None,
               compute_devices=None,
               parameter_device=None):
    super(ParameterServerStrategyExtended, self).__init__(container_strategy)
    self._initialize_strategy(
        cluster_resolver=cluster_resolver,
        compute_devices=compute_devices,
        parameter_device=parameter_device)

    # We typically don't need to do all-reduce in this strategy.
    self._cross_device_ops = (
        cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))

2.1 初始化

这部分完成了获取集群信息的工作。_initialize_strategy 依据 spec 不同选择启动本地还是多工作者,我们只研究多工作者的情况。

  def _initialize_strategy(self,
                           cluster_resolver=None,
                           compute_devices=None,
                           parameter_device=None):
    if cluster_resolver and cluster_resolver.cluster_spec():
      self._initialize_multi_worker(cluster_resolver)
    else:
      self._initialize_local(
          compute_devices, parameter_device, cluster_resolver=cluster_resolver)

_initialize_multi_worker 这里会做一系列配置,比如:

  • 获取 gpu 数量。

  • 从集群配置之中获取信息。

  • 设定工作设备和输入设备名称。

  • 设定计算设备列表。

  • 分配设备策略。

  • 得到参数服务器设备列表。

    def _initialize_multi_worker(self, cluster_resolver):
    """Initialize devices for multiple workers.

    It creates variable devices and compute devices. Variables and operations
    will be assigned to them respectively. We have one compute device per
    replica. The variable device is a device function or device string. The
    default variable device assigns variables to parameter servers in a
    round-robin fashion.
    
    Args:
      cluster_resolver: a descendant of ClusterResolver object.
    
    Raises:
      ValueError: if the cluster doesn't have ps jobs.

    """
    # 获取gpu数量
    if isinstance(cluster_resolver, TFConfigClusterResolver):
    num_gpus = context.num_gpus()
    else:
    num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)

    # Save the num_gpus_per_worker for configure method.
    self._num_gpus_per_worker = num_gpus
    
    # 从集群配置之中获取信息
    cluster_spec = cluster_resolver.cluster_spec()
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_id
    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
    assert cluster_spec.as_dict()
    
    # 设定工作设备和输入设备名称
    self._worker_device ="/job:%s/task:%d" % (task_type, task_id)
    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
    
    # Define compute devices which is a list of device strings and one for each
    # replica. When there are GPUs, replicate operations on these GPUs.
    # Otherwise, place operations on CPU.
    
    # 设定计算设备列表
    if num_gpus > 0:
      compute_devices = tuple(
        "%s/device:GPU:%d" % (self._worker_device, i)
          for i in range(num_gpus))
    else:
      compute_devices = (self._worker_device,)
    
    self._compute_devices = [
        device_util.canonicalize(d) for d in compute_devices]
    
    # In distributed mode, place variables on ps jobs in a round-robin fashion.
    # Note that devices returned from replica_device_setter are not
    # canonical and therefore we don't canonicalize all variable devices to
    # make them consistent.
    # TODO(yuefengz): support passing a strategy object to control variable
    # assignment.
    
    # 分配设备策略,变量放到哪个设备上
    num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
    self._variable_device = device_setter.replica_device_setter(
        ps_tasks=num_ps_replicas, # 参数服务器
        worker_device=self._worker_device, # 工作设备
        merge_devices=True,
        cluster=cluster_spec)
    
    # The _parameter_devices is needed for the parameter_devices property
    # and is a list of all variable devices. Here parameter devices are all
    # tasks of the"ps" job.
    
    # 得到参数服务器设备列表
    self._parameter_devices = tuple(map("/job:ps/task:{}".format,
                                        range(num_ps_replicas)))
    
    # Add a default device so that ops without specified devices will not end up
    # on other workers.
    self._default_device = self._worker_device
    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                task_id)
    self._cluster_spec = cluster_spec
    self._task_type = task_type
    self._task_id = task_id

2.2 分配设备

我们接下来看看如何分配设备。在目前状态下,分配设备就是给每个计算图指定一个设备名字,后续真正运行时候,系统会根据这个设备名字再具体进行分配。

2.2.1 replica_device_setter

replica_device_setter 返回一个设备函数 device function(或者说是策略),当为副本建立计算图时候,此策略将提供信息,该信息用来指导计算图应该分配到哪个设备上。设备函数与 with tf.device(device_function) 一起使用。当构建时候,Operation 会自动被映射到设备函数提供的设备之上。设备约束首先从最内部的上下文添加,然后向外工作。如果 'cluster' 为 'None' 且 'ps_tasks' 为 0,则返回的函数为 no-op。否则,'ps_tasks' 的值派生自 'cluster'。如果'ps_tasks' 数值不为0,则后续变量就放到ps_device之上,否则放到 worker_device 之上。

@tf_export(v1=["train.replica_device_setter"])
def replica_device_setter(ps_tasks=0,
                          ps_device="/job:ps",
                          worker_device="/job:worker",
                          merge_devices=True,
                          cluster=None,
                          ps_ops=None,
                          ps_strategy=None):
"""Return a device function to use when building a Graph for replicas.

  Device Functions are used in with tf.device(device_function): statement to
  automatically assign devices to Operation objects as they are constructed,
  Device constraints are added from the inner-most context first, working
  outwards. The merging behavior adds constraints to fields that are yet unset
  by a more inner context. Currently the fields are (job, task, cpu/gpu).

  If cluster is None, and ps_tasks is 0, the returned function is a no-op.
  Otherwise, the value of ps_tasks is derived from cluster.

  Args:
    ps_tasks: Number of tasks in the ps job.  Ignored if cluster is
      provided.
    ps_device: String.  Device of the ps job.  If empty no ps job is used.
      Defaults to ps.
    worker_device: String.  Device of the worker job.  If empty no worker
      job is used.
    merge_devices: Boolean. If True, merges or only sets a device if the
      device constraint is completely unset. merges device specification rather
      than overriding them.
    cluster: ClusterDef proto or ClusterSpec.
    ps_ops: List of strings representing Operation types that need to be
      placed on ps devices.  If None, defaults to STANDARD_PS_OPS.
    ps_strategy: A callable invoked for every ps Operation (i.e. matched by
      ps_ops), that takes the Operation and returns the ps task index to
      use.  If None, defaults to a round-robin strategy across all ps
      devices.

  Returns:
    A function to pass to tf.device().

  Raises:
    TypeError if cluster is not a dictionary or ClusterDef protocol buffer,
    or if ps_strategy is provided but not a callable.
"""
  if cluster is not None:
    if isinstance(cluster, server_lib.ClusterSpec):
      cluster_spec = cluster.as_dict()
    else:
      cluster_spec = server_lib.ClusterSpec(cluster).as_dict()
    # Get ps_job_name from ps_device by stripping"/job:".
    ps_job_name = pydev.DeviceSpec.from_string(ps_device).job
    if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
      return None
    ps_tasks = len(cluster_spec[ps_job_name])

  if ps_tasks == 0:
    return None

  if ps_ops is None:
    ps_ops = list(STANDARD_PS_OPS)

  if ps_strategy is None:
    ps_strategy = _RoundRobinStrategy(ps_tasks)

  chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
                                  merge_devices, ps_ops, ps_strategy)
  return chooser.device_function

2.2.2 _RoundRobinStrategy

默认情况下,ps 任务上只放置变量 op,并且 placement strategy 是以 round-robin 机制在 ps tasks 之间进行分配。也可以采用比如 tf.contrib.training.GreedyLoadBalancingStrategy。

# To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
# jobs on hosts worker0, worker1 and worker2.
cluster_spec = {
  "ps": ["ps0:2222","ps1:2222"],
  "worker": ["worker0:2222","worker1:2222","worker2:2222"]}
with
tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
  # Build your graph
  v1 = tf.Variable(...)  # assigned to /job:ps/task:0
  v2 = tf.Variable(...)  # assigned to /job:ps/task:1
  v3 = tf.Variable(...)  # assigned to /job:ps/task:0
# Run compute

_RoundRobinStrategy 具体如下:

class _RoundRobinStrategy(object):
"""Returns the next ps task index for placement in round-robin order.

  This class is not to be used directly by users.  See instead
  replica_device_setter() below.
"""

  def __init__(self, num_tasks):
  """Create a new _RoundRobinStrategy.

    Args:
      num_tasks: Number of ps tasks to cycle among.
  """
    self._num_tasks = num_tasks
    self._next_task = 0

  def __call__(self, unused_op):
  """Choose a ps task index for the given Operation.

    Args:
      unused_op: An Operation to be placed on ps.

    Returns:
      The next ps task index to use for the Operation. Returns the next
      index, in the range [offset, offset + num_tasks).
  """
    task = self._next_task
    self._next_task = (self._next_task + 1) % self._num_tasks
    return task

2.2.3 _ReplicaDeviceChooser

replica_device_setter 返回的是 _ReplicaDeviceChooser.device_function。就是使用 _ps_strategy 来返回设备名字。这里会依据_ps_tasks的信息来决定变量放在 ps_device 之上还是worker_device之上。

class _ReplicaDeviceChooser(object):
"""Class to choose devices for Ops in a replicated training setup.

  This class is not to be used directly by users.  See instead
  replica_device_setter() below.
"""

  def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops,
               ps_strategy):
  """Create a new _ReplicaDeviceChooser.

    Args:
      ps_tasks: Number of tasks in the ps job.
      ps_device: String.  Name of the ps job.
      worker_device: String.  Name of the worker job.
      merge_devices: Boolean. Set to True to allow merging of device specs.
      ps_ops: List of strings representing Operation types that need to be
        placed on ps devices.
      ps_strategy: A callable invoked for every ps Operation (i.e. matched by
        ps_ops), that takes the Operation and returns the ps task index to
        use.
  """
    self._ps_tasks = ps_tasks
    self._ps_device = ps_device
    self._worker_device = worker_device
    self._merge_devices = merge_devices
    self._ps_ops = ps_ops
    self._ps_strategy = ps_strategy

  def device_function(self, op):
  """Choose a device for op.

    Args:
      op: an Operation.

    Returns:
      The device to use for the Operation.
  """
    # If we don't return early here, either merge_devices is True, or op.device
    # is empty (in which case merging is a no-op). So we can always merge below.
    if not self._merge_devices and op.device:
      return op.device

    current_device = pydev.DeviceSpec.from_string(op.device or"")

    # The ps_device will be used for specified ops (ps_ops) whenever it is
    # present and ps_tasks is non-zero. However, its task number will only be
    # set (using ps_strategy) if there is a job field in ps_device that won't be
    # changed by the job field (if present) in current_device.
    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
    if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
      ps_device = pydev.DeviceSpec.from_string(self._ps_device)

      current_job, ps_job = current_device.job, ps_device.job
      if ps_job and (not current_job or current_job == ps_job):
        # 这里使用了策略
        ps_device = ps_device.replace(task=self._ps_strategy(op))

      ps_device = ps_device.make_merged_spec(current_device)
      return ps_device.to_string()

    worker_device = pydev.DeviceSpec.from_string(self._worker_device or"")
    worker_device = worker_device.make_merged_spec(current_device)
    return worker_device.to_string()

设备相关的逻辑总结如下:

图 1 分配设备

初始化之后,ParameterServerStrategyExtended如下:

我们接下来看看如何获取训练数据。distribute_datasets_from_function 是调用基类 的 distribute_datasets_from_function,所以我们要看看 StrategyBase。

  def distribute_datasets_from_function(self, dataset_fn, options=None):
    if (options and options.experimental_replication_mode ==
        distribute_lib.InputReplicationMode.PER_REPLICA):
      raise NotImplementedError(
        "InputReplicationMode.PER_REPLICA"
        "is only supported in"
        "experimental_distribute_datasets_from_function"
        "of tf.distribute.MirroredStrategy")
    self._raise_pss_error_if_eager()
    super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
        dataset_fn=dataset_fn, options=options)

3.1 StrategyBase

distribute_datasets_from_function 作用是依靠调用 'dataset_fn' 来分发 tf.data.Dataset。用户传入的参数 dataset_fn 是一个输入函数。这个输入参数带有 InputContext 参数,并返回一个 tf.data.Dataset 实例。dataset_fn 得到的数据集应该是已按每个副本的批大小(即全局批大小除以同步副本的数量)进行分批次和分片的。Tf.distribute.Strategy.distribute_datasets_from_function 本身不会做分批次和分片操作。

dataset_fn 将在每个工作者的 CPU device 上调用并且会生成一个数据集,其中该工作者上的每个 replica 都会将一个输入 batch 移出队列(即,如果一个工作者有两个副本,则每个 step 之中,两个 batches 将会被从 Dataset 之中移出队列)。这种方法有多种用途。首先,它允许您指定自己的分批切分逻辑。(相比之下,tf.distribute.experimental_distribute_dataset 为您进行分批和分片。)例如,experimental_distribute_dataset 无法切分输入文件,则可以使用此方法来自定义手动切分数据集(避免experimental_distribute_dataset 中的慢回调行为)。在数据集无限大的情况下,分片可以通过依据随机种子的不同来创建数据集副本。另外,dataset_fn 应该使用 tf.distribute.InputContext 的实例来得到分批和输入分片的信息。

具体调用方式如下:

def per_worker_dataset_fn():
    return strategy.distribute_datasets_from_function(
        lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))

这里我们发现,distribute_datasets_from_function 则又回到了派生类 _distribute_datasets_from_function 方法。

def distribute_datasets_from_function(self, dataset_fn, options=None):
    return self._extended._distribute_datasets_from_function(dataset_fn, options)

3.2 _distribute_datasets_from_function

_distribute_datasets_from_function 则调用了 InputContext 来获取数据。

  def _distribute_datasets_from_function(self, dataset_fn, options):
    if self._cluster_spec:
      input_pipeline_id = multi_worker_util.id_in_cluster(
          self._cluster_spec, self._task_type, self._task_id)
      num_input_pipelines = multi_worker_util.worker_count(
          self._cluster_spec, self._task_type)
    else:
      input_pipeline_id = 0
      num_input_pipelines = 1

    input_context = distribute_lib.InputContext(
        num_input_pipelines=num_input_pipelines,
        input_pipeline_id=input_pipeline_id,
        num_replicas_in_sync=self._num_replicas_in_sync)

    return input_lib.get_distributed_datasets_from_function(
        dataset_fn,
        self._input_workers_with_options(options), [input_context],
        self._container_strategy(),
        options=options)

3.3 InputLib

这部分代码在 tensorflow/python/distribute/input_lib.py,主要就是获取 iterator。

def get_distributed_datasets_from_function(dataset_fn,
                                           input_workers,
                                           input_contexts,
                                           strategy,
                                           options=None):
"""Returns a distributed dataset from the given input function.

  This is a common function that is used by all strategies to return a
  distributed dataset. The distributed dataset instance returned is different
  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
  instances returned differ from each other in the APIs supported by each of
  them.

  Args:
    dataset_fn: a function that returns a tf.data.Dataset instance.
    input_workers: an InputWorkers object which specifies devices on which
        iterators should be created.
    input_contexts: A list of InputContext instances to be passed to call(s)
        to dataset_fn. Length and order should match worker order in
        worker_device_pairs.
    strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
    options: Default is None. tf.distribute.InputOptions used to control
        options on how this dataset is distributed.

  Returns:
    A distributed dataset instance.

  Raises:
    ValueError: if options.experimental_replication_mode and
    options.experimental_place_dataset_on_device are not consistent
"""
  if tf2.enabled():
    return DistributedDatasetsFromFunction(input_workers, strategy,
                                           input_contexts, dataset_fn, options)
  else:
    return DistributedDatasetsFromFunctionV1(input_workers, strategy,
                                             input_contexts, dataset_fn,
                                             options)

DistributedDatasetsFromFunctionV1 则会返回 DistributedIteratorV1,既然得到了 iterator,就可以从数据集之中获得数据了。

class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
"""Inputs created from dataset function."""

  def _make_initializable_iterator(self, shared_name=None):
  """Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
    del shared_name  # Unused
    # Eager mode generates already initialized iterators. Hence we cannot create
    # an initializable iterator.
    if context.executing_eagerly():
      raise ValueError("Cannot create initializable iterator in Eager mode."
                     "Please use iter() instead.")
    return self._get_iterator()

  def _make_one_shot_iterator(self):
  """Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
    # Graph mode with one shot iterator is disabled because we have to call
    # initialize on the iterator which is only required if we are using a
    # tf.distribute strategy.
    if not context.executing_eagerly():
      raise ValueError("Cannot create a one shot iterator. Please use"
                     "make_initializable_iterator() instead.")
    return self._get_iterator()

  def _get_iterator(self):
    iterators = _create_iterators_per_worker(self._datasets,
                                             self._input_workers, True,
                                             self._options)
    iterator = DistributedIteratorV1(self._input_workers, iterators,
                                     self._strategy,
                                     self._enable_get_next_as_optional)
    iterator._element_spec = self._element_spec  # pylint: disable=protected-access

    # When async eager is enabled, sometimes the iterator may not finish
    # initialization before passing to a multi device function, add a sync point
    # here to make sure all underlying iterators are initialized.
    if context.executing_eagerly():
      context.async_wait()

    return iterator

  def __iter__(self):
    if (ops.executing_eagerly_outside_functions() or
        ops.get_default_graph().building_function):
      return self._get_iterator()

    raise RuntimeError("__iter__() is only supported inside of tf.function"
                     "or when eager execution is enabled.")

4.1 StrategyBase

scope 就是调用基类的方法。

  def scope(self):
    self._raise_pss_error_if_eager()
    return super(ParameterServerStrategyV1, self).scope()

StrategyBase 的 scope 方法返回一个 Context manager,其使用当前策略来建立分布式变量,当进入 Strategy.scope 时会发生:

  • "strategy" 成为全局上下文内的 "当前" strategy 。在这个作用域内,tf.distribute.get_strategy() 将返回此策略。在此范围之外,它返回默认的无操作策略。

  • 进入此作用域也会进入"cross-replica context"。

  • "scope"内的变量创建被策略拦截。每个策略都定义了它想要如何影响变量的创建。像 'MirroredStrategy'、'TPUStrategy' 和 'MultiWorkerMirroredStrategy' 这样的同步策略会在每个副本上创建复制的变量,而 'ParameterServerStrategy' 在参数服务器上创建变量。这是使用自定义的 tf.variable_creator_scope 完成的。

  • 在某些策略中,还可以输入默认的设备作用域:比如在"MultiWorkerMirroredStrategy"中,为每个工作者输入默认的设备作用域 "/CPU:0"。

    def scope(self):
    """Context manager to make the strategy current and distribute variables.

    This method returns a context manager, and is used as follows:
    
    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0","GPU:1"])
    >>> # Variable created inside scope:
    >>> with strategy.scope():
    ...   mirrored_variable = tf.Variable(1.)
    >>> mirrored_variable
    MirroredVariable:{
      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
      1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
    }
    >>> # Variable created outside scope:
    >>> regular_variable = tf.Variable(1.)
    >>> regular_variable
    <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
    
    Returns:
      A context manager.

    """
    return self._extended._scope(self)

既然是调用了 extended,我们就接着分析。

4.2 StrategyExtendedV2

_scope 则配置了如何创建变量,如何获取变量,如何获取变量作用域。具体返回给用户一个 _CurrentDistributionContext,用户使用比如 creator_with_resource_vars 会调用到 派生策略的 _create_variable 来创建变量。

  def _scope(self, strategy):
  """Implementation of tf.distribute.Strategy.scope()."""

    def creator_with_resource_vars(next_creator, **kwargs):
    """Variable creator to use in _CurrentDistributionContext."""
      _require_strategy_scope_extended(self)
      kwargs["use_resource"] = True
      kwargs["distribute_strategy"] = strategy

      # Unwrap initial_value if it is a CheckpointInitialValue to avoid
      # dereferencing a Tensor that is without a name. We still need to
      # propagate the metadata it's holding.
      if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
        checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
        kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
      elif isinstance(kwargs["initial_value"],
                      trackable.CheckpointInitialValueCallable):
        checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
      elif (isinstance(kwargs["initial_value"], functools.partial) and
            isinstance(kwargs["initial_value"].func,
                       trackable.CheckpointInitialValueCallable)):
        # Some libraries (e.g, Keras) create partial function out of initializer
        # to bind shape/dtype, for example:
        #  initial_val = functools.partial(initializer, shape, dtype=dtype)
        # Therefore to get the restore_uid we need to examine the"func" of
        # the partial function.
        checkpoint_restore_uid = kwargs[
          "initial_value"].func.checkpoint_position.restore_uid
      else:
        checkpoint_restore_uid = None

      # 这里调用派生策略的 _create_variable
      created = self._create_variable(next_creator, **kwargs)

      if checkpoint_restore_uid is not None:
        # pylint: disable=protected-access
        # Let the checkpointing infrastructure know that the variable was
        # already restored so it doesn't waste memory loading the value again.
        # In this case of CheckpointInitialValueCallable this may already be
        # done by the final variable creator, but it doesn't hurt to do it
        # again.
        created._maybe_initialize_trackable()
        created._update_uid = checkpoint_restore_uid
       return created

    def distributed_getter(getter, *args, **kwargs):
      return getter(*args, **kwargs)

    return _CurrentDistributionContext(
        strategy,
        variable_scope.variable_creator_scope(creator_with_resource_vars),
        variable_scope.variable_scope(
            variable_scope.get_variable_scope(),
            custom_getter=distributed_getter), self._default_device)

4.2 创建变量

上面讲到了 creator_with_resource_vars 会调用到派生策略的 _create_variable 来创建变量这里我们就看看 PS 如何处理。初始化时候配置了 self._variable_device,这样就知道了应该如何分配变量到设置之上。在后续代码之中有 with ops.device(self._variable_device),这就是把后续作用域之中的变量放到self._variable_device之上。

self._variable_device = device_setter.replica_device_setter(
        ps_tasks=num_ps_replicas, # 参数服务器
        worker_device=self._worker_device, # 工作设备
        merge_devices=True,
        cluster=cluster_spec)

创建变量如下:

  def _create_variable(self, next_creator, **kwargs):

    # 创建变量
    var_creator = self._create_var_creator(next_creator, **kwargs)

    if"colocate_with" in kwargs:
      colocate_with = kwargs["colocate_with"]
      if isinstance(colocate_with, numpy_dataset.SingleDevice):
        with ops.device(colocate_with.device):
          return var_creator(**kwargs)
      with ops.device(None):
        with ops.colocate_with(colocate_with):
          return var_creator(**kwargs)

    with ops.colocate_with(None, ignore_existing=True):
      #
      with ops.device(self._variable_device): # 这里使用到了 replica_device_setter
        return var_creator(**kwargs)

具体建立变量是通过 _create_var_creator。这里主要的是调用了 ps_values.AggregatingVariable 生成变量。

  def _create_var_creator(self, next_creator, **kwargs):
    if self._num_replicas_in_sync > 1:
      aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
      if aggregation not in (
          vs.VariableAggregation.NONE,
          vs.VariableAggregation.SUM,
          vs.VariableAggregation.MEAN,
          vs.VariableAggregation.ONLY_FIRST_REPLICA
      ):
        raise ValueError("Invalid variable aggregation mode:" + aggregation +
                       " for variable:" + kwargs["name"])

      def var_creator(**kwargs):
      """Create an AggregatingVariable and fix up collections."""
        # Record what collections this variable should be added to.
        collections = kwargs.pop("collections", None)
        if collections is None:
          collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        kwargs["collections"] = []

        # Create and wrap the variable.
        v = next_creator(**kwargs)

        # 建立变量
        wrapped = ps_values.AggregatingVariable(self._container_strategy(), v,
                                                aggregation)

        # Add the wrapped variable to the requested collections.
        # The handling of eager mode and the global step matches
        # ResourceVariable._init_from_args().
        if not context.executing_eagerly():
          g = ops.get_default_graph()
          # If"trainable" is True, next_creator() will add the contained
          # variable to the TRAINABLE_VARIABLES collection, so we manually
          # remove it and replace with the wrapper. We can't set"trainable"
          # to False for next_creator() since that causes functions like
          # implicit_gradients to skip those variables.
          if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            if v in l:
              l.remove(v)
          g.add_to_collections(collections, wrapped)
        elif ops.GraphKeys.GLOBAL_STEP in collections:
          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)

        return wrapped

      return var_creator
    else:
      return next_creator

4.3 PS 变量

AggregatingVariable 就是为变量加了一个 wrapper,这样对于变量的操作就落到了 strategy 之上。这里只给出了部分代码。

# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
class AggregatingVariable(variables_lib.Variable, core.Tensor):
"""A wrapper around a variable that aggregates updates across replicas."""

  def __init__(self, strategy, v, aggregation):
    self._distribute_strategy = strategy
    self._v = v
    # NOTE: We don't use"_distributed_container" here because we don't want
    # to trigger that code path in regroup().
    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
    self._aggregation = aggregation

  def get(self):
    return self._v

  @property
  def distribute_strategy(self):
    return self._distribute_strategy

  def __getattr__(self, name):
    return getattr(self._v, name)

  def _assign_func(self, *args, **kwargs):
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      f = kwargs.pop("f")

      # 这里使用了跨副本上下文
      if ds_context.in_cross_replica_context():
        if distribute_lib.get_update_replica_id() is not None:
          # We are calling an assign function in an update context.
          return f(self._v, *args, **kwargs)

        # We are calling an assign function in cross replica context, wrap it in
        # an update call.
        # 使用策略来更新
        return self._distribute_strategy.extended.update(
            self, f, args=args, kwargs=kwargs)
      else:
        replica_context = ds_context.get_replica_context()
        assert replica_context
        # We are calling an assign function in replica context.
        # We reduce the value we want to assign/add/sub. More details about how
        # we handle the different use cases can be found in the _reduce method.
        # We call the function with the reduced value.
        if self._aggregation == vs.VariableAggregation.NONE:
          raise ValueError(
              values_util.aggregation_error_msg.format(
                  variable_type="AggregatingVariable"))

        def merge_fn(strategy,
                     value,
                     use_locking=False,
                     name=None,
                     read_value=True):
          v = values_util.apply_aggregation(strategy, value, self._aggregation,
                                            self)
          if name and isinstance(name, values.PerReplica):
            name = name.values[0]
          return strategy.extended.update(
              self,
              f,
              args=(v,),
              kwargs={
                "use_locking": use_locking,
                "name": name,
                "read_value": read_value
              })
        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)

  def assign_sub(self, *args, **kwargs):
    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
    return self._assign_func(f=assign_sub_fn, *args, **kwargs)

  def assign_add(self, *args, **kwargs):
    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
    return self._assign_func(f=assign_add_fn, *args, **kwargs)

  def assign(self, *args, **kwargs):
    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
    return self._assign_func(f=assign_fn, *args, **kwargs)

  @property
  def initializer(self):
    return self._v.initializer

  def initialized_value(self):
    return self._v.initialized_value()

  @property
  def initial_value(self):
    return self._v.initial_value

  # 省略大部分代码

具体逻辑如下,第一个操作序列是建立变量,第二个操作序列是处理变量。

图 2 创建变量

我们接下来看看 ParameterServerStrategyV1 如何运行。

5.1 基类

ParameterServerStrategyV1 其实调用了基类 StrategyV1 的 run 方法,具体定义在 tensorflow/python/distribute/distribute_lib.py。具体在前文之中我们已经分析过,这里为了行文完整,再次列举出来如下.

这个方法是用 tf.distribution 对象分发计算的主要方法。它在每个副本上调用fn。如果args或kwargs有tf.distribution.DistributedValues,当 fn 在一个特定的副本上执行时,它将与对应于该副本的 tf.distributed.DistributedValues 的组件一起执行。

tf.distribution.DistributedValues 的例子如下:由 tf.distribution.DistributedDataset 产生的tf.distribution.Strategy.experimental_distribute_dataset 或 tf.distribution.Strategy.Dataset 的 tf.distributedDataset,

fn 在副本上下文被调用,fn可以调用tf.distribution.get_replica_context()来访问诸如all_reduce等成员。args 或kwargs 中的所有参数可以是一个嵌套的张量结构,例如一个张量列表,在这种情况下,args 和 kwargs 将被传递给在每个副本上调用的 fn。或者 args 或 kwargs 可以是包含张量或复合张量的tf.compat.v1.TensorInfo.CompositeTensor 的 tf.distributedValues,在这种情况下,每个fn调用将得到与其副本对应的tf.distributedValues的组件。

重要的是:根据 tf.distribution.Strategy 的实现和是否启用 eager execution,fn可能被调用一次或多次。如果 fn被注解为 tf.function 或者 tf.distribution.Strategy.run 在 tf.function 中被调用(默认情况下 tf.function 中禁止 eager execution),fn 在每个副本中被调用一次以生成 Tensorflow 图,然后被重新用于新输入的执行。

run 方法之中,主要就是调用了 call_for_each_replica。

  def run(self, fn, args=(), kwargs=None, options=None):
  """Invokes fn on each replica, with the given arguments.
  """
    del options

    if not isinstance(args, (list, tuple)):
      raise ValueError(
        "positional args must be a list or tuple, got {}".format(type(args)))

    with self.scope():
      # tf.distribute supports Eager functions, so AutoGraph should not be
      # applied when the caller is also in Eager mode.
      fn = autograph.tf_convert(
          fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
      return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)

Extend

执行来到了 StrategyExtendedV2,其实际上调用的是派生类的 _call_for_each_replica。

  def call_for_each_replica(self, fn, args=(), kwargs=None):
  """Run fn once per replica.

    fn may call tf.get_replica_context() to access methods such as
    replica_id_in_sync_group and merge_call().

    merge_call() is used to communicate between the replicas and
    re-enter the cross-replica context. All replicas pause their execution
    having encountered a merge_call() call. After that the
    merge_fn-function is executed. Its results are then unwrapped and
    given back to each replica call. After that execution resumes until
    fn is complete or encounters another merge_call().  Example:

    ```python
    # Called once in"cross-replica" context.
    def merge_fn(distribution, three_plus_replica_id):
      # sum the values across replicas
      return sum(distribution.experimental_local_results(three_plus_replica_id))

    # Called once per replica in distribution, in a"replica" context.
    def fn(three):
      replica_ctx = tf.get_replica_context()
      v = three + replica_ctx.replica_id_in_sync_group
      # Computes the sum of the v values across all replicas.
      s = replica_ctx.merge_call(merge_fn, args=(v,))
      return s + v

    with distribution.scope():
      # in"cross-replica" context
      ...
      merged_results = distribution.run(fn, args=[3])
      # merged_results has the values from every replica execution of fn.
      # This statement prints a list:
      print(distribution.experimental_local_results(merged_results))
    ```

    Args:
      fn: function to run (will be run once per replica).
      args: Tuple or list with positional arguments for fn.
      kwargs: Dict with keyword arguments for fn.

    Returns:
      Merged return value of fn across all replicas.
  """
    _require_cross_replica_or_default_context_extended(self)
    if kwargs is None:
      kwargs = {}
    with self._container_strategy().scope():
      return self._call_for_each_replica(fn, args, kwargs)

5.2 派生

派生类 ParameterServerStrategyExtended 的 _call_for_each_replica 如下:

  def _call_for_each_replica(self, fn, args, kwargs):
    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
                                              args, kwargs)

具体 mirrored_run 部分已经在前文分析过,这里不再赘述,具体逻辑如下:

图 3 运行

或者从另一个角度如下图所示:

https://www.youtube.com/watch?v=B2Tpv_N7wkg&ab_channel=TensorFlow