大模型的基本上分为以下六部分:

1 载入模型、数据、训练参数

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="model/llava")
    attn_implementation: str = "flash_attention_2"
    torch_dtype: str = torch.float16


@dataclass
class DataArguments:
    data_path: Optional[str] = field(default="model/data")
    image_dir: Optional[str] = field(default="model/images")
    ignore_index: int = -100
    max_length: int = 1024


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    optim: str = field(default="adamw_torch")
    model_save_dir: str = ""
    bits: int = field(default=16)
    double_quant: bool = field(default=True)
    quant_type: str = "nf4"
    compute_type: str = "bf16"
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"

parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

2 载入大模型和Tokenzier

2.1 载入大模型

大模型的载入分为5个部分:

  • 1 载入量化参数
  • 2 载入LLM
  • 3 量化稳定性配置
  • 4 输入获取参数
  • 5 LoRA配置

大模型载入代码如下:

# 1 LLM
    # 1.1 载入量化参数
    bnb_model_for_training_args = {}
    if training_args.bits in [4, 8]:
        from transformers import BitsAndBytesConfig
        bnb_model_for_training_args.update(dict(
            device_map={"": os.environ.get("LOCAL_RANK", 0)},
            load_in_4bit=(training_args.bits == 4),
            load_in_8bit=(training_args.bits == 8),
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=training_args.bits == 4,
                load_in_8bit=training_args.bits == 8,
                llm_int8_skip_modules=["mm_projector"],
                llm_int8_threshold=6.0,
                llm_int8_has_fp16_weight=False,
                bnb_4bit_compute_dtype=training_args.compute_dtype,
                bnb_4bit_use_double_quant=training_args.double_quant,
                bnb_4bit_quant_type=training_args.quant_type  # {'fp4', 'nf4'}
            )
        ))
    # 1.2 载入LLM
    model = LlavaForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path,
        torch_dtype=model_args.torch_dtype,
        attn_implementation=model_args.attn_implementation,
        device_map={"": int(os.environ.get("LOCAL_RANK", 0))},
        local_files_only=True,
        **bnb_model_for_training_args
    )
    model.config.use_cache=False
    # 1.3 量化稳定性设置
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
    # 1.4 获取输入数据的梯度
    if training_args.gradient_checkpointing:
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
    # 1.5 载入LoRA
    if training_args.lora_enable:
        from peft import LoraConfig, get_peft_model
        lora_config = LoraConfig(
            r=training_args.lora_r,
            lora_alpha=training_args.lora_alpha,
            target_modules=find_all_linear_names(model),
            lora_dropout=training_args.lora_dropout,
            bias=training_args.lora_bias,
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lora_config)

2.2 载入Processor

processor = LlavaProcessor.from_pretrained(
        model_args.model_name_or_path,
        device_map={"": os.environ.get("LOCAL_RANK", 0)},
        local_files_only=True,
    )

3 载入Dataset和Collator

train_dataset = LlavaDataset(processor, data_args.data_path, data_args.image_dir, data_args.ignore_index)
collator = LlaVAForTrainCollator(data_args.max_length, data_args.ignore_index, processor.tokenizer.pad_token_id)

4 载入Trainer

trainer = transformers.Trainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=None,
        data_collator=collator,
        args=training_args
    )
    trainer.train()
    # trainer.save_state()
    trainer.save_model(output_dir=training_args.model_save_dir)

5 DeepSpeed配置

DeepSpeed的启动命令如下:

deepspeed  --include localhost:2,3 llava/Trainer.py \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path ./model/llava \
    --data_path /public/MountData/xcx/mllm/pretrain/LLaVA/blip_laion_cc_sbu_558k.json \
    --image_dir /public/MountData/xcx/mllm/pretrain/LLaVA/images \
    --bf16 True \
    --model_save_dir ./model/qwen1.5-0.5b-pretrain \
    --output_dir ./checkpoints/qwen1.5-0.5b-pretrain \
    --num_train_epochs 1 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 4 \
    --save_strategy "steps" \
    --optim "adamw_torch" \
    --adam_beta1 0.9 \
    --adam_beta2 0.95 \
    --ddp_find_unused_parameter False \
    --save_steps 1000 \
    --save_total_limit 2 \
    --learning_rate 1e-3 \
    --weight_decay 0. \
    --warmup_ratio 0.05 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --logging_dir ./logs \
    --report_to tensorboard