↑ 点击蓝字 关注极市平台

作者丨颜挺帅@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/519410235
编辑丨极市平台
极市导读
文本介绍了Pytorch中弹性训练的原理,重点分析了ElasticAgent和Rendezvous的实现,同时也介绍了Failover和Scale Up/Down是如何处理的。>>加入极市CV技术交流群,走在计算机视觉的最前沿

1 介绍

弹性训练是一个相对较新的领域,但是其对训练的容错和集群利用率的提升都有着很大的帮助。当前主流的深度学习框架中Pytorch和Horovod提供了对弹性训练的支持。在实操教程|Pytorch - 弹性训练极简实现( 附源码)一文中我介绍了如何使用Pytorch来实现弹性训练。本文将对Pytorch的弹性训练实现原理进行分析,帮助大家对Pytorch的弹性有更深刻的了解。后续文章我们也会介绍Horovod弹性训练的使用和实现分析。

2 使用回顾

Pytorch在1.9.0引入了torchrun,用其替代1.9.0以前版本的torch.distributed.launch。torchrun在torch.distributed.launch 功能的基础上主要新增了两个功能:
  • Failover: 当worker训练失败时,会自动重新启动所有worker继续进行训练;
  • Elastic: 可以动态增加或或删除node节点;
弹性训练代码同DDP代码编写的思路基本一致,只要在DDP代码上增加以下两点即可:
  • checkpoint处理:由于再每次增加或删除node时,会将所有worker kill掉,然后再重新启动所有worker进行训练。因此,在训练代码中要对训练的状态进行保存,以保证重启后能接着上次的状态继续训练。
  • 超参调解:由于node节点数的变化,会导致global batch size的变化,因此我们的learning rate一般也要做相应的调整,保证训练出的模质量不受影响。
这里不再展示代码,具体代码细节可以参考实操教程|Pytorch - 弹性训练极简实现( 附源码)中的内容。
当编写完弹性训练代码后,我们可以使用torchrun来启动弹性训练任务:
  • --nnodes=1:3 :表示当前训练任务接受最少1个node,最多3个node参与分布式训练;
  • --nproc_per_node=4:表示每个node上节点有4个process
  • --max_restarts=3: worker group最大的重启次数;这里需要注意的是,node fail、node scale down和node scale up都会导致restart;
  • --rdzv_id=1:一个unique的job id,所有node均使用同一个job id;
  • --rdzv_backend: rendezvous的backend实现,默认支持c10d和etcd两种;rendezvous用于多个node之间的通信和协调;
  • --rdzv_endpoint:rendezvous的地址,应该为一个node的host ip和port;
torchrun \

--nnodes=1:3\

--nproc_per_node=4\

--max_restarts=3\

--rdzv_id=1\

--rdzv_backend=c10d\

--rdzv_endpoint="192.0.0.1:1234"\

train_elastic.py

3 整体架构

弹性调度的架构如上图所示,其中最关键角色为elastic agent。在每个Node上面都有一个elastic agent进程,其负责管理当前Node上面的所有workers。
当我们调用torchrun 命令启动弹性训练任务后:
  • 首先,elastic agent会触发rendezvous 流程; rendezvous的功能是在所有elastic agent间做协调和同步,该接口会一直阻塞直到至少min个elastic agent加入进来后返回;
  • 然后,elastic agent会启动当前Node的所有workers
  • 最后,elastic agent会监控当前Node上所有workers的运行状态,并根据workers的状态进行相应的处理(例如restart worker)

4 Elastic Agent

本小结,我们详细分析下Elastic Agent的实现。Elastic Agent在Pytorch代码中由以下对象构成:
  • Elastic Agent是抽象基类
  • SimpleElasticAgent提供了更完整的Agent接口,并且实现了部分接口
  • LocalElasticAgent则是实现剩余的接口
Elastic Agent在代码中的调用逻辑如下:
  • torch.distributed.launcher.api:launch_agent() 弹性训练逻辑的入口;
    • 首先、会构建一个RendezvousParameters来描述Rendezvous调用时所需要的参数,例如min_nodes/max_nodes/endpoint等;
    • 然后、构建WorkerSpec描述当前Node上启动Wokers的信息, 例如max_restart/entrypoint等;
    • 再然后,构建LocalElasticAgent对象;
    • 最后,调用LocalElasticAgent的run接口启动当前node的workers进行弹性训练;
  • Elastic run接口主要由两个部分逻辑组成:
    • 若process group的状态为succeeded:调用_exit_barrier接口等待所有node上agent相应并退出
    • 若process group的状态为unhealthyfailed: 如果重试次数小于_remaining_restart则restart所有worker进程,否则stop所有worker,并退出;
    • 若process group的状态为healthy: 则判断当前是否有node等待加入,如果有则restart_worker;(注:restart worker的实现逻辑是先stop 所有worker,然后在调用_initialize_workers)
    • SimpleElasticAgent._initialize_workers:先调用_rendezvous等待至少min 个node加入,然后调用_start_workers接口在当前node上启动worker process
    • while loop monitor worker:while循环,监控上一步启动process的状态

5 Rendezvous

5.1 基本概念

Pytorch中Rendezvous的实现涉及到很多概念,我们这里先把这些概念一一介绍下,然后再介绍Rendezvous的实现这样会清晰很多。
首先是_RendezvousState,每个ElasticAgent上都会存储一份_RendezvousState,并会在必要时进行彼此间的同步,_RendezvousState存储的内容如下:
  • round: The current round of the rendezvous.
  • complete: A boolean value indicating whether the current round of the rendezvous is complete.
  • deadline: The time at which the current round of the rendezvous will be considered complete if it is still waiting for nodes to join.
  • closed: A boolean value indicating whether the rendezvous is closed.
  • participants: A dictionary of the participants and their corresponding ranks.
  • wait_list:A set of nodes that are waiting to participate in the next round of the rendezvous.
  • last_heartbeats: A dictionary containing each node's last heartbeat time.
那_RendezvousState是如何在所有ElasticAgent间进行同步的呢,Pytorch中又提出了Store的概念,在Pytorch中有TCPStoreFileStoreHashStore三种类型,在弹性训练场景,默认使用TCPStore。
TCPStore的典型用法如下:
  • 其是一个典型的server-client架构,我们在process1上启动server,在proess2上启动client,通过TCPStore的set和get接口可以进行数据的设置和获取
  • 在Rendezvous实现中即是通过TCPStore来对_RendezvousState进行设置和获取的。
import
 torch.distributed 
as
 dist

from
 datetime 
import
 timedelta


# Run on process 1 (server)
server_store = dist.TCPStore(
"127.0.0.1"
1234
2
True
, timedelta(seconds=
30
))


# Run on process 2 (client)
client_store = dist.TCPStore(
"127.0.0.1"
1234
2
False
)


# Use any of the store methods from either the client or server after initialization
server_store.set(
"first_key"
"first_value"
)

client_store.get(
"first_key"
)

Pytorch的Rendezvous实现中,通过C10dRendezvousBackend对TCPStore进行了封装,并提供了set_stateget_state接口,方便state的操作。(注:Pytorch中还提供了EtcdRendezvousBackend,该类型的RendezvousBackend通过Etcd来进行_RendezvousState的同步)。
C10dRendezvousBackend的主要实现如下,可以很清晰的看到get_state和set_state的实现,均是对store接口的调用.
classC10dRendezvousBackend(RendezvousBackend):
defget_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
        base64_state: bytes = self._call_store(
"get"
, self._key)


return
 self._decode_state(base64_state)


defset_state
(

        self, state: bytes, token: Optional[Token] = None

    )
 -> Optional[Tuple[bytes, Token, bool]]:

"""See base class."""
        base64_state_str: str = b64encode(state).decode()


if
 token:

# Shortcut if we know for sure that the token is not valid.
ifnot
 isinstance(token, bytes):

                result = self.get_state()

if
 result 
isnotNone
:

                    tmp = *result, 
False
# Python 3.6 does not support tuple unpacking in return
# statements.
return
 tmp

returnNone

            token = token.decode()

else
:

            token = self._NULL_SENTINEL


        base64_state: bytes = self._call_store(
"compare_set"
, self._key, token, base64_state_str)


        state_token_pair = self._decode_state(base64_state)

if
 state_token_pair 
isNone
:

returnNone

        new_state, new_token = state_token_pair


# C10d Store's compare_set method does not offer an easy way to find out
# whether our write attempt was successful. As a brute-force solution we
# perform a bitwise comparison of our local state and the remote state.
return
 new_state, new_token, new_state == state


def_call_store(self, store_op: str, *args, **kwargs) -> Any:
try
:

return
 getattr(self._store, store_op)(*args, **kwargs)

except
 (ValueError, RuntimeError, TimeoutError) 
as
 exc:

raise
 RendezvousConnectionError(

"The connection to the C10d store has failed. See inner exception for details."
            ) 
from
 exc    

在RendezvousBackend的基础上,Pytorch提出了一个更偏向业务层面的概念**_RendezvousStateHolder**,其提供了_RendezvousState进行获取、同步、标记更新的接口,这些接口的实现均是调用RendezvousBackend的set_state和get_state完成的。
_RendezvousStateHolder的定义如下:
class_RendezvousStateHolder(ABC):
"""Holds the shared rendezvous state synced with other nodes."""

defstate(self) -> _RendezvousState:
"""Gets the local state."""

defsync(self) -> Optional[bool]:
"""Reads or writes the latest state.


        Returns:

            A boolean value indicating whether the local state, in case marked

            as dirty, was successfully synced with other nodes.

        """


defmark_dirty(self) -> None:
"""Marks the local state as dirty."""
Rendezvous的基础设置都准备好了,状态在 _RendezvousState中保存,状态的同步通过 _RendezvousStateHolder来完成,此时还差一项,就是Rendezvous state的是如何变更的。这个变更通过 _RendezvousXXXOp_RendezvousOpExecutor共同来完成。
Pytorch首先提供了_RendezvousExitOp/_RendezvousJoinOp/_RendezvousCloseOp/_RendezvousKeepAliveOp来对应ElasticAgent的退出、加入、Rendezvous关闭和心跳保保持四个操作。这些OP的实现逻辑是根据OP的类型和当前_RendezvousState的内容来决定来返回一个action,_RendezvousOpExecutor则执行对应的action。
例如_RendezvousExitOp 对应ElasticAgent的退出操作
  • 如果当前节点仍旧在participants列表中,则返回一个REMOVE_FROM_PARTICIPANTS,_RendezvousOpExecutor在接收到这个action后会执行_remove_from_participants逻辑;
  • 如果当前节点没有在participants列表中,返回FINISH,这个状态_RendezvousOpExecutor不会做任何操作;
class_RendezvousExitOp:
"""Represents a rendezvous exit operation."""

def__call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if
 ctx.node 
in
 ctx.state.participants:

if
 time.monotonic() > deadline:

return
 _Action.ERROR_TIMEOUT

return
 _Action.REMOVE_FROM_PARTICIPANTS

return
 _Action.FINISH



_DistributedRendezvousOpExecutor的核心接口如下:
  • run提供了执行Rendezvous op的总入口
  • 其他接口则对应了Rendezvous op返回的action的实现。这些action的实现本质上都是对_RendezvousState内容的修改,例如_mark_rendezvous_closed是将_RendezvousState的close字段设置为了True。
class_DistributedRendezvousOpExecutor:
defrun(self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float,) -> None:
def_keep_alive(self) -> None:
def_add_to_participants(self)
def_add_to_wait_list(self)
def_remove_from_participants(self)
def_remove_from_wait_list(self)
def_mark_rendezvous_complete(self)
def_mark_rendezvous_closed(self)
:

        self._state.closed = 
True
最后一个要介绍的概念是RendezvousHandler,其是Rendezvous系统最上层的对外接口,ElasticAgent通过该接口来在所有节点间进行协调。在Pytorch中提供了DynamicRendezvousHandler、EtcdRendezvousHandler和StaticTCPRendezvous三种实现,这里我们仅关注DynamicRendezvousHandler。
RendezvousHandler中最核心的接口是next_rendezvous,ElasticAgent会调用该接口来等待至少min个node的加入。他们实现我们后面再进行讲解。
上面介绍的这些概念,可以通过如下的关系图来进行描述。

5.2 实现逻辑

在熟系完Rendezvous的基本概念后,我们现在可以来看其实现逻辑了。
首先,我们看DynamicRendezvousHandler.next_rendezvous的实现逻辑(注:ElasticAgent通过调用该接口实现的node间的协调)。DynamicRendezvousHandler.next_rendezvous 一共由5个步骤组成:
  • DynamicRendezvousHandler._stop_heartbeats():停止先TCPStore的心跳操作,通过调用定时器_PeriodicTimer的cancel接口实现;
  • Execute Exit OP:执行退出逻辑,如果当前node已经在participants中了,则先把当前节点从_RendezvousState的participants列表中删除;
  • Execute Join OP: 下图仅描述了一个常规的场景,源码中还有一些特殊情况需要处理;
    • 将自己加入到_RendezvousState的participants列表中;
    • 向TCPStore发起心跳,等待至少min个node加入;
    • 当_RendezvousState的participants的个数大于min时,mark rendezvous;
    • 此时,Join OP执行完成,返回给_RendezvousOpExecutor 个Finish action;
  • DynamicRendezvousHandler._start_heartbeats(): 开启心跳,这个逻辑通过_PeriodicTimer定期执行_RendezvousKeepAliveOp实现;_RendezvousKeepAliveOp的操作则是对_RendezvousState的last_heartbeats进行更新来实现;
  • DynamicRendezvousHandler._get_world():从_RendezvousState中获取当前rank和work_size信息;
下面我们再看下Rendezvous的OP是如何执行的。上文提到OP是通过_DistributedRendezvousOpExecutor.run()接口统一来完成的。
  • 主流程包裹在while循环中,直到OP的action为finish方可退出循环;
  • 首先,会调用_BackendRendezvousStateHolder.sync()接口在所有node间进行_RendezvousState的同步;
    • 若当前node有内容需要更新,则调用C10dRendezvousBackend.set_state()来更新;若没有,则调用C10dRendezvousBackend.get_state()来获取最新的state;
    • 若获取了最新的state,则对当前node上存储的state进行更新;
  • 然后,调用当前需要执行的OP,OP接口会返回一个ACTION,_DistributedRendezvousOpExecutor则根据ACTION的内容执行keep_alive/add_to_participants/add_to_wait_list等操作;

6 Failover

Failover分为两种情况:
  • ElasticAgent Process正常,但是worker process 出错
  • ElasticAgent Process 异常退出

6.1 Worker Fail

对于worker fail的场景,worker process的异常状态会被ElasticAgent捕获,实现逻辑在SimpleElasticAgent的_invoke_run接口中。
  • 该接口实现中会循环monitor 当前node上所有worker process的状态,如果process 异常,则会进行入UNHEALTHY/FAILED状态的处理流程。
  • 如果当前重试的次数小于_remain_restart,则会发起restart worker的流程
restart worker的实现逻辑也很清晰:
  • 先stop 点前node上所有worker
  • 然后重新走_initialize_workers逻辑来进行Rendezvous和start worker
def_restart_workers(self, worker_group: WorkerGroup) -> None:
"""

        Restarts (stops, rendezvous, starts) all local workers in the group.

        """


        role = worker_group.spec.role

        log.info(
f"[{role}] Stopping worker group"
)

        self._stop_workers(worker_group)

        worker_group.state = WorkerState.STOPPED

        self._initialize_workers(worker_group)

6.2 ElasticAgent Fail

首先,我们看下当一个node Fail掉后,弹性训练是如何运行的。这有两个node:node0和node1,开始node0和node1同时进行分布式训练,当训练到一定时间后,我们将node1 kill掉。
这是node1上的日志:
[763] epoch 14 (rank = 4, local_rank = 0) loss = 1.2388396263122559

[765] epoch 14 (rank = 6, local_rank = 2) loss = 1.4543075561523438

[766] epoch 14 (rank = 7, local_rank = 3) loss = 1.0290627479553223

[764] epoch 14 (rank = 5, local_rank = 1) loss = 1.1143463850021362

^CTraceback (most recent call last):

Traceback (most recent call last):

File "/opt/conda/bin/torchrun", line 33, in <module>

sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')())

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper

return f(*args, **kwargs)

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 724, in main

run(args)

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/run.py", line 715, in run

elastic_launch(

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in __call__

return launch_agent(self._config, self._entrypoint, list(args))

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 236, in launch_agent

result = agent.run()

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/metrics/api.py", line 125, in wrapper

result = f(*args, **kwargs)

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 709, in run

result = self._invoke_run(role)

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/agent/server/api.py", line 850, in _invoke_run

time.sleep(monitor_interval)

File "/opt/conda/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 60, in _terminate_process_handler

raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)

torch.distributed.elastic.multiprocessing.api.SignalException: Process 759 got signal: 2
这是node0上的日志,我们可以得出以下结论:
  • 当Elastic Agent退出时,会导致其他存活的Elastic Agent中的process 运行失败;这是因为剩余process无法在正常进行collective communication了;
  • 存活的Elastic Agent会按照UNHEALTHY/FAILED的处理逻辑来重启本机的worker;若失败的Elastic Agent没有重启,则剩余的Elastic Agent重新构建worker group继续进行训练,若失败的Elastic Agent重新启动(例如kubernetes中job提供重启的机制),则会重新加入到整个训练任务中;
# 1) 此时node0和node1共同进行分布式训练

...

[11762] epoch 14 (rank = 2, local_rank = 2) loss = 1.1763713359832764 [702/1958]

[11760] epoch 14 (rank = 0, local_rank = 0) loss = 1.324049949645996


# 2) 此时node1被kill掉,因此当执行collective communication时,会报出异常

[E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete d

ata. To avoid this inconsistency, we are taking the entire process down.

terminate called after throwing an instance of 'std::runtime_error'

what(): NCCL error: unhandled system error, NCCL version 21.0.3

ncclSystemError: System call (socket, malloc, munmap, etc) failed.


# 3)stop 其他三个process

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11761 closing signal SIGTERM

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11762 closing signal SIGTERM

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 11763 closing signal SIGTERM

ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 11760) of binary: /opt/conda/bin/python


# 4)重新走_initialize_workers逻辑

[11828] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '40539', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}[11825] Initializing process group

with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '40539', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}

[11826] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '40539', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}

[11827] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '40539', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}

[11827] (rank = 2, local_rank = 2) train worker starting...

[11828] (rank = 3, local_rank = 3) train worker starting...

[11825] (rank = 0, local_rank = 0) train worker starting...

[11826] (rank = 1, local_rank = 1) train worker starting...


# 5)node0 独自进行分布式训练

load checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.ptload checkpoint from checkpoint.pt

[11826] epoch 14 (rank = 1, local_rank = 1) loss = 0.839302122592926

[11828] epoch 14 (rank = 3, local_rank = 3) loss = 0.8971960544586182

[11825] epoch 14 (rank = 0, local_rank = 0) loss = 1.3382269144058228

7 Scale Up/Down

Scale Down的可以理解为上文中Elastic Agent退出,但是没有重启的场景,因此这里不再赘述。
Scale UP这里要再介绍一下,Scale UP的流程仍旧可以用上图进行描述:
  • 当有新的节点加入时,由于当前Elastic已经建立一个的Rendezvous,其无法加入,所以当前Node会被加入到_RendezvousState的wait_list中
  • 当ElasticAgent和对应的worker process都正常运行时,monitor会返回Healthy的状态;此时,ElasticAgent会检查_RendezvousState的waiting list的node个数,发现waiting list大于0,则出发restart worker来发起新一轮的Rendezvous以将新的加入,这样新的Node加入到了worker group中;

8 总结

文本介绍了Pytorch中弹性训练的原理,重点分析了ElasticAgent和Rendezvous的实现,同时也介绍了Failover和Scale Up/Down是如何处理的。后续我们还会介绍Horovod 弹性训练,敬请期待。
公众号后台回复“CVPR 2022”获取论文合集打包下载~
△点击卡片关注极市平台,获取最新CV干货
极市干货
CV技术社群邀请函 #
△长按添加极市小助手
添加极市小助手微信(ID : cvmart4)
备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)
即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群
每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~
觉得有用麻烦给个在看啦~
继续阅读
阅读原文