本文根据nanoGPT的思路,使用OpenText数据集预训练一个GPT2.

在介绍nanoGPT之前,我们介绍一些前置知识。

1 前置知识

在介绍GPT-2的预训练过程时,我们先介绍一下pytorch的DDP分布式训练和torch的AMP混合训练的使用。

1.1 AMP

amp的使用分为以下几部:

  • 1 混合精度:model(x)在autocast的作用域下进行
  • 2 损失放缩:loss需要scaler进行放缩
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

x = torch.rand(size=(1000, 64)).cuda()
y = torch.randint(0, 2, size=(1000, )).cuda()


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Model().cuda()
critition = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
scaler = GradScaler()
grad_clip = 1.0

for step in range(1000):
    optimizer.zero_grad()
    with autocast():
        output = model(x)
        loss = critition(output, y)
    scaler.scale(loss).backward()
    if grad_clip:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()
    print(f"step is {step} loss is {loss}")

1.2 DDP

ddp分布式训练,教程可以查看:

具体步骤如下:

  • 1 初始化init_process_group,设置torch.cuda.set_device(local_rank)
  • 2 模型ddp加载
  • 3 dataloader使用分布式sampler
  • 4 训练时,samlpler设置set_epoch,数据需要设置对应的local_rank,loss需要设置reduce
  • 5 评估时,loss需要设置gather
  • 6 销毁 destory_process_group

DDP的启动:

  • 1 使用torch.multiprocess.spawn启动,需要指定主机的地址和端口号信息。
  • 2 torchrun启动,所有信息全部从环境变量中获取

torch.multiprocess.spawn启动:python ddp_gpus.py –max_epochs 5 –batch_size 32

import os, sys
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group


def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    # rank 0 process
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    # nccl:NVIDIA Collective Communication Library
    # 分布式情况下的,gpus 间通信
    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


class Trainer:
    def __init__(self,
                 model: torch.nn.Module,
                 train_dataloader: DataLoader,
                 optimizer: torch.optim.Optimizer,
                 gpu_id: int) -> None:
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_dataloader = train_dataloader
        self.optimizer = optimizer
        self.model = DDP(model, device_ids=[gpu_id])

    def _run_batch(self, xs, ys):
        self.optimizer.zero_grad()
        output = self.model(xs)
        loss = F.cross_entropy(output, ys)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        batch_size = len(next(iter(self.train_dataloader))[0])
        print(f'[GPU: {self.gpu_id}] Epoch: {epoch} | Batchsize: {batch_size} | Steps: {len(self.train_dataloader)}')
        self.train_dataloader.sampler.set_epoch(epoch)
        for xs, ys in self.train_dataloader:
            xs = xs.to(self.gpu_id)
            ys = ys.to(self.gpu_id)
            self._run_batch(xs, ys)

    def train(self, max_epoch: int):
        for epoch in range(max_epoch):
            self._run_epoch(epoch)


class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return self.data[index]


def main(rank: int, world_size: int, max_epochs: int, batch_size: int):
    ddp_setup(rank, world_size)

    train_dataset = MyTrainDataset(2048)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  pin_memory=True,
                                  shuffle=False,
                                  # batch input: split to each gpus (且没有任何 overlaping samples 各个 gpu 之间)
                                  sampler=DistributedSampler(train_dataset))
    model = torch.nn.Linear(20, 1)
    optimzer = torch.optim.SGD(model.parameters(), lr=1e-3)

    trainer = Trainer(model=model, gpu_id=rank, optimizer=optimzer, train_dataloader=train_dataloader)
    trainer.train(max_epochs)

    destroy_process_group()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('--max_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()

    world_size = torch.cuda.device_count()
    mp.spawn(main, args=(world_size, args.max_epochs, args.batch_size), nprocs=world_size)

ddp_run启动:

  • 单卡:torchrun ddp_gpus_torchrun.py –max_epochs 5 –batch_size 32
  • 多卡:
    • torchrun –nproc-per-node=2 ddp_gpus_torchrun.py –max_epochs 5 –batch_size 32
    • python -m torch.distributed.launch –use-env –nproc-per-node=2 ddp_gpus_torchrun.py –max_epochs 5 –batch_size 32

ddp_run的代码如下:

import os, sys
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group


def ddp_setup():
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    # rank 0 process
    #     os.environ["MASTER_ADDR"] = "localhost"
    #     os.environ["MASTER_PORT"] = "12355"
    # nccl:NVIDIA Collective Communication Library
    # 分布式情况下的,gpus 间通信
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ['LOCAL_RANK']))


class Trainer:
    def __init__(self,
                 model: torch.nn.Module,
                 train_dataloader: DataLoader,
                 optimizer: torch.optim.Optimizer,
                 ) -> None:
        self.gpu_id = int(os.environ['LOCAL_RANK'])
        self.model = model.to(self.gpu_id)
        self.train_dataloader = train_dataloader
        self.optimizer = optimizer
        self.model = DDP(model, device_ids=[self.gpu_id])

    def _run_batch(self, xs, ys):
        self.optimizer.zero_grad()
        output = self.model(xs)
        loss = F.cross_entropy(output, ys)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        batch_size = len(next(iter(self.train_dataloader))[0])
        print(f'[GPU: {self.gpu_id}] Epoch: {epoch} | Batchsize: {batch_size} | Steps: {len(self.train_dataloader)}')
        self.train_dataloader.sampler.set_epoch(epoch)
        for xs, ys in self.train_dataloader:
            xs = xs.to(self.gpu_id)
            ys = ys.to(self.gpu_id)
            self._run_batch(xs, ys)

    def train(self, max_epoch: int):
        for epoch in range(max_epoch):
            self._run_epoch(epoch)


class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return self.data[index]


def main(max_epochs: int, batch_size: int):
    ddp_setup()

    train_dataset = MyTrainDataset(2048)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  pin_memory=True,
                                  shuffle=False,
                                  # batch input: split to each gpus (且没有任何 overlaping samples 各个 gpu 之间)
                                  sampler=DistributedSampler(train_dataset))
    model = torch.nn.Linear(20, 1)
    optimzer = torch.optim.SGD(model.parameters(), lr=1e-3)

    trainer = Trainer(model=model, optimizer=optimzer, train_dataloader=train_dataloader)
    trainer.train(max_epochs)

    destroy_process_group()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('--max_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()

    #    world_size = torch.cuda.device_count()
    main(args.max_epochs, args.batch_size)

2 数据加载

终于到激动人心的时刻了,该项目对数据集的处理分为两步:
  • 1 huggingface dataset处理
  • 2 numpy的memmap处理

2.1 huggingface dataset

经典3步曲:
  • 1 load_dataset载入数据集
  • 2 dataset.train_test_split 切分训练集和测试集
  • 3 dataset.map对数据集进行处理

2.2 numpy的memmap

numpy的处理如下:
  • 1 将numpy的memmap设置为w+模式
  • 2 将opentext数据使用numpy.shared切分为1024份
  • 3 array.flush
  • 4 numpy.memmap设置为r模型进行读取
以下为数据处理代码:  
import tiktoken
import numpy as np
from tqdm import tqdm
from datasets import load_dataset, load_from_disk
import os
os.environ["HF_DATASETS_OFFLINE"]="1"
num_proc = 16
enc = tiktoken.get_encoding("gpt2")


def process(example):
    ids = enc.encode_ordinary(example["text"])
    ids.append(enc.eot_token)
    out = {"ids":ids, "len":len(ids)}
    return out


if __name__ == "__main__":
    path = r"F:\openwebtext\openwebtext.py"
    dataset = load_dataset(path, num_proc=num_proc)
    split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
    split_dataset["val"] = split_dataset.pop("test")
    split_dataset = split_dataset.map(process, num_proc=num_proc, remove_columns=["text"])

    for split, dataset in split_dataset.items():
        array_len = np.sum(dataset["len"], dtype=np.uint64)
        filename = "train.bin"
        dtype = np.uint16
        array = np.memmap(filename, dtype, "w+", shape=(array_len, ))
        # 将数据集分成1024份进行载入
        total_batch = 1024
        idx = 0
        for batch_idx in tqdm(range(total_batch)):
            # 将数据集切分为1024份,同时数据格式变为numpy
            batch = dataset.shard(num_shards=total_batch, index=batch_idx, contiguous=True).with_format("numpy")
            # 将每一份数据中的特征进行拼接
            array_batch = np.concatenate(batch["ids"])
            # 载入memmap中
            array[idx:idx+len(array_batch)] = array_batch
            idx += len(array_batch)
        array.flush()

def get_batch(data_dir, split):
    if split == "train":
        data = np.memmap(osp.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
    else:
        data = np.memmap(osp.join(data_dir, "valid.bin"), dtype=np.uint16, mode="r")
    ix = torch.randint(len(data) - config_args.max_seq_len, (config_args.batch_size, ))
    # increase a new dim on dim 0
    x = torch.stack([torch.from_numpy(data[i:i+config_args.max_seq_len].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i+1:i+1+config_args.max_seq_len].astype(np.int64)) for i in ix])
    if config_args.device == "cuda":
        x, y = x.pin_memory().to(config_args.device, non_blocking=True), y.pin_memory().to(config_args.device, non_blocking=True)
    else:
        x, y = x.to(config_args.device), y.to(config_args.device)
    return x, y

3 模型加载

模型在原本的GPT-2的基础上使用了LLAMA的一些技术创新点。 代码如下:
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import config


# TODO 2 Maybe is ok
# RoPE blogs : https://zhuanlan.zhihu.com/p/645263524
class RoPeEmbedding(nn.Module):
    def __init__(self, config):
        super(RoPeEmbedding, self).__init__()

    def get_freqs_cis(self, d_model, max_seq_len, device):
        # 获取弧度 (d_model/2, )
        freqs = 1. / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        # 获取索引 (max_seq_len, )
        t = torch.arange(0, max_seq_len)
        # 得到全索引弧度 (max_seq_len, d_model/2)
        freqs = torch.outer(t, freqs)
        # 位置复数 (max_seq_len, d_model/2)
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis.to(device)

    def forward(self, xq, xk):
        xq_device = xq.device
        max_seq_len, d_model = xq.shape[1], xq.shape[2]
        # 将xq, xk转为复数
        # (batch_size, max_seq_len, d_model) -> (batch_size, max_seq_len, d_model/2, 2)·
        # -> (batch_size, max_seq_len, d_model/2)
        xq_out = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
        xk_out = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
        # 将xq, xk与旋转复数矩阵进行复数乘法
        freqs_cis = self.get_freqs_cis(d_model, max_seq_len, xq_device)
        # 将最后两个维度变成一维
        xq_out = torch.view_as_real(xq_out * freqs_cis).flatten(-2)
        xk_out = torch.view_as_real(xk_out * freqs_cis).flatten(-2)
        return xq_out.type_as(xq), xk_out.type_as(xk)


# TODO 3 Maybe is ok
class SwishGLU(nn.Module):
    def __init__(self, config):
        super(SwishGLU, self).__init__()
        d_model, dropout = config.d_model, config.dropout
        self.w1 = nn.Linear(d_model, d_model*2)
        self.w3 = nn.Linear(d_model, d_model*2)
        self.w2 = nn.Linear(d_model*2, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.w2(F.silu(self.w1(x)) * self.w3(x))
        x = self.drop(x)
        return x


# TODO 4 Maybe is ok
class RMSNorm(nn.Module):
    def __init__(self, config):
        super(RMSNorm, self).__init__()
        d_model, bias = config.d_model, config.bias
        self.weight = nn.Parameter(torch.ones(d_model))
        self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None
        self.eps = 1e-5

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x):
        x = self.weight * self._norm(x) + self.bias
        return x


class LayerNorm(nn.Module):
    def __init__(self, config):
        d_model, bias = config.d_model, config.bias
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None

    def forward(self, x):
        return F.layer_norm(x, normalized_shape=self.weight.shape, weight=self.weight, bias=self.bias, eps=1e-5)


class SelfAttention(nn.Module):
    def __init__(self, config):
        super(SelfAttention, self).__init__()
        in_channel, out_channel = config.d_model, config.d_model//config.n_head
        self.pe_encoder = config.pe_encoder
        self.rope = RoPeEmbedding(config)
        self.Wq = nn.Linear(in_channel, out_channel)
        self.Wk = nn.Linear(in_channel, out_channel)
        self.Wv = nn.Linear(in_channel, out_channel)

    def forward(self, x):
        q, k, v = self.Wq(x), self.Wk(x), self.Wv(x)
        if self.pe_encoder == "rope":
            q, k = self.rope(q, k)
        attention = q @ k.transpose(-1, -2)
        sqrt_hidden = torch.sqrt(torch.tensor(q.shape[-1], device=q.device))
        mask = torch.tril(torch.ones(size=(q.shape[1], q.shape[1]), device=q.device), diagonal=0)
        attention = (attention / sqrt_hidden).masked_fill(mask == 0, float("-inf"))
        attention = F.softmax(attention, dim=-1)
        value = attention @ v
        return value


class MultiSelfAttention(nn.Module):
    def __init__(self, config):
        super(MultiSelfAttention, self).__init__()
        d_model, n_head = config.d_model, config.n_head
        self.Wo = nn.Linear(d_model, d_model)
        self.multi_self_attention = nn.ModuleList([SelfAttention(config) for _ in range(n_head)])

    def forward(self, x):
        value = self.Wo(torch.cat([self_attention(x) for self_attention in self.multi_self_attention], dim=-1))
        return value


class FFN(nn.Module):
    def __init__(self, config):
        super(FFN, self).__init__()
        d_model, dropout = config.d_model, config.dropout
        self.config = config
        if config.ffn == "origin":
            self.fc1 = nn.Linear(d_model, d_model * 4)
            self.fc2 = nn.Linear(d_model * 4, d_model)
            self.gelu = nn.GELU()
            self.dropout = nn.Dropout(dropout)
        elif config.ffn == "swishglu":
            self.swishglu = SwishGLU(config)

    def forward(self, x):
        if self.config.ffn == "origin":
            x = self.gelu(self.fc1(x))
            x = self.dropout(self.fc2(x))
        elif self.config.ffn == "swishglu":
            x = self.swishglu(x)
        return x


class TransformerLayer(nn.Module):
    def __init__(self, config):
        super(TransformerLayer, self).__init__()
        if config.norm == "origin":
            self.norm1 = LayerNorm(config)
            self.norm2 = LayerNorm(config)
        elif config.norm == "rms":
            self.norm1 = RMSNorm(config)
            self.norm2 = RMSNorm(config)
        self.multi_self_attention = MultiSelfAttention(config)
        self.ffn = FFN(config)


    def forward(self, x):
        x = x + self.multi_self_attention(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class GPT2(nn.Module):
    def __init__(self, config):
        super(GPT2, self).__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            vocab = nn.Embedding(config.vocab_size, config.d_model),
            pe = nn.Embedding(config.max_seq_len, config.d_model) if config.pe_encoder == "origin" else None,
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([TransformerLayer(config) for _ in range(config.n_layer)]),
            norm = LayerNorm(config) if config.norm == "origin" else RMSNorm(config)
        ))
        self.lm_head = nn.Linear(config.d_model, config.vocab_size)
        self.apply(self._init_weights)
        for name, paramters in self.transformer.named_parameters():
            if name.endswith("fc2.weight"):
                torch.nn.init.normal_(paramters, 0, (0.02 / math.sqrt(2 * config.n_layer)))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, 0, 0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        if isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, 0, 0.02)

    def forward(self, x, target=None):
        token_emb = self.transformer.vocab(x)
        if self.config.pe_encoder == "origin":
            pos = torch.arange(0, x.shape[1], dtype=torch.long, device=x.device)
            pos_emb = self.transformer.pe(pos)
            h = self.transformer.drop(token_emb + pos_emb)
        else:
            h = self.transformer.drop(token_emb)
        for block in self.transformer.h:
            h = block(h)
        h = self.transformer.norm(h)
        if target is not None:
            logits = self.lm_head(h)
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), target.view(-1), ignore_index=-1)
        else:
            logits = self.lm_head(h[:, -1, :])
            loss = None
        return logits, loss

    def generate(self, idx, max_tokens, t=0.1, topk=None):
        for _ in range(max_tokens):
            logits = self(idx)
            logit = logits[:, -1, :] / t
            if topk is not None:
                # vector index
                v, _ = torch.topk(logit, min(topk, logit.shape[-1]))
                logit[logit < v[:, [-1]]] = -float('inf')
            probs = F.softmax(logit, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=-1)
        return idx

# TODO 1 增加推理代码 K/V cache 参考LLaMA
# TODO 2 增加RoPE 参考LLaMa  OK
# TODO 3 更换SwishGLU OK
# TODO 4 RMSNorm     OK
# TODO 5 增加其他2种采样方法, top_p, beam_search
# TODO 6 FlashAttention 参考LLaMA-moe



if __name__ == "__main__":
    config = config.config_args
    gpt2 = GPT2(config).to("cuda")
    num_parameters = np.sum([p.numel() for p in gpt2.parameters()]) / 1024/ 1024
    print(f"paramters is {num_parameters} M")
    x = torch.randint(0, 1000, size=(2, 10), dtype=torch.long).to("cuda")
    print(f"x is {x}")
    # output = gpt2.generate(x, config.max_tokens, t=0.1, topk=5)
    # print(f"output is {output}")
    logits, loss = gpt2(x)
    print(logits)

4 训练过程

训练过程有上述基础知识打底,相信大家可以看懂。 代码如下:
import os
import time
import math
import pickle
from tqdm import trange
import numpy as np
import torch
from utils import *
from model import *
from contextlib import nullcontext
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from config import config_args

ddp = int(os.environ.get("RANK", -1)) != -1
if ddp:
    init_process_group(backend=config_args.backend)
    rank = int(os.environ.get("RANK"))
    local_rank = int(os.environ.get("LOCAL_RANK"))
    world_size = int(os.environ.get("WORLD_SIZE"))
    device = f"cuda:{local_rank}"
    torch.cuda.set_device(device)
    master_process = rank == 0
    seed_offset = local_rank
    config_args.gradient_accumulation_steps = config_args.gradient_accumulation_steps // world_size
else:
    world_size = 1
    master_process = True
    seed_offset = 0

# 12 * 1024 * 5 * 8 = 0.5M token
token_per_iter = config_args.batch_size * config_args.max_seq_len * config_args.gradient_accumulation_steps * world_size
print(f"batch_size is {config_args.batch_size}")
print(f"max_seq_len is {config_args.max_seq_len}")
print(f"token per iter is {token_per_iter}")

ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[config_args.dtype]
ctx = nullcontext() if config_args.device == "cpu" else torch.amp.autocast(device_type=config_args.device, dtype=ptdtype)
print(f"train type is {ptdtype}")

model = GPT2(config_args).to(config_args.device)
scaler = torch.cuda.amp.GradScaler(enabled=(ptdtype=="bfloat16" or ptdtype=="float16"))
use_fused = True if config_args.device == "cuda" else False
# TODO 不同学习率
optimizer = torch.optim.Adam(model.parameters(), lr=config_args.lr, betas=(config_args.beta1, config_args.beta2), fused=True,
                              weight_decay=config_args.weight_decay)

if config_args.compile:
    model = torch.compile(model)
if ddp:
    model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
raw_model = model.module if ddp else model

# training loop
dataset_data_dir = get_dataset_path()
x, y = get_batch(dataset_data_dir, "train")
start_time = time.time()
best_val_loss = 1e9
for n_iter in trange(1, config_args.max_iters):
    lr = get_lr(n_iter)
    # evaluate
    if n_iter % config_args.eval_n_iter == 0 and master_process:
        loss_dict = estimate_loss(model, ctx, config_args.eval_iter)
        if loss_dict["valid"] < best_val_loss:
            best_val_loss = loss_dict["valid"]
            if n_iter > 0:
                checkpoint = {
                    "model": raw_model.state_dict(),
                    "optimzier": optimizer.state_dict(),
                    "config_args": config_args,
                    "n_iter": n_iter,
                    "best_val_loss": best_val_loss
                }
                print(f"saving checkpoint to {config_args.out_dir}")
                torch.save(checkpoint, osp.join(config_args.out_dir, "ckpt.pt"))
    for micro_step in range(config_args.gradient_accumulation_steps):
        if ddp:
            model.require_backward_grad_sync = (micro_step == config_args.gradient_accumulation_steps - 1)
        with ctx:
            logits, loss = model(x, y)
        print(f"step is {n_iter} train loss is {loss}")
        x, y = get_batch(dataset_data_dir, "train")
        scaler.scale(loss).backward()
    if config_args.grad_clip != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config_args.grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

if ddp:
    destroy_process_group()