萤火跑模型 | CLIP-GEN 无需文本训练即可文字生成图像

小杞    July 22, 2022

随着Transformer模型的发展,近些年多模态模型获得了长足的发展,使得不同任务不同领域可以实现特征的打通,变换出很多新奇好玩的场景。其中非常热门的就是让AI学会看文作图,即文字生成图像,如OpenAI的CLIP模型,其基于带文字的图像数据集上,训练出很惊艳的效果。

然而,收集这些带文字的图像数据集成本非常高。最近字节在Arxiv上发表了一项文本生成图像 (text2img) 的工作,其利用对抗网络GAN改造CLIP模型,使得 CLIP-GEN 可以不依赖带文字描述的图片数据集,直接使用无文本图像数据集进行训练,通过预训练好的 CLIP 模型建立起文本和图像的映射关系。在很多实验数据中表明,它的效果比VQGAN-CLIP要真实,尤其是泛化能力还比不少用大量“文本-图像”数据对训练出来的模型要好很多。

幻方AI最近复现了该项工作,并通过幻方自研的 3FShfreduce算子,对模型训练和推导进行优化。我们在hfai数据仓库中开源了训练数据,模型代码,旨在帮助研究者和开发者们降低研究门槛。

论文标题:CLIP-GEN: Language-Free Training of a Text-to-Image Generator with CLIP

论文地址https://arxiv.org/abs/2203.00386

模型仓库https://github.com/HFAiLab/clip-gen

模型介绍

下图是 CLIP-GEN 模型的整体结构:

arch

CLIP-GEN 模型主要由两部分组成,第一部分是一个 VQGAN 模型,用来学习如何把图像编码成一系列的图像标记(image tokens),通过这些图像标记解码还原成一张图片;第二部分是一个 condition transformer 模型,用来学习如何把文字的 CLIP embedding 映射到图像标记(image tokens)中。

训练、推理的过程分为三步:

  1. 预训练 VQGAN:输入图像,把图像编码到码本空间,然后再从码本空间解码为图像。经过预训练之后我们就能够把图像表示成一个图像标记(image tokens) 下的离散序列。
  2. 训练 condition transformer:通过第一步的预训练,我们已经能够通过图像标记(image tokens) 来生成图片了,在此基础上,我们训练一个 condition transformer,其旨在学习如何把图像的 CLIP embedding 映射到图像标记。训练的过程中 VQGAN 的参数保持不变。
  3. 文字生成图片:由于在 CLIP 中,文字和图像共享同一个嵌入空间,我们可以直接把文字的 CLIP embedding 作为 condition transformer 的输入映射到图像标记上,然后通过 VQGAN 的 decoder 来生成图片

数据集

对于 CLIP-GEN 的训练,我们采用了 COCO Caption 数据集,包含 20 万张图文对(训练的过程中没有使用文本)。我们把 COCO Caption 数据集转换成 ffrecord 格式,整合到了 hfai 数据仓库中,可以直接通过以下方式直接使用:

import hfai

dataset = hfai.datasets.CocoCaption(split='train', transform=transform)

有关更多内容,可以访问 hfai 官方文档:https://doc.hfai.high-flyer.cn/index.html

模型训练与优化加速

幻方AI复现了GLIP-GEN模型,并验证其效果。通过幻方自研的 3FShfreduce算子等优化工具,对模型训练和推导进行优化和加速,具体的包括:

  • hfai ddp:采用 hfreduce 优化多机多卡通信
  • hfai nn: 重构深度学习算子,提升性能
  • hfai datasets: 采用高效数据样本格式 ffrecord,充分发挥 3FS 存储带宽性能

下面进行详细描述。

hfai DDP 通信加速

hfai DDP 内部采用了幻方自研的 hfreduce 高性能通讯框架,能有效提升模型的训练速度,使用方法只需要修改一行代码:

# from torch.nn.parallel import DistributedDataParallel
from hfai.nn.parallel import DistributedDataParallel

# ...... initialize model

model = DistributedDataParallel(model, device_ids=[local_rank])

hfai nn 算子加速

为了进一步提升模型训练速度,我们可以使用 hfai.nn 里的高性能算子,相比于 PyTorch 能带来明显的提升,使用方法只需要增加一行代码:

import hfai

model = hfai.nn.to_hfai(model)  # 自动替换为 hfai 高性能算子

使用说明

幻方AI将所复现的模型和优化的方法都进行了开源,统一归集到 hfai 模型仓库 (https://github.com/HFAiLab/hfai-models) 中,欢迎大家来 star。

  1. 下载 CLIP 预训练模型:下载 CLIP 后放至 pretrained/clip_vit_b32.pt,该预训练模型来自 OpenAI.

  2. 在 COCO 上训练 VQGAN:通过 hfai python 提交任务至萤火集群

    hfai python train_vqgan.py --ds coco -- -n 1 -p 30
  3. 在 COCO 上训练 Conditional GPT:通过 hfai python 提交任务至萤火集群

    hfai python train_gpt.py --ds coco --vqgan_ckpt /path/to/vqgan/ckpt -- -n 4 -p 30

训练结果

我们来看看训练完成后,一些文本生成图像的效果。

tower

bus

train

可以看到,不利用带文本的图像数据集,CLIP-GEN 所生成的效果还是非常逼真的。

体验总结

CLIP-GEN 将对抗网络 GAN 用于改造 CLIP 模型,使得 CLIP-GEN 可以不依赖带文字描述的图片数据集,直接使用无文本图像数据集进行训练,这极大降低了数据收集的成本,推动了该领域研究的发展。通过预训练好的 CLIP 模型建立起文本和图像的映射关系,在很多实验数据中表明,CLIP-GEN 的效果比 VQGAN-CLIP 要真实,尤其是泛化能力还比不少用大量“文本-图像”数据对训练出来的模型要好很多。

综合体验打分如下:

  1. 研究指数:★★★★

    该模型是多模态领域的最新研究成果,降低了数据收集的成本,推动了该领域的发展。

  2. 开源指数:★★★

    代码没有开源,但所依赖的方法有其他开源版本,容易复现。

  3. 门槛指数:★★★

    数据规模大,模型适中,适合多级多卡数据并行训练。一般单卡训练难度比较大。

  4. 通用指数:★★★★

    该方法适用于多模态研究场景,能在很多类似场景下应用。

  5. 适配指数:★★★★★

    依赖简单,很容易与幻方AI的训练优化工具结合,提效明显。

幻方 AI 紧跟 AI 研究的前沿浪潮,致力于用领先算力助力AI落地与价值创造,欢迎各方数据研究者与开发者们一同共建。


本文作者: 小杞


您可以转载、不违背作品原意地摘录及引用本技术博客的内容,但必须遵守以下条款: 署名 — 您应当署名原作者,但不得以任何方式暗示幻方为您背书,亦不会对幻方的权利造成任何负面影响。 非商业性使用 — 您不得将本技术博客内容用于商业目的。 禁止演绎 — 如果基于该内容改编、转换、或者再创作,您不得公开或分发被修改内容,该内容仅可供个人使用。