幻方萤火 | 并行训练中的快速保存与加载 hfai.checkpoint

High-Flyer    May 23, 2022

分布式训练中模型的保存,特别是大模型,常常需要耗费很多的时间,降低了整体的GPU利用率。针对这类问题,幻方AI进行了攻关,优化过往深度学习模型单机训练保存的方法,研发出分布式 checkpoint 方案,大幅度降低模型保存与加载上的开销。

分布式 checkpoint

当我们进行分布式模型训练,特别是在训练大模型时,保存 checkpoint 需要较长的时间,这不仅浪费集群的计算资源,并且给集群整体的任务调度带来管理成本。为此幻方AI提供了一个分布式保存 checkpoint 的功能。

该功能的基本原理是:假设有 N 块 GPU,我们把模型参数和优化器参数切分成 N 个部分,然后每个分布 rank 把对应的部分写入文件系统;在读取的时候我们从文件系统中读出所有 checkpoint 拼接出完整的模型参数和优化器参数。除了模型参数和优化器参数,其他的信息会由 rank 0 进行保存。

使用方法:

from hfai.checkpoint import save, load

model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
others = {'epoch': epoch, 'step': step+1}
save('latest.pt', model, optimizer, others=others)

state = load('latest.pt', map_location='cpu')
epoch, step = state['epoch'], state['step']
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])

集群上实际测试性能如下:

0.png

可以看到,随着并行程度的增加,使用hfai.checkpoint进行模型保存加载的耗时越来越少。

自动断点训练

对于幻方萤火集群上的训练,需要接收集群统一的调度信号进行训练任务的管理。这里幻方AI提供了一个 hfai.checkpoint.init 函数帮助用户进行断点训练,该函数会自动加载上次保存的模型、优化器等状态,返回上次训练的epoch和step,我们可以通过epoch和step进行优雅断点训练。同时,我们会向model注册一个成员函数 try_save ,通过这个函数可以在打断训练之前保存训练的状态。

使用方法:

model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

start_epoch, start_step, others = hfai.checkpoint.init(model, optimizer, ckpt_path='latest.pt')
for epoch in range(start_epoch, epochs):
    for step, (x, y) in enumerate(dataloader):
        if step < start_step:
            continue
        
        output = model(x)
        loss_fn(y, output).backward()
        model.try_save(epoch, step, others=None)

通过上述封装,您可以在代码中省去很多断点操作,简单方便地将代码适配幻方萤火的超算系统,进一步降低门槛。


本文作者: High-Flyer


您可以转载、不违背作品原意地摘录及引用本技术博客的内容,但必须遵守以下条款: 署名 — 您应当署名原作者,但不得以任何方式暗示幻方为您背书,亦不会对幻方的权利造成任何负面影响。 非商业性使用 — 您不得将本技术博客内容用于商业目的。 禁止演绎 — 如果基于该内容改编、转换、或者再创作,您不得公开或分发被修改内容,该内容仅可供个人使用。