LoRA是主流的微调方法,但是还有一部分Prompt-Tuning的微调方法,例如Prompt-Tuning, P-Tuning, Prefix-Tuning and P-TuningV2。
实际上由于Peft库的良好封装,在使用Prompt-Tuning的微调方法时,基本上只需要替换Model的载入方式。
四种微调方式的Config中共同的参数:
- 1 task_type = TaskType.CASUAL_LM
- 2 num_virtual_tokens=xx
1 Prompt-Tuning
from peft import PromptTuningConfig, PromptTuningInit, TaskType, get_peft_model
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", trust_remote_code=True)
peft_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8, prompt_tuning_init=PromptTuningInit.TEXT,
prompt_tuning_init_text="Classify if the tweet is a complaint or not:", tokenizer_name_or_path=path
)
model = get_peft_model(model, peft_config)
2 P-Tuning
from peft import get_peft_model, PromptEncoderConfig, TaskType, PromptEncoderReparameterizationType
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", trust_remote_code=True)
peft_config = PromptEncoderConfig(
num_virtual_tokens=20, encoder_hidden_size=128, task_type=TaskType.CAUSAL_LM,
encoder_reparameterization_type=PromptEncoderReparameterizationType.LSTM
)
model = get_peft_model(model, peft_config)
3 Prefix-Tuning
from peft import get_peft_model, PrefixTuningConfig, TaskType
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", trust_remote_code=True)
peft_config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30, prefix_projection=True
)
model = get_peft_model(model, peft_config)
4 P-TuningV2
P-TuningV2的实现和Prefix-Tuning基本一致,除了P-TuningV2不需要MLP进行重参数化。
from peft import get_peft_model, PrefixTuningConfig, TaskType
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", trust_remote_code=True)
peft_config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30
)
model = get_peft_model(model, peft_config)
5 完整代码
基于P-TuningV2的完整微调代码:
import torch
import warnings
from tqdm import tqdm
warnings.filterwarnings(action="ignore")
from transformers import AutoTokenizer, AutoModelForCausalLM, default_data_collator, get_linear_schedule_with_warmup
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from peft import get_peft_model, PrefixTuningConfig, TaskType
text_column = "Tweet text"
label_column = "text_label"
max_length = 64
lr = 3e-4
num_epochs = 10
batch_size = 8
path = "bigscience/bloomz-560m"
dataset_name = "twitter_complaints"
dataset = load_dataset("ought/raft", dataset_name)
classes = dataset["train"].features["Label"].names
dataset = dataset.map(
lambda x : {"text_label":[classes[label] for label in x["Label"]]},
batched=True,
num_proc=1
)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
def preprocess_function(examples):
batch_size = len(examples[text_column])
inputs = [f"{text_column} is {input} and label is " for input in examples[text_column]]
labels = [str(target) for target in examples[label_column]]
inputs = tokenizer(inputs)
labels = tokenizer(labels)
# 拼接inputs和labels
for i in range(batch_size):
sample = inputs["input_ids"][i]
inputs["input_ids"][i] = inputs["input_ids"][i] + labels["input_ids"][i] + [tokenizer.pad_token_id]
labels["input_ids"][i] = [-100] * len(sample) + labels["input_ids"][i] + [tokenizer.pad_token_id]
inputs["attention_mask"][i] = [1] * len(inputs["input_ids"][i])
# truncate and padding
for i in range(batch_size):
sample = inputs["input_ids"][i]
padding_len = (max_length - len(sample))
inputs["input_ids"][i] = torch.tensor(
([tokenizer.pad_token_id] * padding_len + inputs["input_ids"][i])[:max_length]
)
labels["input_ids"][i] = torch.tensor(
([-100] * padding_len + labels["input_ids"][i])[:max_length]
)
inputs["attention_mask"][i] = torch.tensor(
([0] * padding_len + inputs["attention_mask"][i])[:max_length]
)
# if i == 0:
# print(f"input input_ids is", inputs["input_ids"][i].shape)
# print(f"input attention_mask is", inputs["attention_mask"][i].shape)
# print(f"label input_ids is", labels["input_ids"][i].shape)
inputs["labels"] = labels["input_ids"]
# print(inputs["labels"])
return inputs
processed_datasets = dataset.map(
preprocess_function,
batched=True,
num_proc=1,
remove_columns=dataset["train"].column_names,
load_from_cache_file=False,
desc="Running tokenizer on dataset",
)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["train"]
# 训练与评估使用同一份数据,但是训练数据打乱
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)
# print(len(train_dataloader))
# print(len(eval_dataloader))
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", trust_remote_code=True)
peft_config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30
)
model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())
optimizer = AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer, num_warmup_steps=0, num_training_steps=(len(train_dataloader)*num_epochs)
)
model = model.cuda()
for epoch in range(num_epochs):
model.train()
total_loss = 0
for step, batch in enumerate(tqdm(train_dataloader)):
batch = {k : v.cuda() for k, v in batch.items()}
# print(f"input_ids shape is {batch['input_ids']}")
# print(f"attention_mask shape is {batch['attention_mask']}")
# print(f"labels shape is {batch['labels']}")
optimizer.zero_grad()
outputs = model(**batch)
# print(outputs)
loss = outputs.loss
print(f"loss is {loss}")
total_loss += loss.detach().float()
loss.backward()
optimizer.step()
lr_scheduler.step()
model.eval()
eval_loss = 0
eval_preds = []
for step, batch in enumerate(tqdm(eval_dataloader)):
batch = {k: v.cuda() for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
eval_loss += loss.detach().float()
eval_preds.extend(
tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
skip_special_tokens=True)
)
eval_epoch_loss = eval_loss / len(eval_dataloader)
eval_ppl = torch.exp(eval_epoch_loss)
train_epoch_loss = total_loss / len(train_dataloader)
train_ppl = torch.exp(train_epoch_loss)
print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")
# epoch=9: train_ppl=tensor(2.5512e+22, device='cuda:0') train_epoch_loss=tensor(51.5934, device='cuda:0')
# eval_ppl=tensor(9.9348e+21, device='cuda:0') eval_epoch_loss=tensor(50.6503, device='cuda:0')