本文根据nanoGPT的思路,使用OpenText数据集预训练一个GPT2.
在介绍nanoGPT之前,我们介绍一些前置知识。
1 前置知识
在介绍GPT-2的预训练过程时,我们先介绍一下pytorch的DDP分布式训练和torch的AMP混合训练的使用。
1.1 AMP
amp的使用分为以下几部:
- 1 混合精度:model(x)在autocast的作用域下进行
- 2 损失放缩:loss需要scaler进行放缩
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
x = torch.rand(size=(1000, 64)).cuda()
y = torch.randint(0, 2, size=(1000, )).cuda()
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, 2)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = Model().cuda()
critition = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
scaler = GradScaler()
grad_clip = 1.0
for step in range(1000):
optimizer.zero_grad()
with autocast():
output = model(x)
loss = critition(output, y)
scaler.scale(loss).backward()
if grad_clip:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
scaler.update()
print(f"step is {step} loss is {loss}")
1.2 DDP
ddp分布式训练,教程可以查看:
具体步骤如下:
- 1 初始化init_process_group,设置torch.cuda.set_device(local_rank)
- 2 模型ddp加载
- 3 dataloader使用分布式sampler
- 4 训练时,samlpler设置set_epoch,数据需要设置对应的local_rank,loss需要设置reduce
- 5 评估时,loss需要设置gather
- 6 销毁 destory_process_group
DDP的启动:
- 1 使用torch.multiprocess.spawn启动,需要指定主机的地址和端口号信息。
- 2 torchrun启动,所有信息全部从环境变量中获取
torch.multiprocess.spawn启动:python ddp_gpus.py –max_epochs 5 –batch_size 32
import os, sys
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
def ddp_setup(rank, world_size):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
# rank 0 process
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# nccl:NVIDIA Collective Communication Library
# 分布式情况下的,gpus 间通信
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
class Trainer:
def __init__(self,
model: torch.nn.Module,
train_dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
gpu_id: int) -> None:
self.gpu_id = gpu_id
self.model = model.to(gpu_id)
self.train_dataloader = train_dataloader
self.optimizer = optimizer
self.model = DDP(model, device_ids=[gpu_id])
def _run_batch(self, xs, ys):
self.optimizer.zero_grad()
output = self.model(xs)
loss = F.cross_entropy(output, ys)
loss.backward()
self.optimizer.step()
def _run_epoch(self, epoch):
batch_size = len(next(iter(self.train_dataloader))[0])
print(f'[GPU: {self.gpu_id}] Epoch: {epoch} | Batchsize: {batch_size} | Steps: {len(self.train_dataloader)}')
self.train_dataloader.sampler.set_epoch(epoch)
for xs, ys in self.train_dataloader:
xs = xs.to(self.gpu_id)
ys = ys.to(self.gpu_id)
self._run_batch(xs, ys)
def train(self, max_epoch: int):
for epoch in range(max_epoch):
self._run_epoch(epoch)
class MyTrainDataset(Dataset):
def __init__(self, size):
self.size = size
self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]
def __len__(self):
return self.size
def __getitem__(self, index):
return self.data[index]
def main(rank: int, world_size: int, max_epochs: int, batch_size: int):
ddp_setup(rank, world_size)
train_dataset = MyTrainDataset(2048)
train_dataloader = DataLoader(train_dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
# batch input: split to each gpus (且没有任何 overlaping samples 各个 gpu 之间)
sampler=DistributedSampler(train_dataset))
model = torch.nn.Linear(20, 1)
optimzer = torch.optim.SGD(model.parameters(), lr=1e-3)
trainer = Trainer(model=model, gpu_id=rank, optimizer=optimzer, train_dataloader=train_dataloader)
trainer.train(max_epochs)
destroy_process_group()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('--max_epochs', type=int, help='Total epochs to train the model')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, args.max_epochs, args.batch_size), nprocs=world_size)
ddp_run启动:
- 单卡:torchrun ddp_gpus_torchrun.py –max_epochs 5 –batch_size 32
- 多卡:
- torchrun –nproc-per-node=2 ddp_gpus_torchrun.py –max_epochs 5 –batch_size 32
- python -m torch.distributed.launch –use-env –nproc-per-node=2 ddp_gpus_torchrun.py –max_epochs 5 –batch_size 32
ddp_run的代码如下:
import os, sys
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
def ddp_setup():
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
# rank 0 process
# os.environ["MASTER_ADDR"] = "localhost"
# os.environ["MASTER_PORT"] = "12355"
# nccl:NVIDIA Collective Communication Library
# 分布式情况下的,gpus 间通信
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
class Trainer:
def __init__(self,
model: torch.nn.Module,
train_dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
) -> None:
self.gpu_id = int(os.environ['LOCAL_RANK'])
self.model = model.to(self.gpu_id)
self.train_dataloader = train_dataloader
self.optimizer = optimizer
self.model = DDP(model, device_ids=[self.gpu_id])
def _run_batch(self, xs, ys):
self.optimizer.zero_grad()
output = self.model(xs)
loss = F.cross_entropy(output, ys)
loss.backward()
self.optimizer.step()
def _run_epoch(self, epoch):
batch_size = len(next(iter(self.train_dataloader))[0])
print(f'[GPU: {self.gpu_id}] Epoch: {epoch} | Batchsize: {batch_size} | Steps: {len(self.train_dataloader)}')
self.train_dataloader.sampler.set_epoch(epoch)
for xs, ys in self.train_dataloader:
xs = xs.to(self.gpu_id)
ys = ys.to(self.gpu_id)
self._run_batch(xs, ys)
def train(self, max_epoch: int):
for epoch in range(max_epoch):
self._run_epoch(epoch)
class MyTrainDataset(Dataset):
def __init__(self, size):
self.size = size
self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]
def __len__(self):
return self.size
def __getitem__(self, index):
return self.data[index]
def main(max_epochs: int, batch_size: int):
ddp_setup()
train_dataset = MyTrainDataset(2048)
train_dataloader = DataLoader(train_dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
# batch input: split to each gpus (且没有任何 overlaping samples 各个 gpu 之间)
sampler=DistributedSampler(train_dataset))
model = torch.nn.Linear(20, 1)
optimzer = torch.optim.SGD(model.parameters(), lr=1e-3)
trainer = Trainer(model=model, optimizer=optimzer, train_dataloader=train_dataloader)
trainer.train(max_epochs)
destroy_process_group()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('--max_epochs', type=int, help='Total epochs to train the model')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()
# world_size = torch.cuda.device_count()
main(args.max_epochs, args.batch_size)
2 数据加载
终于到激动人心的时刻了,该项目对数据集的处理分为两步:- 1 huggingface dataset处理
- 2 numpy的memmap处理
2.1 huggingface dataset
经典3步曲:- 1 load_dataset载入数据集
- 2 dataset.train_test_split 切分训练集和测试集
- 3 dataset.map对数据集进行处理
2.2 numpy的memmap
numpy的处理如下:- 1 将numpy的memmap设置为w+模式
- 2 将opentext数据使用numpy.shared切分为1024份
- 3 array.flush
- 4 numpy.memmap设置为r模型进行读取
import tiktoken
import numpy as np
from tqdm import tqdm
from datasets import load_dataset, load_from_disk
import os
os.environ["HF_DATASETS_OFFLINE"]="1"
num_proc = 16
enc = tiktoken.get_encoding("gpt2")
def process(example):
ids = enc.encode_ordinary(example["text"])
ids.append(enc.eot_token)
out = {"ids":ids, "len":len(ids)}
return out
if __name__ == "__main__":
path = r"F:\openwebtext\openwebtext.py"
dataset = load_dataset(path, num_proc=num_proc)
split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
split_dataset["val"] = split_dataset.pop("test")
split_dataset = split_dataset.map(process, num_proc=num_proc, remove_columns=["text"])
for split, dataset in split_dataset.items():
array_len = np.sum(dataset["len"], dtype=np.uint64)
filename = "train.bin"
dtype = np.uint16
array = np.memmap(filename, dtype, "w+", shape=(array_len, ))
# 将数据集分成1024份进行载入
total_batch = 1024
idx = 0
for batch_idx in tqdm(range(total_batch)):
# 将数据集切分为1024份,同时数据格式变为numpy
batch = dataset.shard(num_shards=total_batch, index=batch_idx, contiguous=True).with_format("numpy")
# 将每一份数据中的特征进行拼接
array_batch = np.concatenate(batch["ids"])
# 载入memmap中
array[idx:idx+len(array_batch)] = array_batch
idx += len(array_batch)
array.flush()
def get_batch(data_dir, split):
if split == "train":
data = np.memmap(osp.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
else:
data = np.memmap(osp.join(data_dir, "valid.bin"), dtype=np.uint16, mode="r")
ix = torch.randint(len(data) - config_args.max_seq_len, (config_args.batch_size, ))
# increase a new dim on dim 0
x = torch.stack([torch.from_numpy(data[i:i+config_args.max_seq_len].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i+1:i+1+config_args.max_seq_len].astype(np.int64)) for i in ix])
if config_args.device == "cuda":
x, y = x.pin_memory().to(config_args.device, non_blocking=True), y.pin_memory().to(config_args.device, non_blocking=True)
else:
x, y = x.to(config_args.device), y.to(config_args.device)
return x, y
3 模型加载
模型在原本的GPT-2的基础上使用了LLAMA的一些技术创新点。 代码如下:import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import config
# TODO 2 Maybe is ok
# RoPE blogs : https://zhuanlan.zhihu.com/p/645263524
class RoPeEmbedding(nn.Module):
def __init__(self, config):
super(RoPeEmbedding, self).__init__()
def get_freqs_cis(self, d_model, max_seq_len, device):
# 获取弧度 (d_model/2, )
freqs = 1. / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
# 获取索引 (max_seq_len, )
t = torch.arange(0, max_seq_len)
# 得到全索引弧度 (max_seq_len, d_model/2)
freqs = torch.outer(t, freqs)
# 位置复数 (max_seq_len, d_model/2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis.to(device)
def forward(self, xq, xk):
xq_device = xq.device
max_seq_len, d_model = xq.shape[1], xq.shape[2]
# 将xq, xk转为复数
# (batch_size, max_seq_len, d_model) -> (batch_size, max_seq_len, d_model/2, 2)·
# -> (batch_size, max_seq_len, d_model/2)
xq_out = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_out = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 将xq, xk与旋转复数矩阵进行复数乘法
freqs_cis = self.get_freqs_cis(d_model, max_seq_len, xq_device)
# 将最后两个维度变成一维
xq_out = torch.view_as_real(xq_out * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_out * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
# TODO 3 Maybe is ok
class SwishGLU(nn.Module):
def __init__(self, config):
super(SwishGLU, self).__init__()
d_model, dropout = config.d_model, config.dropout
self.w1 = nn.Linear(d_model, d_model*2)
self.w3 = nn.Linear(d_model, d_model*2)
self.w2 = nn.Linear(d_model*2, d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x):
x = self.w2(F.silu(self.w1(x)) * self.w3(x))
x = self.drop(x)
return x
# TODO 4 Maybe is ok
class RMSNorm(nn.Module):
def __init__(self, config):
super(RMSNorm, self).__init__()
d_model, bias = config.d_model, config.bias
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None
self.eps = 1e-5
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
x = self.weight * self._norm(x) + self.bias
return x
class LayerNorm(nn.Module):
def __init__(self, config):
d_model, bias = config.d_model, config.bias
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None
def forward(self, x):
return F.layer_norm(x, normalized_shape=self.weight.shape, weight=self.weight, bias=self.bias, eps=1e-5)
class SelfAttention(nn.Module):
def __init__(self, config):
super(SelfAttention, self).__init__()
in_channel, out_channel = config.d_model, config.d_model//config.n_head
self.pe_encoder = config.pe_encoder
self.rope = RoPeEmbedding(config)
self.Wq = nn.Linear(in_channel, out_channel)
self.Wk = nn.Linear(in_channel, out_channel)
self.Wv = nn.Linear(in_channel, out_channel)
def forward(self, x):
q, k, v = self.Wq(x), self.Wk(x), self.Wv(x)
if self.pe_encoder == "rope":
q, k = self.rope(q, k)
attention = q @ k.transpose(-1, -2)
sqrt_hidden = torch.sqrt(torch.tensor(q.shape[-1], device=q.device))
mask = torch.tril(torch.ones(size=(q.shape[1], q.shape[1]), device=q.device), diagonal=0)
attention = (attention / sqrt_hidden).masked_fill(mask == 0, float("-inf"))
attention = F.softmax(attention, dim=-1)
value = attention @ v
return value
class MultiSelfAttention(nn.Module):
def __init__(self, config):
super(MultiSelfAttention, self).__init__()
d_model, n_head = config.d_model, config.n_head
self.Wo = nn.Linear(d_model, d_model)
self.multi_self_attention = nn.ModuleList([SelfAttention(config) for _ in range(n_head)])
def forward(self, x):
value = self.Wo(torch.cat([self_attention(x) for self_attention in self.multi_self_attention], dim=-1))
return value
class FFN(nn.Module):
def __init__(self, config):
super(FFN, self).__init__()
d_model, dropout = config.d_model, config.dropout
self.config = config
if config.ffn == "origin":
self.fc1 = nn.Linear(d_model, d_model * 4)
self.fc2 = nn.Linear(d_model * 4, d_model)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout)
elif config.ffn == "swishglu":
self.swishglu = SwishGLU(config)
def forward(self, x):
if self.config.ffn == "origin":
x = self.gelu(self.fc1(x))
x = self.dropout(self.fc2(x))
elif self.config.ffn == "swishglu":
x = self.swishglu(x)
return x
class TransformerLayer(nn.Module):
def __init__(self, config):
super(TransformerLayer, self).__init__()
if config.norm == "origin":
self.norm1 = LayerNorm(config)
self.norm2 = LayerNorm(config)
elif config.norm == "rms":
self.norm1 = RMSNorm(config)
self.norm2 = RMSNorm(config)
self.multi_self_attention = MultiSelfAttention(config)
self.ffn = FFN(config)
def forward(self, x):
x = x + self.multi_self_attention(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class GPT2(nn.Module):
def __init__(self, config):
super(GPT2, self).__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
vocab = nn.Embedding(config.vocab_size, config.d_model),
pe = nn.Embedding(config.max_seq_len, config.d_model) if config.pe_encoder == "origin" else None,
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([TransformerLayer(config) for _ in range(config.n_layer)]),
norm = LayerNorm(config) if config.norm == "origin" else RMSNorm(config)
))
self.lm_head = nn.Linear(config.d_model, config.vocab_size)
self.apply(self._init_weights)
for name, paramters in self.transformer.named_parameters():
if name.endswith("fc2.weight"):
torch.nn.init.normal_(paramters, 0, (0.02 / math.sqrt(2 * config.n_layer)))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, 0, 0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
if isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, 0, 0.02)
def forward(self, x, target=None):
token_emb = self.transformer.vocab(x)
if self.config.pe_encoder == "origin":
pos = torch.arange(0, x.shape[1], dtype=torch.long, device=x.device)
pos_emb = self.transformer.pe(pos)
h = self.transformer.drop(token_emb + pos_emb)
else:
h = self.transformer.drop(token_emb)
for block in self.transformer.h:
h = block(h)
h = self.transformer.norm(h)
if target is not None:
logits = self.lm_head(h)
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), target.view(-1), ignore_index=-1)
else:
logits = self.lm_head(h[:, -1, :])
loss = None
return logits, loss
def generate(self, idx, max_tokens, t=0.1, topk=None):
for _ in range(max_tokens):
logits = self(idx)
logit = logits[:, -1, :] / t
if topk is not None:
# vector index
v, _ = torch.topk(logit, min(topk, logit.shape[-1]))
logit[logit < v[:, [-1]]] = -float('inf')
probs = F.softmax(logit, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=-1)
return idx
# TODO 1 增加推理代码 K/V cache 参考LLaMA
# TODO 2 增加RoPE 参考LLaMa OK
# TODO 3 更换SwishGLU OK
# TODO 4 RMSNorm OK
# TODO 5 增加其他2种采样方法, top_p, beam_search
# TODO 6 FlashAttention 参考LLaMA-moe
if __name__ == "__main__":
config = config.config_args
gpt2 = GPT2(config).to("cuda")
num_parameters = np.sum([p.numel() for p in gpt2.parameters()]) / 1024/ 1024
print(f"paramters is {num_parameters} M")
x = torch.randint(0, 1000, size=(2, 10), dtype=torch.long).to("cuda")
print(f"x is {x}")
# output = gpt2.generate(x, config.max_tokens, t=0.1, topk=5)
# print(f"output is {output}")
logits, loss = gpt2(x)
print(logits)
4 训练过程
训练过程有上述基础知识打底,相信大家可以看懂。 代码如下:import os
import time
import math
import pickle
from tqdm import trange
import numpy as np
import torch
from utils import *
from model import *
from contextlib import nullcontext
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from config import config_args
ddp = int(os.environ.get("RANK", -1)) != -1
if ddp:
init_process_group(backend=config_args.backend)
rank = int(os.environ.get("RANK"))
local_rank = int(os.environ.get("LOCAL_RANK"))
world_size = int(os.environ.get("WORLD_SIZE"))
device = f"cuda:{local_rank}"
torch.cuda.set_device(device)
master_process = rank == 0
seed_offset = local_rank
config_args.gradient_accumulation_steps = config_args.gradient_accumulation_steps // world_size
else:
world_size = 1
master_process = True
seed_offset = 0
# 12 * 1024 * 5 * 8 = 0.5M token
token_per_iter = config_args.batch_size * config_args.max_seq_len * config_args.gradient_accumulation_steps * world_size
print(f"batch_size is {config_args.batch_size}")
print(f"max_seq_len is {config_args.max_seq_len}")
print(f"token per iter is {token_per_iter}")
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[config_args.dtype]
ctx = nullcontext() if config_args.device == "cpu" else torch.amp.autocast(device_type=config_args.device, dtype=ptdtype)
print(f"train type is {ptdtype}")
model = GPT2(config_args).to(config_args.device)
scaler = torch.cuda.amp.GradScaler(enabled=(ptdtype=="bfloat16" or ptdtype=="float16"))
use_fused = True if config_args.device == "cuda" else False
# TODO 不同学习率
optimizer = torch.optim.Adam(model.parameters(), lr=config_args.lr, betas=(config_args.beta1, config_args.beta2), fused=True,
weight_decay=config_args.weight_decay)
if config_args.compile:
model = torch.compile(model)
if ddp:
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
raw_model = model.module if ddp else model
# training loop
dataset_data_dir = get_dataset_path()
x, y = get_batch(dataset_data_dir, "train")
start_time = time.time()
best_val_loss = 1e9
for n_iter in trange(1, config_args.max_iters):
lr = get_lr(n_iter)
# evaluate
if n_iter % config_args.eval_n_iter == 0 and master_process:
loss_dict = estimate_loss(model, ctx, config_args.eval_iter)
if loss_dict["valid"] < best_val_loss:
best_val_loss = loss_dict["valid"]
if n_iter > 0:
checkpoint = {
"model": raw_model.state_dict(),
"optimzier": optimizer.state_dict(),
"config_args": config_args,
"n_iter": n_iter,
"best_val_loss": best_val_loss
}
print(f"saving checkpoint to {config_args.out_dir}")
torch.save(checkpoint, osp.join(config_args.out_dir, "ckpt.pt"))
for micro_step in range(config_args.gradient_accumulation_steps):
if ddp:
model.require_backward_grad_sync = (micro_step == config_args.gradient_accumulation_steps - 1)
with ctx:
logits, loss = model(x, y)
print(f"step is {n_iter} train loss is {loss}")
x, y = get_batch(dataset_data_dir, "train")
scaler.scale(loss).backward()
if config_args.grad_clip != 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config_args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if ddp:
destroy_process_group()