目前SFT领域如火如荼,出现全量微调,LoRA,QLoRA等微调方法,同时可以伴随着8bit和4bit的量化,最后分布式训练框架DeepSpeed ZeRO优化器可以选择不同的模式,这些都决定了最后的计算资源的消耗。

Notes:

  • 1 LoRA和QLoRA能显著降低SFT过程中的显存消耗。
  • 2 量化只是减少了model加载的显存消耗,同时还引入了量化和反量化的操作,增加了训练时间。
  • 4 ZeRO优化器提供数据并行和模型并行,ZeRO-0为DDP,其他的数字每增加则显存消耗降低,训练时间增加。

我们来探索一下上述选项的组合吧,方便从业人员就本身的计算资源选择合适的组合方式。

1 实验准备

模型我们采用LLaVA,由1.6B的CLIP-large-336和7B的Qwen1.5-7B模型组成,batch_size_per_device=8。

以下我们来列举可供选择的选项:

  • 微调方法: FT / LoRA / QLoRA
  • 量化: 8bit  / 4bit
  • ZeRO优化器: 0 / 1 / 2 / 3

2 Experiment

2.1 FT

2.1.1 FT / ZeRO-0

出现OOM。

  • 1 模型参数: 17G
  • 2 反向传播梯度:17G
  • 3 优化器:
    • fp32梯度:34G
    • 一阶矩:34G
    • 二阶矩:34G
    • 拷贝模型参数fp32:34G

故FT情况下至少需要17*20 = 340G。

2.1.2 FT / ZeRO-3

出现OOM。

ZeRO-3会分布式的分配模型参数,反向梯度和优化器。

此处我们这里只有2台A100 40G,则每一台机器需要的显存为:(模型参数+反向传播梯度+优化器(fp32梯度,一阶矩,二阶矩))=284/2=142G。

2.2 LoRA

2.2.1 LoRA / ZeRO-0

batch_size_per_device: 1

显存消耗:22G

计算时间:36h

batch_size_per_device: 4

显存消耗:30G

计算时间:24h

我们来探究一下为什么batch_size_per_device从1变成4,显存增加了8G。

trainable params: 170,196,992 || all params: 8,216,008,704 || trainable%: 2.071528866773703

  • 1 由于我们设置了gradient_checkpoint,故中间激活值不进行存储,临时计算。
  • 2 LoRA可训练的参数为0.17b,则反向传播的显存增加为2*0.17*batch_size=2*0.17*3=1.02G。
  • 2 优化器显存增加:12*0.17*batch_size=12*0.17*3=6.12G。
  • 3 其他显存(Pytorch或者CUDA消耗)

如果显存足够的话,LoRA+无量化+ZeRO-0是训练速度最快的组合。