import torch import torch.nn as nn import torch.optim as optim from itertools import chain # Parts of the code are modifications of Pytorch's AdamW optimizer # Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py class SOAP(optim.Optimizer): """ Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). Parameters: params (`Iterable[nn.parameter.Parameter]`): Iterable of parameters to optimize or dictionaries defining parameter groups. lr (`float`, *optional*, defaults to 0.003): The learning rate to use. betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`): Adam's betas parameters (b1, b2). shampoo_beta (`float`, *optional*, defaults to -1): If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1]. eps (`float`, *optional*, defaults to 1e-08): Adam's epsilon for numerical stability. weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient. precondition_frequency (`int`, *optional*, defaults to 10): How often to update the preconditioner. max_precond_dim (`int`, *optional*, defaults to 10000): Maximum dimension of the preconditioner. Set to 10000, so that we exclude most common vocab sizes while including layers. merge_dims (`bool`, *optional*, defaults to `False`): Whether or not to merge dimensions of the preconditioner. precondition_1d (`bool`, *optional*, defaults to `False`): Whether or not to precondition 1D gradients. normalize_grads (`bool`, *optional*, defaults to `False`): Whether or not to normalize gradients per layer. Helps at large precondition_frequency (~100 in our experiments), but hurts performance at small precondition_frequency (~10 in our experiments). data_format (`str`, *optional*, defaults to `channels_first`): Data format of the input for convolutional layers. Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW. correct_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias correction in Adam. """ def __init__( self, params, lr: float = 3e-3, betas=(0.95, 0.95), shampoo_beta: float= -1, eps: float = 1e-8, weight_decay: float = 0.01, precondition_frequency: int=10, max_precond_dim: int=10000, # merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim. precondition_1d: bool = False, normalize_grads: bool = False, data_format: str = "channels_first", correct_bias: bool = True, ): defaults = { "lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay, "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim, "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads, "correct_bias": correct_bias, } super().__init__(params, defaults) self._data_format = data_format def merge_dims(self, grad, max_precond_dim): """ Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim. """ assert self._data_format in ["channels_first", "channels_last"] if self._data_format == "channels_last" and grad.dim() == 4: grad = grad.permute(0, 3, 1, 2) shape = grad.shape new_shape = [] curr_shape = 1 for sh in shape: temp_shape = curr_shape * sh if temp_shape > max_precond_dim: if curr_shape > 1: new_shape.append(curr_shape) curr_shape = sh else: new_shape.append(sh) curr_shape = 1 else: curr_shape = temp_shape if curr_shape > 1 or len(new_shape)==0: new_shape.append(curr_shape) new_grad = grad.reshape(new_shape) return new_grad @torch.no_grad() def step(self): """ Performs a single optimization step. Arguments: closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. """ loss = None for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad state = self.state[p] if "step" not in state: state["step"] = 0 # State initialization if "exp_avg" not in state: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(grad) if 'Q' not in state: self.init_preconditioner( grad, state, precondition_frequency=group['precondition_frequency'], precondition_1d=group['precondition_1d'], shampoo_beta=(group['shampoo_beta'] if group['shampoo_beta'] >= 0 else group["betas"][1]), max_precond_dim=group['max_precond_dim'], merge_dims=group["merge_dims"], ) self.update_preconditioner(grad, state, max_precond_dim=group['max_precond_dim'], merge_dims=group["merge_dims"], precondition_1d=group["precondition_1d"]) continue # first step is skipped so that we never use the current gradients in the projection. # Projecting gradients to the eigenbases of Shampoo's preconditioner # i.e. projecting to the eigenbases of matrices in state['GG'] grad_projected = self.project(grad, state, merge_dims=group["merge_dims"], max_precond_dim=group['max_precond_dim']) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2)) denom = exp_avg_sq.sqrt().add_(group["eps"]) # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner # i.e. projecting to the eigenbases of matrices in state['GG'] exp_avg_projected = self.project(exp_avg, state, merge_dims=group["merge_dims"], max_precond_dim=group['max_precond_dim']) step_size = group["lr"] if group["correct_bias"]: bias_correction1 = 1.0 - beta1 ** (state["step"]) bias_correction2 = 1.0 - beta2 ** (state["step"]) step_size = step_size * (bias_correction2 ** .5) / bias_correction1 # Projecting back the preconditioned (by Adam) exponential moving average of gradients # to the original space norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=group["merge_dims"], max_precond_dim=group['max_precond_dim']) if group["normalize_grads"]: norm_grad = norm_grad / (1e-30+torch.mean(norm_grad**2)**0.5) p.add_(norm_grad, alpha=-step_size) # From AdamW code: Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. # Add weight decay at the end (fixed version) if group["weight_decay"] > 0.0: p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) # Update is done after the gradient step to avoid using current gradients in the projection. self.update_preconditioner(grad, state, max_precond_dim=group['max_precond_dim'], merge_dims=group["merge_dims"], precondition_1d=group["precondition_1d"]) return loss def init_preconditioner(self, grad, state, precondition_frequency=10, shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False, merge_dims=False): """ Initializes the preconditioner matrices (L and R in the paper). """ state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper). if grad.dim() == 1: if not precondition_1d or grad.shape[0] > max_precond_dim: state['GG'].append([]) else: state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)) else: if merge_dims: grad = self.merge_dims(grad, max_precond_dim) for sh in grad.shape: if sh > max_precond_dim: state['GG'].append([]) else: state['GG'].append(torch.zeros(sh, sh, device=grad.device)) state['Q'] = None # Will hold all the eigenbases of the preconditioner. state['precondition_frequency'] = precondition_frequency state['shampoo_beta'] = shampoo_beta def project(self, grad, state, merge_dims=False, max_precond_dim=10000): """ Projects the gradient to the eigenbases of the preconditioner. """ original_shape = grad.shape if merge_dims: if grad.dim() == 4 and self._data_format == 'channels_last': permuted_shape = grad.permute(0, 3, 1, 2).shape grad = self.merge_dims(grad, max_precond_dim) for mat in state['Q']: if len(mat) > 0: grad = torch.tensordot( grad, mat, dims=[[0], [0]], ) else: permute_order = list(range(1, len(grad.shape))) + [0] grad = grad.permute(permute_order) if merge_dims: if self._data_format == 'channels_last' and len(original_shape) == 4: grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) else: grad = grad.reshape(original_shape) return grad def update_preconditioner(self, grad, state, max_precond_dim=10000, merge_dims=False, precondition_1d=False): """ Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper). """ if grad.dim() == 1: if precondition_1d and grad.shape[0] <= max_precond_dim: state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta']) else: if merge_dims: new_grad = self.merge_dims(grad, max_precond_dim) for idx, sh in enumerate(new_grad.shape): if sh <= max_precond_dim: outer_product = torch.tensordot( new_grad, new_grad, dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2, ) state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta']) else: for idx, sh in enumerate(grad.shape): if sh <= max_precond_dim: outer_product = torch.tensordot( grad, grad, # Contracts across all dimensions except for k. dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2, ) state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta']) if state['Q'] is None: state['Q'] = self.get_orthogonal_matrix(state['GG']) if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0: state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims) def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000): """ Projects the gradient back to the original space. """ original_shape = grad.shape if merge_dims: if self._data_format == 'channels_last' and grad.dim() == 4: permuted_shape = grad.permute(0, 3, 1, 2).shape grad = self.merge_dims(grad, max_precond_dim) for mat in state['Q']: if len(mat) > 0: grad = torch.tensordot( grad, mat, dims=[[0], [1]], ) else: permute_order = list(range(1, len(grad.shape))) + [0] grad = grad.permute(permute_order) if merge_dims: if self._data_format == 'channels_last' and len(original_shape) == 4: grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) else: grad = grad.reshape(original_shape) return grad def get_orthogonal_matrix(self, mat): """ Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition. """ matrix = [] for m in mat: if len(m) == 0: matrix.append([]) continue if m.data.dtype != torch.float: float_data = False original_type = m.data.dtype original_device = m.data.device matrix.append(m.data.float()) else: float_data = True matrix.append(m.data) final = [] for m in matrix: if len(m) == 0: final.append([]) continue try: _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) except: _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) Q = Q.to(m.dtype) Q = torch.flip(Q, [1]) if not float_data: Q = Q.to(original_device).type(original_type) final.append(Q) return final def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False): """ Computes the eigenbases of the preconditioner using one round of power iteration followed by torch.linalg.qr decomposition. """ precond_list = state['GG'] orth_list = state['Q'] matrix = [] orth_matrix = [] for m,o in zip(precond_list, orth_list): if len(m) == 0: matrix.append([]) orth_matrix.append([]) continue if m.data.dtype != torch.float: float_data = False original_type = m.data.dtype original_device = m.data.device matrix.append(m.data.float()) orth_matrix.append(o.data.float()) else: float_data = True matrix.append(m.data.float()) orth_matrix.append(o.data.float()) orig_shape = state['exp_avg_sq'].shape if self._data_format == 'channels_last' and len(orig_shape) == 4: permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape if merge_dims: exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim) else: exp_avg_sq = state['exp_avg_sq'] final = [] for ind, (m,o) in enumerate(zip(matrix, orth_matrix)): if len(m)==0: final.append([]) continue est_eig = torch.diag(o.T @ m @ o) sort_idx = torch.argsort(est_eig, descending=True) exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) o = o[:,sort_idx] power_iter = m @ o Q, _ = torch.linalg.qr(power_iter) if not float_data: Q = Q.to(original_device).type(original_type) final.append(Q) if merge_dims: if self._data_format == 'channels_last' and len(orig_shape) == 4: exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1) else: exp_avg_sq = exp_avg_sq.reshape(orig_shape) state['exp_avg_sq'] = exp_avg_sq return final import os import sys with open(sys.argv[0]) as f: code = f.read() # read the code of this file ASAP, for logging import uuid import glob import time from dataclasses import dataclass import numpy as np import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist import torch._inductor.config as config from torch.nn.parallel import DistributedDataParallel as DDP # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the GPT-2 model class Rotary(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None def forward(self, x): seq_len = x.shape[1] if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq).to(x.device) self.cos_cached = freqs.cos() self.sin_cached = freqs.sin() return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3]//2 x1 = x[..., :d] x2 = x[..., d:] y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat([y1, y2], 3) def rmsnorm(x0, eps=1e-6): x = x0.float() x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) return x.type_as(x0) class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = self.n_embd // self.n_head assert self.n_embd % self.n_head == 0 self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) # output projection self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.rotary = Rotary(self.head_dim) def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) q, k, v = self.c_q(x), self.c_k(x), self.c_v(x) k = k.view(B, T, self.n_head, self.head_dim) q = q.view(B, T, self.n_head, self.head_dim) v = v.view(B, T, self.n_head, self.head_dim) cos, sin = self.rotary(q) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.c_proj(y) return y class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) def forward(self, x): x = self.c_fc(x) x = F.gelu(x) x = self.c_proj(x) return x class Block(nn.Module): def __init__(self, config): super().__init__() self.attn = CausalSelfAttention(config) self.mlp = MLP(config) self.attn_scale = (1 / (2 * config.n_layer)**0.5) def forward(self, x): x = x + self.attn_scale * self.attn(rmsnorm(x)) x = x + self.mlp(rmsnorm(x)) return x # ----------------------------------------------------------------------------- # The main GPT-2 model @dataclass class GPTConfig: vocab_size : int = 50257 n_layer : int = 12 n_head : int = 12 n_embd : int = 768 class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying def forward(self, idx, targets=None, return_logits=True): b, t = idx.size() pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # shape (t) # forward the GPT model itself x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) for block in self.transformer.h: x = block(x) x = rmsnorm(x) if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) logits = logits.float() # use tf32/fp32 for logits loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim logits = logits.float() # use tf32/fp32 for logits loss = None # there are performance reasons why not returning logits is prudent, if not needed if not return_logits: logits = None return logits, loss # ----------------------------------------------------------------------------- # Our own simple Distributed Data Loader def _peek_data_shard(filename): # only reads the header, returns header data with open(filename, "rb") as f: # first read the header, which is 256 int32 integers (4 bytes each) header = np.frombuffer(f.read(256*4), dtype=np.int32) if header[0] != 20240520: print("ERROR: magic number mismatch in the data .bin file!") print("---> HINT: Are you passing in a correct file with --input_bin?") print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") exit(1) assert header[1] == 1, "unsupported version" ntok = header[2] # number of tokens (claimed) return ntok # for now just return the number of tokens def _load_data_shard(filename): with open(filename, "rb") as f: # first read the header, which is 256 int32 integers (4 bytes each) header = np.frombuffer(f.read(256*4), dtype=np.int32) assert header[0] == 20240520, "magic number mismatch in the data .bin file" assert header[1] == 1, "unsupported version" ntok = header[2] # number of tokens (claimed) # the rest of it are tokens, stored as uint16 tokens = np.frombuffer(f.read(), dtype=np.uint16) assert len(tokens) == ntok, "number of tokens read does not match header?" return tokens class DistributedDataLoader: def __init__(self, filename_pattern, B, T, process_rank, num_processes): self.process_rank = process_rank self.num_processes = num_processes self.B = B self.T = T # glob files that match the pattern self.files = sorted(glob.glob(filename_pattern)) assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" # load and validate all data shards, count number of tokens in total ntok_total = 0 for fname in self.files: shard_ntok = _peek_data_shard(fname) assert shard_ntok >= num_processes * B * T + 1 ntok_total += int(shard_ntok) self.ntok_total = ntok_total # kick things off self.reset() def reset(self): self.current_shard = 0 self.current_position = self.process_rank * self.B * self.T self.tokens = _load_data_shard(self.files[self.current_shard]) def advance(self): # advance to next data shard self.current_shard = (self.current_shard + 1) % len(self.files) self.current_position = self.process_rank * self.B * self.T self.tokens = _load_data_shard(self.files[self.current_shard]) def next_batch(self): B = self.B T = self.T buf = self.tokens[self.current_position : self.current_position+B*T+1] buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) x = (buf[:-1]).view(B, T) # inputs y = (buf[1:]).view(B, T) # targets # advance current position and load next shard if necessary self.current_position += B * T * self.num_processes if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): self.advance() return x.cuda(), y.cuda() # ----------------------------------------------------------------------------- # int main @dataclass class Hyperparameters: # data hyperparams input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on # optimization hyperparams batch_size : int = 8*64 # batch size, in sequences, across all devices device_batch_size : int = 64 # batch size, in sequences, per device sequence_length : int = 1024 # sequence length, in tokens num_iterations : int = 6000 # number of iterations to run learning_rate : float = 0.0036 warmup_iters : int = 250 warmdown_iters : int = 1800 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule # evaluation and logging hyperparams val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end args = Hyperparameters() # set up DDP (distributed data parallel). torchrun sets this env variable assert torch.cuda.is_available() dist.init_process_group(backend='nccl') ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) print(f"using device: {device}") master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc. # convenience variables B, T = args.device_batch_size, args.sequence_length # calculate the number of steps to take in the val loop. assert args.val_tokens % (B * T * ddp_world_size) == 0 val_steps = args.val_tokens // (B * T * ddp_world_size) # calculate the steps of gradient accumulation required to attain the desired global batch size. assert args.batch_size % (B * ddp_world_size) == 0 train_accumulation_steps = args.batch_size // (B * ddp_world_size) # load tokens train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) if master_process: print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") x, y = train_loader.next_batch() # init the model from scratch num_vocab = 50257 model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=12, n_embd=768)) model = model.cuda() if hasattr(config, "coordinate_descent_tuning"): config.coordinate_descent_tuning = True # suggested by @Chillee model = torch.compile(model) # here we wrap model into DDP container model = DDP(model, device_ids=[ddp_local_rank]) raw_model = model.module # always contains the "raw" unwrapped model ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16) # init the optimizer(s) optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95), weight_decay=0, fused=True) optimizer2 = SOAP(raw_model.transformer.h.parameters(), lr=0.5*args.learning_rate, betas=(.95, .95), weight_decay=0, precondition_frequency=10) optimizers = [optimizer1, optimizer2] # learning rate decay scheduler (linear warmup and warmdown) def get_lr(it): assert it <= args.num_iterations # 1) linear warmup for warmup_iters steps if it < args.warmup_iters: return (it+1) / args.warmup_iters # 2) constant lr for a while elif it < args.num_iterations - args.warmdown_iters: return 1.0 # 3) linear warmdown else: decay_ratio = (args.num_iterations - it) / args.warmdown_iters return decay_ratio schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] # begin logging if master_process: run_id = str(uuid.uuid4()) logdir = 'logs/%s/' % run_id os.makedirs(logdir, exist_ok=True) logfile = 'logs/%s.txt' % run_id # create the log file with open(logfile, "w") as f: # begin the log by printing this file (the Python code) f.write('='*100 + '\n') f.write(code) f.write('='*100 + '\n') # log information about the hardware/software environment this is running on # and print the full `nvidia-smi` to file f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n") import subprocess result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) f.write(f'{result.stdout}\n') f.write('='*100 + '\n') training_time_ms = 0 # start the clock torch.cuda.synchronize() t0 = time.time() # begin training train_loader.reset() for step in range(args.num_iterations + 1): last_step = (step == args.num_iterations) # This effectively ignores timing first 10 steps, which are slower for weird reasons. # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 # steps with dummy data first, and then re-initialize the model and reset the loader. if step == 10: training_time_ms = 0 t0 = time.time() timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val # once in a while evaluate the validation dataset if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): # stop the clock torch.cuda.synchronize() training_time_ms += 1000 * (time.time() - t0) # run validation batches model.eval() val_loader.reset() val_loss = 0.0 for _ in range(val_steps): x_val, y_val = val_loader.next_batch() with torch.no_grad(): # of course, we'd like to use ctx here too, but that creates a torch.compile error for some reason _, loss = model(x_val, y_val, return_logits=False) val_loss += loss dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) val_loss /= val_steps # log val loss to console and to logfile if master_process: print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') with open(logfile, "a") as f: f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n') # start the clock again torch.cuda.synchronize() t0 = time.time() if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): # stop the clock torch.cuda.synchronize() training_time_ms += 1000 * (time.time() - t0) # save the state of the training process log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) # start the clock again torch.cuda.synchronize() t0 = time.time() # bit confusing: we want to make sure to eval on 0th iteration # but also after the very last iteration. so we loop for step <= num_iterations # instead of just < num_iterations (one extra due to <=), only to do # the validation/sampling one last time, and then we break right here as we're done. if last_step: break # --------------- TRAINING SECTION BEGIN ----------------- model.train() for i in range(1, train_accumulation_steps+1): # forward pass with ctx: _, loss = model(x, y, return_logits=False) train_loss = loss.detach() # advance the dataset for the next batch x, y = train_loader.next_batch() # backward pass if i < train_accumulation_steps: with model.no_sync(): # there's no need to sync gradients every accumulation step loss.backward() else: loss.backward() # just sync on the last step for p in model.parameters(): p.grad /= train_accumulation_steps # step the optimizers and schedulers for opt, sched in zip(optimizers, schedulers): opt.step() sched.step() # null the gradients model.zero_grad(set_to_none=True) # --------------- TRAINING SECTION END ------------------- # everything that follows now is just diagnostics, prints, logging, etc. #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower if master_process: approx_time = training_time_ms + 1000 * (time.time() - t0) print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") with open(logfile, "a") as f: f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n") if master_process: print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") # ------------------------------------------------------------------------- # clean up nice dist.destroy_process_group()