LLaVA这类多模态数据集的构建和文本大模型的构建基本一致。
区别之处在于:
- 1 tokenizer新增一个<image> token值。
- 2 dataset需要新增“pixel_values”的数据。
1 Dataset
Dataset需要注意的问题:
- 1 最好图像和文本的编码在Dataset中做,这样在Collator中可以直接获取batch的最大length。
- 2 Dataset必须实现的函数
- 2.1 __init__()
- 2.2 __len__()
- 2.3 __getitem__()
import json
import torch
import os.path as osp
from PIL import Image
from torch.utils.data import Dataset
class LlavaDataset(Dataset):
def __init__(self, processor, data_path, image_dir, ignore_index):
super().__init__()
self.image_dir = image_dir
self.processor = processor
self.ignore_index = ignore_index
self.data_list = json.load(open(data_path, "r"))
def __len__(self):
return len(self.data_list)
def __getitem__(self, item):
data = self.data_list[item]
conversations = data["conversations"]
image_path = osp.join(self.image_dir, data["image"])
human = conversations[0]["value"]
gpt = conversations[1]["value"]
human_input_ids = self.processor.tokenizer(human)["input_ids"]
gpt_input_ids = self.processor.tokenizer(gpt)["input_ids"]
input_ids = human_input_ids + gpt_input_ids
attention_mask = [1] * len(input_ids)
labels = [self.ignore_index] * len(human_input_ids) + gpt_input_ids
pixel_values = self.processor.image_processor(Image.open(image_path))["pixel_values"][0]
return (input_ids, attention_mask, labels, pixel_values)
2 DataCollator
Collator的标准流程:
- 1 获取batch中的max_seq_len与我们设置的最大max_length中取min。
- 2 处理batch中的list数据,变成符合处理形式的tensor数据格式。
- 2.1 获取input_ids, attention_mask, labels, pixel_values
- 2.2 padding
- 2.3 trauncate
- 2.4 batch
- 2.5 tensor
import torch
class LlaVAForTrainCollator:
def __init__(self, max_length, ignore_index, pad_token_id):
self.max_length = max_length
self.ignore_index = ignore_index
self.pad_token_id = pad_token_id
def __call__(self, batch):
batch_max_length = [len(sampler[0]) for sampler in batch]
final_max_length = min(max(batch_max_length), self.max_length)
batch_input_ids, batch_attention_mask, batch_labels, batch_pixel_values = [], [], [], []
for sampler in batch:
# 1 input_ids, attention_mask, labels, pixel_values
input_ids, attention_mask, labels, pixel_values = sampler
# 2 padding
padding_len = final_max_length - len(input_ids)
input_ids = [self.pad_token_id] * padding_len + input_ids
attention_mask = [0] * padding_len + attention_mask
labels = [self.ignore_index] * padding_len + labels
# 3 truncate
input_ids = input_ids[:self.max_length]
attention_mask = attention_mask[:self.max_length]
labels = labels[:self.max_length]
# 4 batch
batch_input_ids.append(input_ids)
batch_attention_mask.append(attention_mask)
batch_labels.append(labels)
batch_pixel_values.append(pixel_values)
# 5 tensor
batch_input_ids = torch.tensor(batch_input_ids, dtype=torch.long)
batch_attention_mask = torch.tensor(batch_attention_mask, dtype=torch.long)
batch_labels = torch.tensor(batch_labels, dtype=torch.long)
batch_pixel_values = torch.tensor(batch_pixel_values, dtype=torch.float)
return {
"input_ids": batch_input_ids, "attention_mask": batch_attention_mask,
"labels": batch_labels, "pixel_values": batch_pixel_values
}