程序笔记   发布时间:2022-07-19  发布网站:大佬教程  code.js-code.com
大佬教程收集整理的这篇文章主要介绍了[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark大佬教程大佬觉得挺不错的,现在分享给大家,也给大家做个参考。

[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

目录
  • [源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark
    • 0x00 摘要
    • 0x01 总体架构图
    • 0x02 第一阶段 :Horovod 启动
      • 2.1 Driver服务 :SparkDriverservice
      • 2.2 启动spark task : _make_spark_thread
      • 2.3 等待 spark task 启动结束
        • 2.3.1 _notify_and_register_task_addresses
        • 2.3.2 driver.wait_for_initial_registration
      • 2.4 等待
        • 2.3.1 Barrier 1 in Driver
        • 2.3.2 Barrier 2 in task
        • 2.3.3 总体等待流程
    • 0x03 第二阶段 :Spark Task 启动
      • 3.1 具体spark启动逻辑 :_task_fn
      • 3.2 SparkTaskservice
        • 3.2.1 SparkTaskservice 定义
        • 3.2.2 基本功能
      • 3.3 注册Task
        • 3.3.1 发送注册请求
        • 3.3.2 Driver处理
      • 3.4 Task 等待下一步通知
    • 0x04 第三阶段:Driver 通知 task 注册成功
      • 4.1 _notify_and_register_task_addresses
      • 4.2 notify_and_register
      • 4.3 wait_for_task_to_task_address_updates
      • 4.4 等待 In Task
        • 4.4.1 wait_for_command_termination
        • 4.4.2 _command_thread
    • 0x05 总结
    • 0xEE 个人信息

0x00 摘要

Horovod 是Uber于2017年发布的一个易于使用的高性能的分布式训练框架,在业界得到了广泛应用。

本系列将通过源码分析来带领大家了解 Horovod。这几篇介绍 horovod 如何运行在 spark 之上。本文是第九篇,介绍 horovod on spark 如何启动。

本系列其他文章如下:

[源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识

[源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入

[源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun背后做了什么

[源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver

[源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架

[源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构

[源码解析] 深度学习分布式训练框架 horovod (7) --- DiStributedoptimizer

[源码解析] 深度学习分布式训练框架 horovod (8) --- on spark

0x01 总体架构图

首先,我们还是要祭出架构图,这样大家可以按图索骥。

[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

总体来说,Horovod on Spark 的总体逻辑分为以下阶段:

  • 启动 SparkDriverservice 服务,利用 _make_spark_thread 启动 Spark task,然后 horovod 会等待启动结束;
  • 多线程在 spark executor 之中启动 spark task,每个task之中运行一个 SparkTaskservice,SparkTaskservice 会向 hovorod 主进程中的 SparkDriverTask 进行注册,并且等待下一步运行启动的指令;
  • Horovod 收到所有 task 结束的信息之后,通知各个 task,进入下一阶段;
  • Horovod 调用 mpi_run (又利用到 mpirun_rsh.py)在每一个 spark executor 上启动 orted 进程,以启动 MPI cluster;
  • orted 在每一个 executor 之上运行训练代码;

我们下面就具体看看如何启动。

0x02 第一阶段 :Horovod 启动

本部分主要逻辑是:启动 SparkDriverservice 服务,利用 _make_spark_thread 启动 Spark task,然后 horovod 会等待启动结束。

2.1 Driver服务 :SparkDriverservice

SparkDriverservice 继承了 driver_service.basicDriverservice,所以其内部启动了一个 socket server,可以进行网络交互。

Horovod 利用 SparkDriverservice 来和 Spark executor(通过其中运行的SparkTaskservice)交互,比如收集信息,让 spark 启动训练job等等。这是一个 RPC 机制

具体 SparkDriverservice 的功能可以参见其内部处理的各种 request,比如

  • Coderequest :SparkTaskservice会用来请求用户代码;
  • TaskHostHashInDicesrequest :获取 task host 地址;
  • TaskIndexByRankrequest :从 rank 获取到 task index;
  • SetLocalRankToRankrequest :从 local rank 得到 rank 信息;
  • WaitForTaskShutdownrequest :等待 shutdown;

和前文介绍的 HorovodRunDriverservice 有些类似。

其中,其成员变量 _fn 就是训练函数,以后当 SparkTaskservice 请求代码的时候,就通过 CodeResponse 把 _fn 直接发送回去。这样就解决了代码发布问题

class SparkDriverservice(driver_service.basicDriverservicE):
    NAME = 'driver service'

    def __init__(self, initial_np, num_proc, fn, args, kwargs, key, nics):
        super(SparkDriverservice, self).__init__(num_proc,
                                                 SparkDriverservice.NAME,
                                                 key, nics)
        self._initial_np = initial_np
        self._fn = fn # 保存用户代码
        self._args = args # 用户参数
        self._kwargs = kwargs 
        self._key = key
        self._nics = nics # 网卡信息
        self._ranks_to_inDices = {}
        self._spark_job_failed = false
        self._lock = threading.Lock()
        self._task_shutdown = threading.Event()

    def _handle(self, req, client_address):

        if isinstance(req, TaskHostHashInDicesrequest): # 获取 task host 地址
            return TaskHostHashInDicesResponse(self._task_host_hash_inDices[req.host_hash])

        if isinstance(req, SetLocalRankToRankrequest): # 从 local rank 得到 rank 信息
            self._lock.acquire()

            try:
                # get index for host and local_rank
                inDices = self._task_host_hash_inDices[req.host]
                index = inDices[req.local_rank]

                values = list(self._ranks_to_inDices.values())
                prev_pos = values.index(indeX) if index in values else None
                if prev_pos is not None:
                    prev_rank = list(self._ranks_to_inDices.keys())[prev_pos]
                    del self._ranks_to_inDices[prev_rank]

                # memorize rank's index
                self._ranks_to_inDices[req.rank] = index
            finally:
                self._lock.release()
            return SetLocalRankToRankResponse(indeX)

        if isinstance(req, TaskIndexByRankrequest): # 是从 rank 获取到 task index
            self._lock.acquire()
            try:
                return TaskIndexByRankResponse(self._ranks_to_inDices[req.rank])
            finally:
                self._lock.release()

        if isinstance(req, Coderequest): # SparkTaskservice会用来请求用户代码
            return CodeResponse(self._fn, self._args, self._kwargs)

        if isinstance(req, WaitForTaskShutdownrequest): # 等待任务结束
            self._task_shutdown.wait()
            return network.AckResponse()

        return super(SparkDriverservice, self)._handle(req, client_address)

2.2 启动spark task : _make_spark_thread

在 Horovod.spark.run 之中,_make_spark_thread 建立了 thread。这里关键代码是:

@H_891_176@mapper = _make_mapper(driver.addresses(), setTings, use_gloo, is_elastiC)
result = procs.mapPartitionsWithIndex(mapper).collect()

@H_991_152@mapPartitionsWithIndex 这句代码会促使 Spark 在多个 Executor 之中运行 mapper 函数,并且得到运行结果。

即创建 setTings.num_procSpark tasks,每个 task 会运行 mapper(_task_fn), 外部的 run 函数会等待这些执行结果。其实如果需要使用RDD,也许可以使用 foreachPartition,这样每个结点上将会在内存中持有RDD的一个分区。

def _make_spark_thread(spark_context, spark_job_group, driver, result_queue,
                       setTings, use_gloo, is_elastiC):
    """Creates `setTings.num_proc` Spark tasks in a parallel thread."""
    
    def run_spark():
        """Creates `setTings.num_proc` Spark tasks, each execuTing `_task_fn` and waits for them to terminate."""
        try:
            spark_context.setJobGroup(spark_job_group, "Horovod Spark Run", interruptOnCancel=TruE)
            procs = spark_context.range(0, numSlices=setTings.max_np if setTings.elastic else setTings.num_proC)
            # We assume that folks caring about security will enable Spark RPC encryption,
            # thus ensuring that key that is passed here remains secret.
            mapper = _make_mapper(driver.addresses(), setTings, use_gloo, is_elastiC)
            # 促使 Spark 在多个 Executor 之中运行 mapper 函数,并且得到运行结果
            result = procs.mapPartitionsWithIndex(mapper).collect()
            result_queue.put(result)
        except:
            driver.notify_spark_job_failed()
            raise

    spark_thread = in_thread(target=run_spark, daemon=falsE)
    return spark_thread

2.3 等待 spark task 启动结束

启动了 spark task 之后,horovod 主进程会调用如下来等待 task 全部 启动完成。

# wait for all tasks to register, notify them and initiate task-to-task address registration
_notify_and_register_task_addresses(driver, setTings)

即,run 函数中,当 _make_spark_thread 之后,horovod 主进程调用 _notify_and_register_task_addresses,从而调用 driver.wait_for_initial_registration(setTings.start_timeout) ,进行总体等待。

等待的内容是:等待所有 num_proc tasks 来注册。当所有 spark thread 都ready 之后,主 horovod 进程会继续运行

[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

2.3.1 _notify_and_register_task_addresses

horovod 主进程之中,会使用 _notify_and_register_task_addresses等待这些 spark task 来注册,从而调用 driver.wait_for_initial_registration(setTings.start_timeout) ,进行总体等待。

注意,同时发送注册请求之后, spark task 自己也调用 task.wait_for_initial_registration 等待 horovod 再通知下一阶段的启动。

而在horovod 主进程的 _notify_and_register_task_addresses 其实也很复杂:

  • 调用 driver.wait_for_initial_registration 等待task来注册,需要等待 num_proc 个task;
  • 利用 notify_and_register 注册task,并且通知各个 task 开始下一步;

具体代码如下:

def _notify_and_register_task_addresses(driver, setTings, notify=TruE):
    # wait for num_proc tasks to register
    # 等待task来注册,需要等待 num_proc 个task
    driver.wait_for_initial_registration(setTings.start_timeout) 

    def notify_and_register(indeX): # 注册task,并且通知各个 task 开始下一步
        task_client = task_service.SparkTaskClient(index,
                                                   driver.task_addresses_for_driver(indeX),
                                                   setTings.key, setTings.verbosE)

        if notify:
            task_client.notify_initial_registration_complete()

        next_task_index = (index + 1) % setTings.num_proc
        next_task_addresses = driver.all_task_addresses(next_task_indeX)
        task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
        driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)

    for index in driver.task_inDices():
        in_thread(notify_and_register, (index,)) #在thread之中启动task

    driver.wait_for_task_to_task_address_updates(setTings.start_timeout)

我们目前只能看其第一步 “等待注册”。

2.3.2 driver.wait_for_initial_registration

在这里 SparkDriverSerivce 首先等待所有 spark executor 注册。

在 class BasicDriverservice(network.basicservicE): 有如下代码,可以看到,只有全部 _num_proc 注册完成,当所有 spark thread 都ready 之后,主 horovod 进程会继续运行。

这里关键是:while len(self._all_task_addresses) < self._num_proc就是等待 self._all_task_addresses 的数目达到 _num_proc。

class BasicDriverservice(network.basicservicE):
  
  def wait_for_initial_registration(self, timeout):
      self._wait_cond.acquire()
      try:
          # 等待 self._all_task_addresses 的数目达到 _num_proc
          while len(self._all_task_addresses) < self._num_proc:
              self._wait_cond.wait(timeout.remaining())
              timeout.check_time_out_for('tasks to start')
      finally:
          self._wait_cond.release()

2.4 等待

关于等待代码,我们要做一下特殊说明,具体看图。

[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

这里有两套 wait_for_initial_registration。可以认为是两套 barrier

就是:

  • barrier 1 :SparkDriverSerivce 等待所有 SparkTaskSerivce ready;
  • barrier 2 :所有 SparkTaskSerivce 需要一起运行,所以 SparkTaskSerivce们 都在等待 barrier 2。SparkDriverSerivce 会通知 这些 SparkTaskSerivce 一起发动;

2.3.1 Barrier 1 in Driver

在 run 函数中,当 _make_spark_thread 之后,horovod 主进程调用 _notify_and_register_task_addresses,从而调用 driver.wait_for_initial_registration(setTings.start_timeout) ,进行总体等待。

等待的内容是:等待所有 num_proc tasks 来注册。当所有 spark thread 都ready 之后,主 horovod 进程会继续运行。这里关键是:

while len(self._all_task_addresses) < self._num_proc

就是等待 self._all_task_addresses 的数目达到 _num_proc。

def wait_for_initial_registration(self, timeout):
    self._wait_cond.acquire()
    try:
        while len(self._all_task_addresses) < self._num_proc:
            self._wait_cond.wait(timeout.remaining())
            timeout.check_time_out_for('tasks to start')
    finally:
        self._wait_cond.release()

在 BasicDriverservice 之中,如果收到了 spark executor 的注册请求就进行处理,这里最重要是:

self._all_task_addresses[req.index] = req.task_addresses

当所有的 spark executor 都注册了,这里就等待成功

2.3.2 Barrier 2 in task

每个 spark thread 在 _task_fn 之中运行,就是在 spark task 之中运行。这里也可以看出来是 Spark task 的一个总体流程

  • 首先 调用 register_task
  • 其次 调用 task.wait_for_initial_registration(setTings.start_timeout)
  • 然后 调用 wait_for_command_termination 来等待结束;

task.wait_for_initial_registration 会等待 self._initial_registration_complete = True 这个条件,就是等待 register_task 注册完成。

每个 Spark Executor 都有一个 SparkTaskservice,所以 每个spark task 都有自己的 _initial_registration_complete。

hovorod.run 主进程会逐一通知每个 SparkTaskservice 的 _initial_registration_complete。

即,哪个 SparkTaskservice 好了,就通知哪个 SparkTaskservice 的 _initial_registration_complete。这样,这个 SparkTaskservice 就可以正式运行了。

2.3.3 总体等待流程

总体等待流程具体如图,数字就是执行顺序:

@H_616_357@
  • SparkDriverSerivce 调用 driver.wait_for_initial_registration 来等待 SparkTaskSerivce 的注册,这是 barrier 1
  • SparkTaskSerivce 1 进行注册,然后 SparkTaskSerivce 1 自己也调用 task.wait_for_initial_registration 等待 horovod 再通知下一阶段的启动,这是 barrier 2
  • SparkTaskSerivce 2 进行注册,然后 SparkTaskSerivce 2 自己也调用 task.wait_for_initial_registration 等待 horovod 再通知下一阶段的启动,这是 barrier 2
  • hovorod.run 主进程在发现所有 task 都注册之后,barrier 1 等待结束,会逐一通知每个 SparkTaskservice 的 _initial_registration_complete。只有 4 完成之后,两个 SparkTaskSerivce 才能继续执行 5,6;
  • SparkTaskSerivce 1 对于 barrier 2 等待结束,继续执行
  • SparkTaskSerivce 2 对于 barrier 2 等待结束,继续执行
  •     SparkTaskSerivce 1          SparkTaskSerivce 2            SparkDriverSerivce
    
                +                           +                             +
                |                           |                             |
                |                           |                             |
                |                           |                             |
                |                           |                             |   1
                |                           |                             |
                |                           |                             |
                |                           |                             v
                |                           |
                |                           |         +--------------------------------------+
                |                           |         | barrier 1                            |
                |                           |   2     |                                      |
                |          3                +-------> |                                      |
                |                           |         |                                      |
                +-----------------------------------> | driver.wait_for_initial_registration |
                |                           |         |                                      |
                |                           |         |                                      |
                |                           |         |                                      |
                |                           |         +--------------------+-----------------+
                |                           |                              |
                |                           |                              |
    +-----------+----------------------+    |                  4           |
    |barrier 2                         | <---------------------------------+
    |                                  |    |                              |
    |task.wait_for_initial_registration|    |                              |
    |                                  |    |                              |
    +-----------+----------------------+    |                              |
                |                           |                              |
                |             +-------------+----------------------+       |
                |             | barrier 2                          |   4   |
                | 6           |                                    +<------+
                |             | task.wait_for_initial_registration |       |
                |             |                                    |       |
                |             +-------------+----------------------+       |
                |                           |                              |
                |                           |                              |
                |                           |  5                           |
                |                           |                              |
                v                           v                              v
    
    

    我们接下来详细介绍 task 启动内容 和 driver 后续工作。

    0x03 第二阶段 :Spark Task 启动

    本阶段我们详细介绍下 Spark Task 的启动过程。

    这部分主要功能是:多线程在 spark executor 之中启动 spark task,每个spark task会运行_task_fn函数,_task_fn函数会运行一个 SparkTaskservice,SparkTaskSerivce 会向 hovorod 主进程中的 SparkDriverTask 进行注册,并且等待下一步运行启动的指令;

    此时程序(不是训练程序,而是 SparkTaskservice)已经在 Spark Executor内部运行了。我们看看在 spark Executor 之中,是如何启动运行 SparkTaskservice 的。

    3.1 具体spark启动逻辑 :_task_fn

    Horovod 在 thread 里面通过 _make_mapper 来让 Spark 运行 _task_fn。

    def _make_mapper(driver_addresses, setTings, use_gloo, is_elastiC):
    
        def _mapper(index, _):
            yield _task_fn(index, driver_addresses, key, setTings, use_gloo, is_elastiC)
    
        return _mapper
    

    _task_fn 的作用是为了注册 horovod 进入到 spark task。即,在每一个 spark task (executor) 之中启动一个 SparkTaskservice。

    一定要注意:这些 SparkTaskservice 是运行在 spark executor 之中,通过网络与 horovod 之中的 SparkDriverservice 交互

    可以看到,_task_fn 的总体逻辑是:

    • 启动 SparkTaskservice;
    • 通过 driver_service.SparkDriverClient.register_task 来向 horovod 中的 Driver 注册;
    • 通过 task.wait_for_initial_registration(setTings.start_timeout) 来等待下一步启动的开始指示;
    • 如果下一步开始启动了,则调用 task.wait_for_command_termination() 等待结束;

    具体如下:

    def _task_fn(index, driver_addresses, key, setTings, use_gloo, is_elastiC):
        setTings.key = key
        hosthash = host_hash(salt='{}-{}'.format(index, time.time()) if is_elastic else NonE)
        os.environ['HOROVOD_HOSTNAME'] = hosthash
        # 启动 SparkTaskservice,SparkTaskservice本身包括一个socket server,可以和driver交互
        task = task_service.SparkTaskservice(index, setTings.key, setTings.nics,...)
        try:
            driver_client = driver_service.SparkDriverClient(driver_addresses, setTings.key, setTings.verbosE)
            # 向 horovod 中的 Driver 注册
            driver_client.register_task(index, task.addresses(), hosthash)
    
            # 这里依然运行在spark task之中,但因为不是SparkTaskservice,所以只是做协助工作,最后静静等待
            if not is_elastic:
                # 等待下一步启动的开始指示
                task.wait_for_initial_registration(setTings.start_timeout)
                task_inDices_on_this_host = driver_client.task_host_hash_inDices(hosthash)
                local_rank_zero_index = task_inDices_on_this_host[0]
            else:
                local_rank_zero_index = None
    
            if is_elastic:
    						...... # 后续文章会介绍
            elif use_gloo or index == local_rank_zero_index:
                # Either Gloo or first task with MPI.
                # 使用Gloo或者使用MPI的第一个task,让这个task做操作
                task.wait_for_command_start(setTings.start_timeout)
                # 等待结束
                task.wait_for_command_termination()
            else:
                # The other tasks with MPI need to wait for the first task to finish.
                # 让其他的task等待第一个task结束
                first_task_addresses = driver_client.all_task_addresses(local_rank_zero_indeX)
                first_task_client = 
                    task_service.SparkTaskClient(local_rank_zero_index,
                                                 first_task_addresses, setTings.key,
                                                 setTings.verbosE)
                # 调用 task.wait_for_command_termination() 等待结束  
                first_task_client.wait_for_command_termination()
    
            return task.fn_result()
        finally:
            task.shutdown()
    

    3.2 SparkTaskservice

    再次强调如下代码:

    task = task_service.SparkTaskservice(index, setTings.key, setTings.nics,...)

    每一个_task_fn 中都定义了一个 SparkTaskservice,即每一个 Spark Executor 都会生成一个(或者多个) SparkTaskservice,在 spark task 之中运行并且作用。

    3.2.1 SparkTaskservice 定义

    SparkTaskservice 定义如下,因为继承了BasicTaskservice,所以其内部最终也会启动一个 socket server,以便同 horovod 中的 SparkDriverservice 交互:

    class SparkTaskservice(task_service.basicTaskservicE):
        NAME_FORMAT = 'task service #%d'
    
        def __init__(self, index, key, nics, minimum_command_lifetime_s, verbose=0):
            # on a Spark cluster we need our Train function to see the Spark worker environment
            # this includes PYTHONPATH, HADOOP_TOKEN_FILE_LOCATION and _HOROVOD_SECRET_KEY
            env = os.environ.copy()
    
            # we inject the secret key here
            env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(key)
    
            # we also need to provide the current working dir to mpirun_exec_fn.py
            env['HOROVOD_SPARK_WORK_DIR'] = os.getcwd()
    
            super(SparkTaskservice, self).__init__(SparkTaskservice.NAME_FORMAT % index,
                                                   index, key, nics, env, verbosE)
            self._key = key
            self._minimum_command_lifetime_s = minimum_command_lifetime_s
            self._minimum_command_lifetime = None
    

    3.2.2 基本功能

    SparkTaskservice 的基本功能如下。

    • _run_command 将会被用来在 spark 之中启动训练job;
    • _handle 会处理 GetTask@R_654_10586@skAddressesrequest,用来获取 task 地址,也会处理resourcesrequest,返回资源;
    • _get_resources 将返回 spark 资源;
    • wait_for_command_termination 会等待命令执行结束;

    具体代码如下:

    def _run_command(self, command, env, event,
                     stdout=None, stderr=None, index=None,
                     prefix_output_with_timestamp=falsE):
        # 在 spark 之中启动训练job
        super(SparkTaskservice, self)._run_command(command, env, event,
                                                   stdout, stderr, index,
                                                   prefix_output_with_timestamp)
    
        if self._minimum_command_lifetime_s is not None:
            self._minimum_command_lifetime = timeout.Timeout(self._minimum_command_lifetime_s,
                                                             message='Just measuring runtime')
    
    def _handle(self, req, client_address):
        # 返回资源
        if isinstance(req, resourcesrequest):
            return resourcesResponse(self._get_resources())
    
        # 获取 task 地址  
        if isinstance(req, GetTask@R_654_10586@skAddressesrequest):
            next_task_index = req.task_index
            next_task_addresses = req.all_task_addresses
            # We request interface matching to weed out all the NAT'ed interfaces.
            next_task_client = 
                SparkTaskClient(next_task_index, next_task_addresses,
                                self._key, self._verbose,
                                match_intf=TruE)
            return GetTask@R_654_10586@skAddressesResponse(next_task_client.addresses())
    
        return super(SparkTaskservice, self)._handle(req, client_address)
    
    def _get_resources(self):
        # 返回 spark 资源
        if LooseVersion(pyspark.__version__) >= LooseVersion('3.0.0'):
            task_context = pyspark.TaskContext.get()
            if task_context:
                return task_context.resources()
            else:
                print("Not running inside Spark worker, no resources available")
        return Dict()
    
    def wait_for_command_termination(self):
        """
        Waits for command termination. Ensures this method takes at least
        self._minimum_command_lifetime_s seconds to return after command started.
        """
        try:
            # 等待命令执行结束
            return super(SparkTaskservice, self).wait_for_command_termination()
        finally:
            # command terminated, make sure this method takes at least
            # self._minimum_command_lifetime_s seconds after command started
            # the client that started the command needs some time to connect again
            # to wait for the result (see horovod.spark.driver.rsh).
            if self._minimum_command_lifetime is not None:
                time.sleep(self._minimum_command_lifetime.remaining())
    

    3.3 注册Task

    下一步代码就是用来向 Driver 注册 本 task。

    driver_client.register_task(index, task.addresses(), hosthash)
    

    3.3.1 发送注册请求

    注册具体通过如下完成,这里调用了 network.py 的 _send 函数,就是通过 socket,spark executor 和 horovod driver 进行了网络交互:

    class BasicDriverClient(network.basicClient):
    
        def register_task(self, index, task_addresses, host_hash):
            self._send(RegisterTaskrequest(index, task_addresses, host_hash))
    

    3.3.2 Driver处理

    我们先来到 Horovod 中运行的 Driver来看看(下一节内容,这里提前看看

    在 BasicDriverservice 之中,如果收到了RegisterTaskrequest请求就进行处理,这里最重要是:

    self._all_task_addresses[req.index] = req.task_addresses

    这样,self._all_task_addresses 的数目就增加了。

    而我们之前提到了,horovod 正在 driver.wait_for_initial_registration 上面等待,其关键是:

    while len(self._all_task_addresses) < self._num_proc

    如果self._all_task_addresses 的数目达到了_num_proc,driver.wait_for_initial_registration 就结束了,就顺利执行。

    具体处理 RegisterTaskrequest 的代码如下,BasicDriverservice 之中有各种成员变量,用来维护各种所需信息,我们在前文 [原创 源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver 中已经详细讲解过,_handle函数的RegisterTaskrequest 处理就是用来更新这些成员变量:

    class BasicDriverservice(network.basicservicE):
    
        def _handle(self, req, client_address):
            if isinstance(req, RegisterTaskrequest):
                self._wait_cond.acquire()
                try:
    
                    self._all_task_addresses[req.index] = req.task_addresses
                    # Just use source address for service for fast probing.
                    self._task_addresses_for_driver[req.index] = 
                        self._filter_by_ip(req.task_addresses, client_address[0])
                      
                    # Remove host hash earlier registered under this index.
                    if req.index in self._task_index_host_hash:
                        earlier_host_hash = self._task_index_host_hash[req.index]
                        if earlier_host_hash != req.host_hash:
                            self._task_host_hash_inDices[earlier_host_hash].remove(req.indeX)
    
                    # Make index -> host hash map.
                    self._task_index_host_hash[req.index] = req.host_hash
    
                    # Make host hash -> inDices map.
                    if req.host_hash not in self._task_host_hash_inDices:
                        self._task_host_hash_inDices[req.host_hash] = []
                    self._task_host_hash_inDices[req.host_hash].append(req.indeX)
                    # TODO: this sorTing is a problem in elastic horovod
                    self._task_host_hash_inDices[req.host_hash].sort()
                finally:
                    self._wait_cond.notify_all()
                    self._wait_cond.release()
                    
                return network.AckResponse()
    

    3.4 Task 等待下一步通知

    前面提到了,当 spark task 向 driver 发送注册请求之后,Spark task 通过 task.wait_for_initial_registration(setTings.start_timeout) 来等待下一步启动的开始指示。就是 driver 认为你一景注册完成了,可以开始进入下一步了。

    task.wait_for_initial_registration 会等待 self._initial_registration_complete = True 这个条件,就是等待 register_task 注册完成。

    class BasicTaskservice(network.basicservicE):
      
      def wait_for_initial_registration(self, timeout):
            self._wait_cond.acquire()
            try:
                while not self._initial_registration_complete:
                    self._wait_cond.wait(timeout.remaining())
                    timeout.check_time_out_for('tasks to start')
            finally:
                self._wait_cond.release()
    

    每个 Spark Executor 都有一个 SparkTaskservice,所以 每个spark task 都有自己的 _initial_registration_complete。

    hovorod.run 主进程会逐一通知每个 SparkTaskservice 的 _initial_registration_complete。即,哪个 SparkTaskservice 好了,就通知哪个 SparkTaskservice 的 _initial_registration_complete。

    hovorod.run 主进程 是通过发送 NotifyInitialRegistrationCompleterequest完成这一步的。

    def notify_initial_registration_complete(self):
        self._send(NotifyInitialRegistrationCompleterequest())
    

    BasicTaskservice 在等待 NotifyInitialRegistrationCompleterequest,如果收到了,就设置为 True,这样wait_for_initial_registration 就等待结束了。

    if isinstance(req, NotifyInitialRegistrationCompleterequest):
        self._wait_cond.acquire()
        try:
            self._initial_registration_complete = True
        finally:
            self._wait_cond.notify_all()
            self._wait_cond.release()
        return network.AckResponse()
    

    就说明当本 thread 注册在 horovod 之后,就算本 spark thread 启动成功了。

    +-------------------------------------+             +----------------------------------------------------+
    | Horovod Main thread                 |             | Spark Executor                                     |
    |                                     |             |                     _task_fn                       |
    |                                     |             |                        +                           |
    |                                     |             |                        |                           |
    |                                     |             |                        |                           |
    |                                     |             |                        v                           |
    | +-------------------------------+   |             |  +---------------------+------------------------+  |
    | | SparkDriverservice            |   |             |  | SparkTaskservice                             |  |
    | |                               |   |             |  |               +                              |  |
    | |                               |   |  1 register |  |               |                              |  |
    | |  self._all_task_addresses <----------------------------------------+                              |  |
    | |                               |   |             |  |               |                              |  |
    | |              +                |   |             |  |               |                              |  |
    | |              |                |   |             |  |               |                              |  |
    | |              | 3              |   |             |  |               |                              |  |
    | |              |                |   |             |  |               | 2                            |  |
    | |              v                |   |             |  |               |                              |  |
    | |  self._wait_cond.notify_all() |   |             |  |               |                              |  |
    | |              +                |   |             |  |               v                              |  |
    | |              |                |   |             |  |     +---------+---------------------------+  |  |
    | |              |                |   |             |  |     |                                     |  |  |
    | |              |                |   |             |  |     | task.wait_for_initial_registration  |  |  |
    | |              |                |   |             |  |     |                                     |  |  |
    | |              |                |   |             |  |     +-------------------------------------+  |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              |                |   |             |  |                                              |  |
    | |              v                |   |             |  |                                              |  |
    | |                               |   |             |  |                                              |  |
    | |                               |   |             |  |                                              |  |
    | |                               |   |             |  |                                              |  |
    | +-------------------------------+   |             |  +----------------------------------------------+  |
    +-------------------------------------+             +----------------------------------------------------+
    
    

    手机如下:

    [源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

    0x04 第三阶段:Driver 通知 task 注册成功

    本阶段的作用是:Horovod 收到所有 task 结束的信息之后,通知各个 task,进入下一阶段。

    4.1 _notify_and_register_task_addresses

    前面提到。在 horovod 主进程之中,会使用 _notify_and_register_task_addresses 来等待这些 spark task 来注册,从而调用 driver.wait_for_initial_registration(setTings.start_timeout) ,进行总体等待。

    注意,同时发送注册请求之后, spark task 自己也调用 task.wait_for_initial_registration 等待horovod 再通知下一阶段的启动。

    而 _notify_and_register_task_addresses 中其实也很复杂:

    • 调用 driver.wait_for_initial_registration 等待task来注册;(目前这一步已经完成
    • 利用 notify_and_register 注册task,并且通知各个 task 开始下一步;(我们这里进入后面这两步
    • 利用 driver.wait_for_task_to_task_address_updates 再次确认下所有 task 都OK;
    def _notify_and_register_task_addresses(driver, setTings, notify=TruE):
        # wait for num_proc tasks to register
        driver.wait_for_initial_registration(setTings.start_timeout)
    
        def notify_and_register(indeX):
            # 注册task,并且通知各个 task 开始下一步
            task_client = task_service.SparkTaskClient(index,
                                                       driver.task_addresses_for_driver(indeX),
                                                       setTings.key, setTings.verbosE)
    
            if notify:
                task_client.notify_initial_registration_complete()
    
            next_task_index = (index + 1) % setTings.num_proc
            next_task_addresses = driver.all_task_addresses(next_task_indeX)
            task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
            driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)
    
        for index in driver.task_inDices():
            in_thread(notify_and_register, (index,)) # 注册task,并且通知各个 task 开始下一步
    
        # 再次确认下所有 task 都OK    
        driver.wait_for_task_to_task_address_updates(setTings.start_timeout)
    

    4.2 notify_and_register

    可以看到 notify_and_register 的作用就是:

    • 调用 task_client.notify_initial_registration_complete() 通知 spark task 注册成功了,这样就让所有等待 task.wait_for_initial_registration 的 spark executor 一起运行下一阶段。
    • 调用 driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses) 来让 Driver 完成注册。
    def wait_for_task_to_task_address_updates(self, timeout):
        self._wait_cond.acquire()
        try:
            while len(self._task_addresses_for_tasks) < self._initial_np:
                self.check_for_spark_job_failure()
                self._wait_cond.wait(timeout.remaining())
                timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
        finally:
            self._wait_cond.release()
    

    4.3 wait_for_task_to_task_address_updates

    这里会再次确认所有 spark task 都OK。

    def wait_for_task_to_task_address_updates(self, timeout):
        self._wait_cond.acquire()
        try:
            while len(self._task_addresses_for_tasks) < self._initial_np:
                self.check_for_spark_job_failure()
                self._wait_cond.wait(timeout.remaining())
                timeout.check_time_out_for('Spark tasks to update task-to-task addresses')
        finally:
            self._wait_cond.release()
    

    4.4 等待 In Task

    在 Spark task 之中,如果收到了下一步启动指示之后,会调用 wait_for_command_termination 进行等待。

    其实,这一步也就意味spark exector 自己本身的逻辑任务结束了,因为以后都是 SparkTaskservice 自己独立完成的动作,它来负责训练代码的启动。既然 _task_fn 的逻辑任务已经结束,那么静静地等待即可。

    4.4.1 wait_for_command_termination

    在 horovod-master/horovod/spark/task/task_service.py

    def wait_for_command_termination(self):
        """
        Waits for command termination. Ensures this method takes at least
        self._minimum_command_lifetime_s seconds to return after command started.
        """
        try:
            return super(SparkTaskservice, self).wait_for_command_termination()
        finally:
            # command terminated, make sure this method takes at least
            # self._minimum_command_lifetime_s seconds after command started
            # the client that started the command needs some time to connect again
            # to wait for the result (see horovod.spark.driver.rsh).
            if self._minimum_command_lifetime is not None:
                time.sleep(self._minimum_command_lifetime.remaining())
    

    在 horovod-master/horovod/runner/common/service/task_service.py 中可以看到,就是等待训练代码所在的 thread 结束。

    def wait_for_command_termination(self):
        self._command_thread.join() # 马上会说明
    

    4.4.2 _command_thread

    这里对 _command_thread 略作说明。

    在 SparkTaskservice 处理 RunCommandrequest 时候,运行 Command 的 thread 就是被赋值为 _command_thread。

    class BasicTaskservice(network.basicservicE):
        def _handle(self, req, client_address):
          
            if isinstance(req, RunCommandrequest): # 运行命令请求
                self._wait_cond.acquire()
                try:
                    if self._command_thread is None:
    
                        if self._command_env:
                            env = self._command_env.copy()
                            self._add_envs(env, req.env)
                            req.env = env
    
                        self._command_abort = threading.Event()
                        self._command_stdout = Pipe() if req.capture_stdout else None
                        self._command_stderr = Pipe() if req.capture_stderr else None
                        # 配置各种参数信息
                        args = (req.command, req.env, self._command_abort,
                                self._command_stdout, self._command_stderr,
                                self._index,
                                req.prefix_output_with_timestamp)
                        # 启动一个新线程来运行命令
                        self._command_thread = in_thread(self._run_command, args)
                finally:
                    self._wait_cond.notify_all()
                    self._wait_cond.release()
                return network.AckResponse()  
    

    逻辑如下:

    +-------------------------------------+             +----------------------------------------------------+
    | Horovod Main thread                 |             | Spark Executor                                     |
    |                                     |             |                     _task_fn                       |
    |                                     |             |                        +                           |
    |                                     |             |                        |                           |
    |                                     |             |                        |                           |
    |                                     |             |                        v                           |
    | +-------------------------------+   |             |  +---------------------+------------------------+  |
    | | SparkDriverservice            |   |             |  | SparkTaskservice                             |  |
    | |                               |   |             |  |               +                              |  |
    | |                               |   |  1 register |  |               |                              |  |
    | |  self._all_task_addresses <----------------------------------------+                              |  |
    | |                               |   |             |  |               |                              |  |
    | |              +                |   |             |  |               |                              |  |
    | |              |                |   |             |  |               |                              |  |
    | |              | 3              |   |             |  |               |                              |  |
    | |              |                |   |             |  |               | 2                            |  |
    | |              v                |   |             |  |               |                              |  |
    | |  self._wait_cond.notify_all() |   |             |  |               |                              |  |
    | |              +                |   |             |  |               v                              |  |
    | |              |                |   +             +  +     +---------+---------------------------+  |  |
    | |              |            4   |  RegistrationComplete    |                                     |  |  |
    | |              |  +-----------------+-------------+--+---> | task.wait_for_initial_registration  |  |  |
    | |              |                |   |             |  |     |                                     |  |  |
    | |              |                |   |             |  |     +---------+---------------------------+  |  |
    | |              |                |   |             |  |               |                              |  |
    | |              |                |   |             |  |               |                              |  |
    | |              |                |   |             |  |               | 5                            |  |
    | |              |                |   |             |  |               |                              |  |
    | |              |                |   |             |  |               |                              |  |
    | |              |                |   |             |  |               v                              |  |
    | |              |                |   |             |  |        wait_for_command_termination          |  |
    | |              |                | 6 |  RunCommand |  |               +                              |  |
    | |              |                |   |             |  |               |                              |  |
    | |              +----------------------------------------------->     | 7                            |  |
    | |              |                |   |             |  |               v                              |  |
    | |              v                |   |             |  |        self._command_thread.join()           |  |
    | |                               |   |             |  |                                              |  |
    | |                               |   |             |  |                                              |  |
    | |                               |   |             |  |                                              |  |
    | +-------------------------------+   |             |  +----------------------------------------------+  |
    +-------------------------------------+             +----------------------------------------------------+
    
    

    手机如下:

    [源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

    至此,第一阶段完成,我们下一篇继续,敬请期待。

    0x05 总结

    总体来说,Horovod on Spark 的总体逻辑分为以下阶段:

    • 启动 SparkDriverservice 服务,利用 _make_spark_thread 启动 Spark task,然后 horovod 会等待启动结束;
    • 多线程在 spark executor 之中启动 spark task,每个task之中运行一个 SparkTaskservice,SparkTaskservice 会向 hovorod 主进程中的 SparkDriverTask 进行注册,并且等待下一步运行启动的指令;
    • Horovod 收到所有 task 结束的信息之后,通知各个 task,进入下一阶段;
    • Horovod 调用 mpi_run (又利用到 mpirun_rsh.py)在每一个 spark executor 上启动 orted,以启动 MPI cluster;
    • orted 在每一个 executor 之上运行训练代码;

    本文介绍了前三个阶段,即启动阶段。下文介绍后续两个阶段,敬请期待。

    0xEE 个人信息

    ★★★★★★关于生活和技术的思★★★★★★

    微信公众账号:罗西的思

    如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。

    [源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

    大佬总结

    以上是大佬教程为你收集整理的[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark全部内容,希望文章能够帮你解决[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark所遇到的程序开发问题。

    如果觉得大佬教程网站内容还不错,欢迎将大佬教程推荐给程序员好友。

    本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
    如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。