大模型的参数载入的方式非常多种,例如json,yaml,dict,class。

此处介绍一种配合DeepSpeed启动非常便携的参数载入方式:

  • 利用transformers.HfArugument同时解析model,data,training的参数。

1 参数载入说明

具体实现方式如下:

  • 1 分别将model, data, training的参数定义为class,方便下面写代码的时候IDE可以直接利用属性进行获取。
  • 2 使用transformers.HfArugument

该方式的好处:

  • 引入参数的时候不需要指定顺序,transformers.HfArugument会自动根据参数解析到对应的类中。

2 配置类定义

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="model/llava")
    attn_implementation: str = "flash_attention_2"
    torch_dtype: str = torch.float16


@dataclass
class DataArguments:
    data_path: Optional[str] = field(default="model/data")
    image_dir: Optional[str] = field(default="model/images")
    ignore_index: int = -100
    max_length: int = 1024


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    optim: str = field(default="adamw_torch")
    model_save_dir: str = ""
    bits: int = field(default=16)
    double_quant: bool = field(default=True)
    quant_type: str = "nf4"
    compute_type: str = "bf16"
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"

3 参数载入

#!/bin/bash

deepspeed  --include localhost:2,3 llava/Trainer.py \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path ./model/llava \
    --data_path /public/MountData/xcx/mllm/pretrain/LLaVA/blip_laion_cc_sbu_558k.json \
    --image_dir /public/MountData/xcx/mllm/pretrain/LLaVA/images \
    --bf16 True \
    --model_save_dir ./model/qwen1.5-0.5b-pretrain \
    --output_dir ./checkpoints/qwen1.5-0.5b-pretrain \
    --num_train_epochs 1 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 4 \
    --save_strategy "steps" \
    --optim "adamw_torch" \
    --adam_beta1 0.9 \
    --adam_beta2 0.95 \
    --ddp_find_unused_parameter False \
    --save_steps 1000 \
    --save_total_limit 2 \
    --learning_rate 1e-3 \
    --weight_decay 0. \
    --warmup_ratio 0.05 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --logging_dir ./logs \
    --report_to tensorboard