LLaVA这类多模态数据集的构建和文本大模型的构建基本一致。

区别之处在于:

  • 1 tokenizer新增一个<image> token值。
  • 2 dataset需要新增“pixel_values”的数据。

1 Dataset

Dataset需要注意的问题:

  • 1 最好图像和文本的编码在Dataset中做,这样在Collator中可以直接获取batch的最大length。
  • 2 Dataset必须实现的函数
    • 2.1 __init__()
    • 2.2 __len__()
    • 2.3 __getitem__()
import json
import torch
import os.path as osp
from PIL import Image
from torch.utils.data import Dataset


class LlavaDataset(Dataset):
    def __init__(self, processor, data_path, image_dir, ignore_index):
        super().__init__()
        self.image_dir = image_dir
        self.processor = processor
        self.ignore_index = ignore_index
        self.data_list = json.load(open(data_path, "r"))

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, item):
        data = self.data_list[item]
        conversations = data["conversations"]
        image_path = osp.join(self.image_dir, data["image"])
        human = conversations[0]["value"]
        gpt = conversations[1]["value"]
        human_input_ids = self.processor.tokenizer(human)["input_ids"]
        gpt_input_ids = self.processor.tokenizer(gpt)["input_ids"]
        input_ids = human_input_ids + gpt_input_ids
        attention_mask = [1] * len(input_ids)
        labels = [self.ignore_index] * len(human_input_ids) + gpt_input_ids
        pixel_values = self.processor.image_processor(Image.open(image_path))["pixel_values"][0]
        return (input_ids, attention_mask, labels, pixel_values)

2 DataCollator

Collator的标准流程:

  • 1 获取batch中的max_seq_len与我们设置的最大max_length中取min。
  • 2 处理batch中的list数据,变成符合处理形式的tensor数据格式。
    • 2.1 获取input_ids, attention_mask, labels, pixel_values
    • 2.2 padding
    • 2.3 trauncate
    • 2.4 batch
    • 2.5 tensor
import torch


class LlaVAForTrainCollator:
    def __init__(self, max_length, ignore_index, pad_token_id):
        self.max_length = max_length
        self.ignore_index = ignore_index
        self.pad_token_id = pad_token_id

    def __call__(self, batch):
        batch_max_length = [len(sampler[0]) for sampler in batch]
        final_max_length = min(max(batch_max_length), self.max_length)
        batch_input_ids, batch_attention_mask, batch_labels, batch_pixel_values = [], [], [], []
        for sampler in batch:
            # 1 input_ids, attention_mask, labels, pixel_values
            input_ids, attention_mask, labels, pixel_values = sampler
            # 2 padding
            padding_len = final_max_length - len(input_ids)
            input_ids = [self.pad_token_id] * padding_len + input_ids
            attention_mask = [0] * padding_len + attention_mask
            labels = [self.ignore_index] * padding_len + labels
            # 3 truncate
            input_ids = input_ids[:self.max_length]
            attention_mask = attention_mask[:self.max_length]
            labels = labels[:self.max_length]
            # 4 batch
            batch_input_ids.append(input_ids)
            batch_attention_mask.append(attention_mask)
            batch_labels.append(labels)
            batch_pixel_values.append(pixel_values)
        # 5 tensor
        batch_input_ids = torch.tensor(batch_input_ids, dtype=torch.long)
        batch_attention_mask = torch.tensor(batch_attention_mask, dtype=torch.long)
        batch_labels = torch.tensor(batch_labels, dtype=torch.long)
        batch_pixel_values = torch.tensor(batch_pixel_values, dtype=torch.float)

        return {
            "input_ids": batch_input_ids, "attention_mask": batch_attention_mask,
            "labels": batch_labels, "pixel_values": batch_pixel_values
        }