目前SFT领域如火如荼,出现全量微调,LoRA,QLoRA等微调方法,同时可以伴随着8bit和4bit的量化,最后分布式训练框架DeepSpeed ZeRO优化器可以选择不同的模式,这些都决定了最后的计算资源的消耗。
Notes:
- 1 LoRA和QLoRA能显著降低SFT过程中的显存消耗。
- 2 量化只是减少了model加载的显存消耗,同时还引入了量化和反量化的操作,增加了训练时间。
- 4 ZeRO优化器提供数据并行和模型并行,ZeRO-0为DDP,其他的数字每增加则显存消耗降低,训练时间增加。
我们来探索一下上述选项的组合吧,方便从业人员就本身的计算资源选择合适的组合方式。
1 实验准备
模型我们采用LLaVA,由1.6B的CLIP-large-336和7B的Qwen1.5-7B模型组成,batch_size_per_device=8。
以下我们来列举可供选择的选项:
- 微调方法: FT / LoRA / QLoRA
- 量化: 8bit / 4bit
- ZeRO优化器: 0 / 1 / 2 / 3
2 Experiment
2.1 FT
2.1.1 FT / ZeRO-0
出现OOM。
- 1 模型参数: 17G
- 2 反向传播梯度:17G
- 3 优化器:
- fp32梯度:34G
- 一阶矩:34G
- 二阶矩:34G
- 拷贝模型参数fp32:34G
故FT情况下至少需要17*20 = 340G。
2.1.2 FT / ZeRO-3
出现OOM。
ZeRO-3会分布式的分配模型参数,反向梯度和优化器。
此处我们这里只有2台A100 40G,则每一台机器需要的显存为:(模型参数+反向传播梯度+优化器(fp32梯度,一阶矩,二阶矩))=284/2=142G。
2.2 LoRA
2.2.1 LoRA / ZeRO-0
batch_size_per_device: 1
显存消耗:22G
计算时间:36h
batch_size_per_device: 4
显存消耗:30G
计算时间:24h
我们来探究一下为什么batch_size_per_device从1变成4,显存增加了8G。
trainable params: 170,196,992 || all params: 8,216,008,704 || trainable%: 2.071528866773703
- 1 由于我们设置了gradient_checkpoint,故中间激活值不进行存储,临时计算。
- 2 LoRA可训练的参数为0.17b,则反向传播的显存增加为2*0.17*batch_size=2*0.17*3=1.02G。
- 2 优化器显存增加:12*0.17*batch_size=12*0.17*3=6.12G。
- 3 其他显存(Pytorch或者CUDA消耗)
如果显存足够的话,LoRA+无量化+ZeRO-0是训练速度最快的组合。