import os import sys with open(sys.argv[0]) as f: code = f.read() # read the code of this file ASAP, for logging import uuid import glob import time import contextlib from dataclasses import dataclass import numpy as np import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist import torch._inductor.config as config from torch.nn.parallel import DistributedDataParallel as DDP # Use of FlexAttention contributed by @KoszarskyB from torch.nn.attention.flex_attention import flex_attention, create_block_mask flex_attention = torch.compile(flex_attention, dynamic=False) create_block_mask = torch.compile(create_block_mask, dynamic=False) # ----------------------------------------------------------------------------- # Muon optimizer def zeropower_via_svd(G, steps=None): U, S, V = G.svd() return U @ V.T @torch.compile def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no longer converges all the way to one everywhere on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ assert len(G.shape) == 2 a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= (X.norm() + eps) # ensure top singular value <= 1 if G.size(0) > G.size(1): X = X.T for _ in range(steps): A = X @ X.T B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X if G.size(0) > G.size(1): X = X.T return X zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5) class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU. Some warnings: - This optimizer assumes that all parameters passed in are 2D. - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. - We believe it is unlikely to work well for training with small batch size. - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). Arguments: lr: The learning rate used by the internal SGD. momentum: The momentum used by the internal SGD. nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') backend_steps: The number of iteration steps to use in the backend, if it is iterative. """ def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5): defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) super().__init__(params, defaults) def step(self): for group in self.param_groups: lr = group['lr'] momentum = group['momentum'] zeropower_backend = zeropower_backends[group['backend']] # generate weight updates in distributed fashion total_params = sum(p.numel() for p in group['params']) updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16) curr_idx = 0 for i, p in enumerate(group['params']): # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']): g = p.grad assert g is not None state = self.state[p] if 'momentum_buffer' not in state: state['momentum_buffer'] = torch.zeros_like(g) buf = state['momentum_buffer'] buf.mul_(momentum).add_(g) g = g.add(buf, alpha=momentum) if group['nesterov'] else buf g = zeropower_backend(g, steps=group['backend_steps']) g *= max(1, g.size(0)/g.size(1))**0.5 updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten() curr_idx += p.numel() # sync updates across devices. we are not memory-constrained so can do this simple deserialization dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) # deserialize and apply updates curr_idx = 0 for p in group['params']: g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data) p.data.add_(g, alpha=-lr) curr_idx += p.numel() # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the GPT-2 model def norm(x): return F.rms_norm(x, (x.size(-1),)) class CastedLinear(nn.Linear): def __init__(self, in_features, out_features): super().__init__(in_features, out_features, bias=False) def forward(self, x): return F.linear(x, self.weight.to(x.dtype)) class Rotary(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() self.register_buffer('inv_freq', (1 / base) ** (torch.arange(0, dim, 2) / dim)) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None def forward(self, x): seq_len = x.shape[1] if seq_len != self.seq_len_cached: t = torch.arange(seq_len, device=x.device) freqs = torch.outer(t, self.inv_freq) self.seq_len_cached = seq_len self.cos_cached = freqs.cos() self.sin_cached = freqs.sin() cos, sin = self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] # apply_rotary_emb(x, cos, sin) x1, x2 = x.chunk(2, dim=3) y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat((y1, y2), 3).type_as(x) class CausalSelfAttention(nn.Module): def __init__(self, dim, n_head): super().__init__() assert dim % n_head == 0 self.n_head = n_head self.c_q = CastedLinear(dim, dim) self.c_k = CastedLinear(dim, dim) self.c_v = CastedLinear(dim, dim) # value residual lambda self.lamb = nn.Parameter(torch.tensor(0.5)) # @Grad62304977 # rotary embeddings self.rotary = Rotary(dim // n_head) # dim // n_head = head_dim # output projection self.c_proj = CastedLinear(dim, dim) self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 def forward(self, x, vi, block_mask): B, T = x.size(0), x.size(1) # batch size, sequence length assert B == 1, "Must use batch size = 1 for FlexAttention" q = self.c_q(x).view(B, T, self.n_head, -1) k = self.c_k(x).view(B, T, self.n_head, -1) v = self.c_v(x).view(B, T, self.n_head, -1) v = (1 - self.lamb) * v + self.lamb * vi.view_as(v) # @Grad62304977 q, k = norm(q), norm(k) # QK norm suggested by @Grad62304977 q, k = self.rotary(q), self.rotary(k) y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask) y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side y = self.c_proj(y) return y class MLP(nn.Module): def __init__(self, dim): super().__init__() self.c_fc = CastedLinear(dim, 4 * dim) self.c_proj = CastedLinear(4 * dim, dim) self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 def forward(self, x): x = self.c_fc(x) x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 x = self.c_proj(x) return x class Block(nn.Module): def __init__(self, config): super().__init__() self.attn = CausalSelfAttention(config.n_embd, config.n_head) self.mlp = MLP(config.n_embd) self.lambdas = nn.Parameter(torch.tensor([1., 0.])) def forward(self, x, vi, x0, block_mask): x = self.lambdas[0] * x + self.lambdas[1] * x0 x = x + self.attn(norm(x), vi, block_mask) x = x + self.mlp(norm(x)) return x # ----------------------------------------------------------------------------- # The main GPT-2 model @dataclass class GPTConfig: vocab_size : int = 50304 n_layer : int = 12 n_head : int = 6 # head dim 128 suggested by @Grad62304977 n_embd : int = 768 class GPT(nn.Module): def __init__(self, config): super().__init__() # U-net design by @brendanh0gan self.num_encoder_layers = config.n_layer // 2 # Half of the layers for encoder self.num_decoder_layers = config.n_layer - self.num_encoder_layers # Remaining for decoder # Add learnable skip connection weights for decoder layers self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning vte = nn.Embedding(config.vocab_size, config.n_embd*12), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), )) self.lm_head = CastedLinear(config.n_embd, config.vocab_size) self.lm_head.weight.data.zero_() # @Grad62304977 def forward(self, idx, target, attn_blocksize): docs = (idx == 50256).cumsum(0) def document_causal_mask(b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx document_mask = docs[q_idx] == docs[kv_idx] window_mask = q_idx - kv_idx < attn_blocksize return causal_mask & document_mask & window_mask S = len(idx) block_mask = create_block_mask(document_causal_mask, None, None, S, S, device="cuda", _compile=True) # forward the GPT model itself x = self.transformer.wte(idx[None]) # token embeddings of shape (b, t, n_embd) x = norm(x) # @Grad62304977 x0 = x vi = self.transformer.vte(idx[None]).chunk(12, dim=-1) # Store outputs for U-Net skip connections skip_connections = [] # Encoder pass - process only the first half of the blocks for i in range(self.num_encoder_layers): x = self.transformer.h[i](x, vi[i], x0, block_mask) skip_connections.append(x) # Decoder pass - process the remaining blocks with weighted skip connections for i in range(self.num_decoder_layers): x = x + self.skip_weights[i] * skip_connections.pop() x = self.transformer.h[self.num_encoder_layers + i](x, vi[self.num_encoder_layers+i], x0, block_mask) x = norm(x) logits = self.lm_head(x) logits = 30 * torch.tanh(logits / 30) # @Grad62304977 logits = logits.float() loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1)) return loss # ----------------------------------------------------------------------------- # Our own simple Distributed Data Loader def _peek_data_shard(filename): # only reads the header, returns header data with open(filename, "rb") as f: # first read the header, which is 256 int32 integers (4 bytes each) header = np.frombuffer(f.read(256*4), dtype=np.int32) if header[0] != 20240520: print("ERROR: magic number mismatch in the data .bin file!") print("---> HINT: Are you passing in a correct file with --input_bin?") print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") exit(1) assert header[1] == 1, "unsupported version" ntok = header[2] # number of tokens (claimed) return ntok # for now just return the number of tokens def _load_data_shard(filename): with open(filename, "rb") as f: # first read the header, which is 256 int32 integers (4 bytes each) header = np.frombuffer(f.read(256*4), dtype=np.int32) assert header[0] == 20240520, "magic number mismatch in the data .bin file" assert header[1] == 1, "unsupported version" ntok = header[2] # number of tokens (claimed) # the rest of it are tokens, stored as uint16 tokens = np.frombuffer(f.read(), dtype=np.uint16) assert len(tokens) == ntok, "number of tokens read does not match header?" return tokens class DistributedDataLoader: def __init__(self, filename_pattern, T, process_rank, num_processes): self.process_rank = process_rank self.num_processes = num_processes self.T = T # glob files that match the pattern self.files = sorted(glob.glob(filename_pattern)) assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" # load and validate all data shards, count number of tokens in total ntok_total = 0 for fname in self.files: shard_ntok = _peek_data_shard(fname) assert shard_ntok >= num_processes * T + 1 ntok_total += int(shard_ntok) self.ntok_total = ntok_total self.reset() def reset(self): self.current_shard = -1 self.advance() def advance(self): # advance to next data shard self.current_shard = (self.current_shard + 1) % len(self.files) self.current_position = self.process_rank * self.T self.tokens = _load_data_shard(self.files[self.current_shard]) def next_batch(self): batch_size = self.T * self.num_processes buf = self.tokens[self.current_position:self.current_position+self.T+1] buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) x = buf[:-1] # inputs y = buf[1:] # targets # advance current position and load next shard if necessary self.current_position += batch_size if self.current_position + batch_size >= len(self.tokens): self.advance() return x.cuda(), y.cuda() # ----------------------------------------------------------------------------- # int main @dataclass class Hyperparameters: # data hyperparams input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on # optimization hyperparams batch_size : int = 8 # batch size, in sequences, across all devices sequence_length : int = 64*1024 # sequence length, in tokens num_iterations : int = 1530 # number of iterations to run warmup_iters : int = 0 cooldown_iters : int = 600 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule weight_decay : float = 0 # evaluation and logging hyperparams val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end args = Hyperparameters() # set up DDP (distributed data parallel). torchrun sets this env variable assert torch.cuda.is_available() dist.init_process_group(backend='nccl') ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) print(f"using device: {device}") master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc. # begin logging logfile = None if master_process: run_id = str(uuid.uuid4()) logdir = 'logs/%s/' % run_id os.makedirs(logdir, exist_ok=True) logfile = 'logs/%s.txt' % run_id # create the log file with open(logfile, "w") as f: # begin the log by printing this file (the Python code) f.write(code) f.write('='*100 + '\n') def print0(s, logonly=False): if master_process: with open(logfile, "a") as f: if not logonly: print(s) f.write(s+'\n') # log information about the hardware/software environment this is running on # and print the full `nvidia-smi` to file print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:") import subprocess result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) print0(f'{result.stdout}', logonly=True) print0('='*100, logonly=True) # convenience variables T = args.sequence_length # calculate the number of steps to take in the val loop. assert args.val_tokens % (T * ddp_world_size) == 0 val_steps = args.val_tokens // (T * ddp_world_size) # calculate the steps of gradient accumulation required to attain the desired global batch size. assert args.batch_size % (ddp_world_size) == 0 train_accumulation_steps = args.batch_size // ddp_world_size # load tokens train_loader = DistributedDataLoader(args.input_bin, T, ddp_rank, ddp_world_size) val_loader = DistributedDataLoader(args.input_val_bin, T, ddp_rank, ddp_world_size) print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files") print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files") print0('='*100, logonly=True) x, y = train_loader.next_batch() # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977. # this originates from Karpathy's experiments. num_vocab = 50304 model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768)) model = model.cuda().bfloat16() for m in model.modules(): if isinstance(m, CastedLinear): m.float() if hasattr(config, "coordinate_descent_tuning"): config.coordinate_descent_tuning = True # suggested by @Chillee model = torch.compile(model) # here we wrap model into DDP container model = DDP(model, device_ids=[ddp_local_rank]) raw_model = model.module # always contains the "raw" unwrapped model # init the optimizer(s) optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight, raw_model.transformer.vte.weight], lr=0.6, betas=(0.8, 0.95), fused=True) optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.008, betas=(0.8, 0.95), fused=True) params = list(raw_model.transformer.h.parameters()) matrix_params = [p for p in params if p.ndim == 2] scalar_params = [p for p in params if p.ndim < 2] + [raw_model.skip_weights] optimizer3 = Muon(matrix_params, lr=0.05, momentum=0.95) optimizer4 = torch.optim.Adam(scalar_params, lr=0.04, betas=(0.8, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned optimizers = [optimizer1, optimizer2, optimizer3, optimizer4] # learning rate decay scheduler (linear warmup and cooldown) def get_lr(it): assert it <= args.num_iterations # 1) linear warmup for warmup_iters steps if it < args.warmup_iters: return (it+1) / args.warmup_iters # 2) constant lr for a while elif it < args.num_iterations - args.cooldown_iters: return 1.0 # 3) linear cooldown else: decay_ratio = (args.num_iterations - it) / args.cooldown_iters return decay_ratio schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] # Start training loop training_time_ms = 0 # start the clock torch.cuda.synchronize() t0 = time.time() # begin training for step in range(args.num_iterations + 1): last_step = (step == args.num_iterations) # This effectively ignores timing first 10 steps, which are slower for weird reasons. # Alternately, and slightly more correctly in terms of benchmarking, we could do 10 # steps with dummy data first, and then re-initialize the model and reset the loader. if step == 10: training_time_ms = 0 t0 = time.time() timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val # Set the attention blocksize for the current step, in chunks of 64. By @fernbear.bsky.social attn_blocksize = torch.tensor(64*((step/args.num_iterations * (1792 - 64) + 64)//64), dtype=torch.int, device='cuda') # once in a while evaluate the validation dataset if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)): # stop the clock torch.cuda.synchronize() training_time_ms += 1000 * (time.time() - t0) # run validation batches model.eval() val_loader.reset() val_loss = 0.0 for _ in range(val_steps): with torch.no_grad(): x_val, y_val = val_loader.next_batch() val_loss += model(x_val, y_val, attn_blocksize=attn_blocksize) dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) val_loss /= val_steps # log val loss to console and to logfile print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms') # start the clock again torch.cuda.synchronize() t0 = time.time() if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)): # stop the clock torch.cuda.synchronize() training_time_ms += 1000 * (time.time() - t0) # save the state of the training process log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step)) # start the clock again torch.cuda.synchronize() t0 = time.time() # bit confusing: we want to make sure to eval on 0th iteration # but also after the very last iteration. so we loop for step <= num_iterations # instead of just < num_iterations (one extra due to <=), only to do # the validation/sampling one last time, and then we break right here as we're done. if last_step: break # --------------- TRAINING SECTION BEGIN ----------------- model.train() for i in range(1, train_accumulation_steps+1): ctx = model.no_sync() if i < train_accumulation_steps else contextlib.nullcontext() with ctx: # there's no need to sync gradients every accumulation step # forward pass loss = model(x, y, attn_blocksize=attn_blocksize) # advance the dataset for the next batch x, y = train_loader.next_batch() # backward pass loss.backward() train_loss = loss.detach() for p in model.parameters(): p.grad /= train_accumulation_steps # momentum warmup for Muon frac = min(step/300, 1) optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95 # step the optimizers and schedulers for opt, sched in zip(optimizers, schedulers): opt.step() sched.step() # null the gradients model.zero_grad(set_to_none=True) # --------------- TRAINING SECTION END ------------------- # everything that follows now is just diagnostics, prints, logging, etc. #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower approx_time = training_time_ms + 1000 * (time.time() - t0) print0(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms") if master_process: print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") # ------------------------------------------------------------------------- # clean up nice dist.destroy_process_group() ==================================================================================================== Running pytorch 2.6.0.dev20241203+cu124 compiled for CUDA 12.4 nvidia-smi: Thu Dec 5 03:53:59 2024 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | | N/A 38C P0 75W / 700W | 3MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | | N/A 30C P0 115W / 700W | 529MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | | N/A 31C P0 96W / 700W | 22MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | | N/A 38C P0 113W / 700W | 23MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | | N/A 39C P0 123W / 700W | 529MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | | N/A 29C P0 99W / 700W | 23MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | | N/A 39C P0 109W / 700W | 22MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | | N/A 30C P0 119W / 700W | 529MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ +---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| +---------------------------------------------------------------------------------------+ ==================================================================================================== Training DataLoader: total number of tokens: 1100000000 across 11 files Validation DataLoader: total number of tokens: 100000000 across 1 files ==================================================================================================== step:0/1530 val_loss:10.8258 train_time:0ms step_avg:nanms step:1/1530 train_loss:10.8258 train_time:31726ms step_avg:nanms step:2/1530 train_loss:10.0757 train_time:31835ms step_avg:nanms step:3/1530 train_loss:8.3762 train_time:31996ms step_avg:nanms step:4/1530 train_loss:7.6191 train_time:32157ms step_avg:nanms step:5/1530 train_loss:7.4363 train_time:32318ms step_avg:nanms step:6/1530 train_loss:6.9771 train_time:32478ms step_avg:nanms step:7/1530 train_loss:7.2237 train_time:32639ms step_avg:nanms step:8/1530 train_loss:6.7541 train_time:32799ms step_avg:nanms step:9/1530 train_loss:6.6450 train_time:32961ms step_avg:nanms step:10/1530 train_loss:6.5076 train_time:33121ms step_avg:nanms step:11/1530 train_loss:6.4169 train_time:115ms step_avg:nanms step:12/1530 train_loss:6.3626 train_time:274ms step_avg:nanms step:13/1530 train_loss:6.2889 train_time:435ms step_avg:144.95ms step:14/1530 train_loss:6.2455 train_time:595ms step_avg:148.79ms step:15/1530 train_loss:6.1513 train_time:755ms step_avg:150.96ms step:16/1530 train_loss:6.1032 train_time:915ms step_avg:152.56ms step:17/1530 train_loss:6.1653 train_time:1075ms step_avg:153.59ms step:18/1530 train_loss:5.9902 train_time:1236ms step_avg:154.49ms step:19/1530 train_loss:5.9660 train_time:1396ms step_avg:155.14ms step:20/1530 train_loss:5.6972 train_time:1556ms step_avg:155.63ms step:21/1530 train_loss:5.9465 train_time:1717ms step_avg:156.12ms step:22/1530 train_loss:6.1615 train_time:1877ms step_avg:156.43ms step:23/1530 train_loss:5.8394 train_time:2038ms step_avg:156.76ms step:24/1530 train_loss:6.0311 train_time:2198ms step_avg:157.01ms step:25/1530 train_loss:5.6768 train_time:2358ms step_avg:157.20ms step:26/1530 train_loss:5.5814 train_time:2519ms step_avg:157.45ms step:27/1530 train_loss:5.7729 train_time:2679ms step_avg:157.57ms step:28/1530 train_loss:5.4105 train_time:2839ms step_avg:157.70ms step:29/1530 train_loss:5.6548 train_time:2999ms step_avg:157.85ms step:30/1530 train_loss:5.4713 train_time:3160ms step_avg:158.01ms step:31/1530 train_loss:5.4303 train_time:3320ms step_avg:158.10ms step:32/1530 train_loss:5.2821 train_time:3480ms step_avg:158.17ms step:33/1530 train_loss:5.5655 train_time:3641ms step_avg:158.29ms step:34/1530 train_loss:5.4880 train_time:3801ms step_avg:158.39ms step:35/1530 train_loss:5.6038 train_time:3962ms step_avg:158.47ms step:36/1530 train_loss:5.5378 train_time:4123ms step_avg:158.58ms step:37/1530 train_loss:5.4441 train_time:4284ms step_avg:158.66ms step:38/1530 train_loss:5.3051 train_time:4443ms step_avg:158.69ms step:39/1530 train_loss:5.3268 train_time:4604ms step_avg:158.77ms step:40/1530 train_loss:5.2411 train_time:4765ms step_avg:158.84ms step:41/1530 train_loss:5.2492 train_time:4926ms step_avg:158.91ms step:42/1530 train_loss:5.1910 train_time:5087ms step_avg:158.97ms step:43/1530 train_loss:5.2826 train_time:5247ms step_avg:158.99ms step:44/1530 train_loss:5.2321 train_time:5407ms step_avg:159.03ms step:45/1530 train_loss:5.3760 train_time:5568ms step_avg:159.10ms step:46/1530 train_loss:5.1772 train_time:5728ms step_avg:159.10ms step:47/1530 train_loss:5.0751 train_time:5888ms step_avg:159.14ms step:48/1530 train_loss:5.2039 train_time:6048ms step_avg:159.17ms step:49/1530 train_loss:5.1359 train_time:6208ms step_avg:159.19ms step:50/1530 train_loss:5.2461 train_time:6369ms step_avg:159.23ms step:51/1530 train_loss:5.1408 train_time:6530ms step_avg:159.28ms step:52/1530 train_loss:5.0356 train_time:6691ms step_avg:159.32ms step:53/1530 train_loss:5.1752 train_time:6852ms step_avg:159.35ms step:54/1530 train_loss:4.9979 train_time:7013ms step_avg:159.38ms step:55/1530 train_loss:5.4167 train_time:7172ms step_avg:159.38ms step:56/1530 train_loss:5.0339 train_time:7333ms step_avg:159.41ms step:57/1530 train_loss:4.8919 train_time:7493ms step_avg:159.42ms step:58/1530 train_loss:5.0326 train_time:7653ms step_avg:159.43ms step:59/1530 train_loss:5.0168 train_time:7813ms step_avg:159.45ms step:60/1530 train_loss:5.1372 train_time:7972ms step_avg:159.45ms step:61/1530 train_loss:4.8538 train_time:8133ms step_avg:159.47ms step:62/1530 train_loss:4.9729 train_time:8293ms step_avg:159.49ms step:63/1530 train_loss:4.9735 train_time:8453ms step_avg:159.49ms step:64/1530 train_loss:5.0418 train_time:8614ms step_avg:159.52ms step:65/1530 train_loss:4.8100 train_time:8773ms step_avg:159.51ms step:66/1530 train_loss:4.9145 train_time:8934ms step_avg:159.54ms step:67/1530 train_loss:4.8167 train_time:9095ms step_avg:159.55ms step:68/1530 train_loss:5.0829 train_time:9255ms step_avg:159.57ms step:69/1530 train_loss:4.7349 train_time:9415ms step_avg:159.58ms step:70/1530 train_loss:4.8479 train_time:9574ms step_avg:159.57ms step:71/1530 train_loss:4.9598 train_time:9736ms step_avg:159.60ms step:72/1530 train_loss:4.8812 train_time:9896ms step_avg:159.61ms step:73/1530 train_loss:4.7797 train_time:10056ms step_avg:159.63ms step:74/1530 train_loss:4.9233 train_time:10217ms step_avg:159.64ms step:75/1530 train_loss:4.8668 train_time:10377ms step_avg:159.64ms step:76/1530 train_loss:4.7935 train_time:10537ms step_avg:159.66ms step:77/1530 train_loss:4.9136 train_time:10698ms step_avg:159.67ms step:78/1530 train_loss:5.1370 train_time:10858ms step_avg:159.67ms step:79/1530 train_loss:4.8215 train_time:11018ms step_avg:159.68ms step:80/1530 train_loss:4.8621 train_time:11178ms step_avg:159.69ms step:81/1530 train_loss:4.6528 train_time:11339ms step_avg:159.70ms step:82/1530 train_loss:4.8240 train_time:11499ms step_avg:159.71ms step:83/1530 train_loss:4.7662 train_time:11659ms step_avg:159.71ms step:84/1530 train_loss:4.7635 train_time:11819ms step_avg:159.72ms step:85/1530 train_loss:4.6308 train_time:11980ms step_avg:159.73ms step:86/1530 train_loss:4.8405 train_time:12141ms step_avg:159.75ms step:87/1530 train_loss:4.7606 train_time:12301ms step_avg:159.76ms step:88/1530 train_loss:4.7623 train_time:12462ms step_avg:159.77ms step:89/1530 train_loss:4.7148 train_time:12624ms step_avg:159.80ms step:90/1530 train_loss:4.6508 train_time:12784ms step_avg:159.80ms step:91/1530 train_loss:4.6249 train_time:12945ms step_avg:159.81ms step:92/1530 train_loss:4.7989 train_time:13105ms step_avg:159.82ms step:93/1530 train_loss:4.6251 train_time:13266ms step_avg:159.83ms step:94/1530 train_loss:4.6449 train_time:13425ms step_avg:159.83ms step:95/1530 train_loss:4.6950 train_time:13586ms step_avg:159.83ms step:96/1530 train_loss:4.6048 train_time:13745ms step_avg:159.83ms step:97/1530 train_loss:4.6596 train_time:13905ms step_avg:159.83ms step:98/1530 train_loss:4.5872 train_time:14066ms step_avg:159.84ms step:99/1530 train_loss:4.6694 train_time:14227ms step_avg:159.85ms step:100/1530 train_loss:4.6808 train_time:14388ms step_avg:159.87ms step:101/1530 train_loss:4.5378 train_time:14548ms step_avg:159.87ms step:102/1530 train_loss:4.7094 train_time:14710ms step_avg:159.89ms step:103/1530 train_loss:4.5833 train_time:14871ms step_avg:159.90ms step:104/1530 train_loss:4.5457 train_time:15031ms step_avg:159.90ms step:105/1530 train_loss:4.5710 train_time:15192ms step_avg:159.92ms step:106/1530 train_loss:4.6741 train_time:15353ms step_avg:159.93ms step:107/1530 train_loss:4.5306 train_time:15513ms step_avg:159.93ms step:108/1530 train_loss:4.3762 train_time:15673ms step_avg:159.93ms step:109/1530 train_loss:4.4942 train_time:15833ms step_avg:159.93ms step:110/1530 train_loss:4.4909 train_time:15994ms step_avg:159.94ms step:111/1530 train_loss:4.4303 train_time:16154ms step_avg:159.94ms step:112/1530 train_loss:4.6024 train_time:16315ms step_avg:159.95ms step:113/1530 train_loss:4.5091 train_time:16474ms step_avg:159.94ms step:114/1530 train_loss:4.3860 train_time:16635ms step_avg:159.95ms step:115/1530 train_loss:4.5238 train_time:16797ms step_avg:159.97ms step:116/1530 train_loss:4.4675 train_time:16961ms step_avg:160.01ms step:117/1530 train_loss:4.3803 train_time:17125ms step_avg:160.05ms step:118/1530 train_loss:4.6102 train_time:17290ms step_avg:160.09ms step:119/1530 train_loss:4.4743 train_time:17455ms step_avg:160.14ms step:120/1530 train_loss:4.3470 train_time:17618ms step_avg:160.17ms step:121/1530 train_loss:4.3027 train_time:17782ms step_avg:160.20ms step:122/1530 train_loss:4.4497 train_time:17946ms step_avg:160.23ms step:123/1530 train_loss:4.2878 train_time:18110ms step_avg:160.27ms step:124/1530 train_loss:4.5965 train_time:18274ms step_avg:160.30ms step:125/1530 train_loss:4.4601 train_time:18438ms step_avg:160.33ms step:125/1530 val_loss:4.4160 train_time:18485ms step_avg:160.73ms step:126/1530 train_loss:4.4322 train_time:18605ms step_avg:160.39ms step:127/1530 train_loss:4.4615 train_time:18770ms step_avg:160.43ms step:128/1530 train_loss:4.4082 train_time:18934ms step_avg:160.46ms step:129/1530 train_loss:4.7061 train_time:19099ms step_avg:160.49ms step:130/1530 train_loss:4.3576 train_time:19262ms step_avg:160.51ms step:131/1530 train_loss:4.4159 train_time:19425ms step_avg:160.53ms step:132/1530 train_loss:4.3593 train_time:19589ms step_avg:160.56ms step:133/1530 train_loss:4.4524 train_time:19752ms step_avg:160.58ms step:134/1530 train_loss:4.2765 train_time:19916ms step_avg:160.62ms step:135/1530 train_loss:4.4603 train_time:20081ms step_avg:160.65ms step:136/1530 train_loss:4.2250 train_time:20243ms step_avg:160.66ms step:137/1530 train_loss:4.3782 train_time:20407ms step_avg:160.69ms step:138/1530 train_loss:4.2991 train_time:20572ms step_avg:160.72ms step:139/1530 train_loss:4.3958 train_time:20737ms step_avg:160.75ms step:140/1530 train_loss:4.4765 train_time:20901ms step_avg:160.78ms step:141/1530 train_loss:4.3186 train_time:21064ms step_avg:160.80ms step:142/1530 train_loss:4.3130 train_time:21227ms step_avg:160.81ms step:143/1530 train_loss:4.2577 train_time:21394ms step_avg:160.86ms step:144/1530 train_loss:4.3530 train_time:21558ms step_avg:160.88ms step:145/1530 train_loss:4.3115 train_time:21722ms step_avg:160.90ms step:146/1530 train_loss:4.1817 train_time:21886ms step_avg:160.93ms step:147/1530 train_loss:4.3446 train_time:22050ms step_avg:160.95ms step:148/1530 train_loss:4.3775 train_time:22214ms step_avg:160.97ms step:149/1530 train_loss:4.3192 train_time:22379ms step_avg:161.00ms step:150/1530 train_loss:4.4442 train_time:22542ms step_avg:161.01ms step:151/1530 train_loss:4.2719 train_time:22706ms step_avg:161.03ms step:152/1530 train_loss:4.2799 train_time:22870ms step_avg:161.06ms step:153/1530 train_loss:4.3647 train_time:23035ms step_avg:161.08ms step:154/1530 train_loss:4.3702 train_time:23200ms step_avg:161.11ms step:155/1530 train_loss:4.2799 train_time:23364ms step_avg:161.13ms step:156/1530 train_loss:4.3479 train_time:23526ms step_avg:161.14ms step:157/1530 train_loss:4.4106 train_time:23693ms step_avg:161.18ms step:158/1530 train_loss:4.2481 train_time:23857ms step_avg:161.20ms step:159/1530 train_loss:4.3084 train_time:24020ms step_avg:161.21ms step:160/1530 train_loss:4.1437 train_time:24184ms step_avg:161.22ms step:161/1530 train_loss:4.3592 train_time:24347ms step_avg:161.24ms step:162/1530 train_loss:4.3632 train_time:24511ms step_avg:161.26ms step:163/1530 train_loss:4.3556 train_time:24675ms step_avg:161.28ms step:164/1530 train_loss:4.2047 train_time:24839ms step_avg:161.29ms step:165/1530 train_loss:4.2908 train_time:25002ms step_avg:161.31ms step:166/1530 train_loss:4.3439 train_time:25165ms step_avg:161.32ms step:167/1530 train_loss:4.1948 train_time:25329ms step_avg:161.33ms step:168/1530 train_loss:4.2914 train_time:25493ms step_avg:161.35ms step:169/1530 train_loss:4.1710 train_time:25658ms step_avg:161.37ms step:170/1530 train_loss:4.0307 train_time:25822ms step_avg:161.39ms step:171/1530 train_loss:4.2100 train_time:25985ms step_avg:161.40ms step:172/1530 train_loss:4.2132 train_time:26148ms step_avg:161.41ms step:173/1530 train_loss:4.2643 train_time:26310ms step_avg:161.41ms step:174/1530 train_loss:4.4182 train_time:26473ms step_avg:161.42ms step:175/1530 train_loss:4.2429 train_time:26637ms step_avg:161.44ms step:176/1530 train_loss:4.0917 train_time:26799ms step_avg:161.44ms step:177/1530 train_loss:4.0697 train_time:26961ms step_avg:161.44ms step:178/1530 train_loss:4.1809 train_time:27123ms step_avg:161.45ms step:179/1530 train_loss:4.1202 train_time:27287ms step_avg:161.46ms step:180/1530 train_loss:4.1058 train_time:27448ms step_avg:161.46ms step:181/1530 train_loss:4.2952 train_time:27612ms step_avg:161.48ms step:182/1530 train_loss:4.1582 train_time:27776ms step_avg:161.49ms step:183/1530 train_loss:4.1449 train_time:27938ms step_avg:161.49ms step:184/1530 train_loss:4.1239 train_time:28101ms step_avg:161.50ms step:185/1530 train_loss:4.2078 train_time:28264ms step_avg:161.51ms step:186/1530 train_loss:4.1662 train_time:28425ms step_avg:161.51ms step:187/1530 train_loss:4.2342 train_time:28589ms step_avg:161.52ms step:188/1530 train_loss:4.1738 train_time:28884ms step_avg:162.27ms step:189/1530 train_loss:4.1263 train_time:29229ms step_avg:163.29ms step:190/1530 train_loss:4.2146 train_time:29391ms step_avg:163.29ms step:191/1530 train_loss:4.0843 train_time:29553ms step_avg:163.28ms step:192/1530 train_loss:4.0483 train_time:29717ms step_avg:163.28ms step:193/1530 train_loss:4.2562 train_time:29879ms step_avg:163.27ms step:194/1530 train_loss:4.1743 train_time:30040ms step_avg:163.26ms step:195/1530 train_loss:4.3514 train_time:30203ms step_avg:163.26ms step:196/1530 train_loss:4.1699 train_time:30366ms step_avg:163.26ms step:197/1530 train_loss:4.0463 train_time:30529ms step_avg:163.26ms step:198/1530 train_loss:4.1772 train_time:30694ms step_avg:163.26ms step:199/1530 train_loss:4.0398 train_time:30857ms step_avg:163.26ms step:200/1530 train_loss:4.1187 train_time:31019ms step_avg:163.26ms step:201/1530 train_loss:4.0198 train_time:31182ms step_avg:163.26ms step:202/1530 train_loss:4.2549 train_time:31345ms step_avg:163.25ms step:203/1530 train_loss:4.0645 train_time:31507ms step_avg:163.25ms step:204/1530 train_loss:4.1868 train_time:31670ms step_avg:163.25ms step:205/1530 train_loss:4.2494 train_time:31834ms step_avg:163.25ms step:206/1530 train_loss:3.9449 train_time:31996ms step_avg:163.25ms step:207/1530 train_loss:4.0828 train_time:32160ms step_avg:163.25ms step:208/1530 train_loss:4.0969 train_time:32322ms step_avg:163.24ms step:209/1530 train_loss:4.2297 train_time:32485ms step_avg:163.24ms step:210/1530 train_loss:4.1912 train_time:32648ms step_avg:163.24ms step:211/1530 train_loss:4.0612 train_time:32814ms step_avg:163.25ms step:212/1530 train_loss:4.1178 train_time:32977ms step_avg:163.25ms step:213/1530 train_loss:4.0522 train_time:33139ms step_avg:163.24ms step:214/1530 train_loss:4.1160 train_time:33302ms step_avg:163.24ms step:215/1530 train_loss:3.9645 train_time:33464ms step_avg:163.24ms step:216/1530 train_loss:4.0060 train_time:33628ms step_avg:163.24ms step:217/1530 train_loss:4.0125 train_time:33792ms step_avg:163.25ms step:218/1530 train_loss:4.0726 train_time:33956ms step_avg:163.25ms step:219/1530 train_loss:4.0610 train_time:34119ms step_avg:163.25ms step:220/1530 train_loss:4.0762 train_time:34282ms step_avg:163.25ms step:221/1530 train_loss:4.0876 train_time:34444ms step_avg:163.24ms step:222/1530 train_loss:4.0079 train_time:34607ms step_avg:163.24ms step:223/1530 train_loss:3.9896 train_time:34771ms step_avg:163.24ms step:224/1530 train_loss:4.2995 train_time:34936ms step_avg:163.25ms step:225/1530 train_loss:3.9251 train_time:35100ms step_avg:163.25ms step:226/1530 train_loss:3.9905 train_time:35262ms step_avg:163.25ms step:227/1530 train_loss:3.9785 train_time:35424ms step_avg:163.25ms step:228/1530 train_loss:4.1437 train_time:35590ms step_avg:163.26ms step:229/1530 train_loss:3.9204 train_time:35757ms step_avg:163.28ms step:230/1530 train_loss:4.0416 train_time:35922ms step_avg:163.28ms step:231/1530 train_loss:3.8978 train_time:36088ms step_avg:163.30ms step:232/1530 train_loss:3.9621 train_time:36254ms step_avg:163.31ms step:233/1530 train_loss:4.0896 train_time:36420ms step_avg:163.32ms step:234/1530 train_loss:4.0380 train_time:36586ms step_avg:163.33ms step:235/1530 train_loss:3.8923 train_time:36752ms step_avg:163.34ms step:236/1530 train_loss:4.0804 train_time:36919ms step_avg:163.36ms step:237/1530 train_loss:4.0741 train_time:37084ms step_avg:163.37ms step:238/1530 train_loss:3.9352 train_time:37250ms step_avg:163.38ms step:239/1530 train_loss:4.0774 train_time:37417ms step_avg:163.39ms step:240/1530 train_loss:4.1092 train_time:37583ms step_avg:163.40ms step:241/1530 train_loss:3.9586 train_time:37748ms step_avg:163.41ms step:242/1530 train_loss:4.1345 train_time:37916ms step_avg:163.43ms step:243/1530 train_loss:4.0090 train_time:38082ms step_avg:163.44ms step:244/1530 train_loss:4.0727 train_time:38247ms step_avg:163.45ms step:245/1530 train_loss:4.1347 train_time:38413ms step_avg:163.46ms step:246/1530 train_loss:4.0640 train_time:38580ms step_avg:163.48ms step:247/1530 train_loss:4.0216 train_time:38745ms step_avg:163.48ms step:248/1530 train_loss:4.1001 train_time:38912ms step_avg:163.49ms step:249/1530 train_loss:3.9136 train_time:39078ms step_avg:163.51ms step:250/1530 train_loss:3.9707 train_time:39243ms step_avg:163.51ms step:250/1530 val_loss:4.0063 train_time:39291ms step_avg:163.71ms step:251/1530 train_loss:4.0754 train_time:39413ms step_avg:163.54ms step:252/1530 train_loss:4.1701 train_time:39579ms step_avg:163.55ms step:253/1530 train_loss:3.9266 train_time:39745ms step_avg:163.56ms step:254/1530 train_loss:3.8776 train_time:39911ms step_avg:163.57ms step:255/1530 train_loss:4.0718 train_time:40077ms step_avg:163.58ms step:256/1530 train_loss:3.9838 train_time:40242ms step_avg:163.58ms step:257/1530 train_loss:3.9921 train_time:40408ms step_avg:163.60ms step:258/1530 train_loss:3.9937 train_time:40575ms step_avg:163.61ms step:259/1530 train_loss:4.0258 train_time:40742ms step_avg:163.62ms step:260/1530 train_loss:4.0484 train_time:40907ms step_avg:163.63ms step:261/1530 train_loss:4.0143 train_time:41075ms step_avg:163.65ms step:262/1530 train_loss:3.9921 train_time:41242ms step_avg:163.66ms step:263/1530 train_loss:3.8910 train_time:41407ms step_avg:163.67ms step:264/1530 train_loss:3.9813 train_time:41574ms step_avg:163.68ms step:265/1530 train_loss:3.8649 train_time:41740ms step_avg:163.69ms step:266/1530 train_loss:3.9227 train_time:41906ms step_avg:163.70ms step:267/1530 train_loss:3.9318 train_time:42072ms step_avg:163.71ms step:268/1530 train_loss:3.9575 train_time:42238ms step_avg:163.71ms step:269/1530 train_loss:3.8498 train_time:42402ms step_avg:163.72ms step:270/1530 train_loss:4.0985 train_time:42569ms step_avg:163.73ms step:271/1530 train_loss:3.9645 train_time:42736ms step_avg:163.74ms step:272/1530 train_loss:3.9146 train_time:42901ms step_avg:163.74ms step:273/1530 train_loss:3.9429 train_time:43067ms step_avg:163.75ms step:274/1530 train_loss:4.0353 train_time:43234ms step_avg:163.77ms step:275/1530 train_loss:4.0576 train_time:43399ms step_avg:163.77ms step:276/1530 train_loss:4.2236 train_time:43565ms step_avg:163.78ms step:277/1530 train_loss:4.0390 train_time:43732ms step_avg:163.79ms step:278/1530 train_loss:4.0837 train_time:43898ms step_avg:163.80ms step:279/1530 train_loss:3.9983 train_time:44064ms step_avg:163.81ms step:280/1530 train_loss:4.2142 train_time:44231ms step_avg:163.82ms step:281/1530 train_loss:3.9709 train_time:44397ms step_avg:163.83ms step:282/1530 train_loss:3.9430 train_time:44564ms step_avg:163.84ms step:283/1530 train_loss:3.9073 train_time:44731ms step_avg:163.85ms step:284/1530 train_loss:4.0410 train_time:44897ms step_avg:163.86ms step:285/1530 train_loss:4.0548 train_time:45062ms step_avg:163.86ms step:286/1530 train_loss:4.0908 train_time:45226ms step_avg:163.86ms step:287/1530 train_loss:3.9131 train_time:45393ms step_avg:163.87ms step:288/1530 train_loss:4.0086 train_time:45557ms step_avg:163.87ms step:289/1530 train_loss:3.8766 train_time:45723ms step_avg:163.88ms step:290/1530 train_loss:3.8541 train_time:45888ms step_avg:163.89ms step:291/1530 train_loss:3.9036 train_time:46054ms step_avg:163.89ms step:292/1530 train_loss:3.8629 train_time:46218ms step_avg:163.89ms step:293/1530 train_loss:3.9012 train_time:46383ms step_avg:163.90ms step:294/1530 train_loss:3.9328 train_time:46550ms step_avg:163.91ms step:295/1530 train_loss:3.8447 train_time:46714ms step_avg:163.91ms step:296/1530 train_loss:3.8624 train_time:46879ms step_avg:163.91ms step:297/1530 train_loss:3.8648 train_time:47045ms step_avg:163.92ms step:298/1530 train_loss:3.9697 train_time:47212ms step_avg:163.93ms step:299/1530 train_loss:3.8201 train_time:47376ms step_avg:163.93ms step:300/1530 train_loss:3.9668 train_time:47541ms step_avg:163.94ms step:301/1530 train_loss:3.9605 train_time:47707ms step_avg:163.94ms step:302/1530 train_loss:3.9323 train_time:47872ms step_avg:163.95ms step:303/1530 train_loss:3.9813 train_time:48038ms step_avg:163.95ms step:304/1530 train_loss:3.9561 train_time:48202ms step_avg:163.95ms step:305/1530 train_loss:4.4451 train_time:48368ms step_avg:163.96ms step:306/1530 train_loss:3.9351 train_time:48534ms step_avg:163.97ms step:307/1530 train_loss:3.8352 train_time:48699ms step_avg:163.97ms step:308/1530 train_loss:3.9756 train_time:48863ms step_avg:163.97ms step:309/1530 train_loss:3.8787 train_time:49031ms step_avg:163.98ms step:310/1530 train_loss:4.0783 train_time:49196ms step_avg:163.99ms step:311/1530 train_loss:3.9239 train_time:49361ms step_avg:163.99ms step:312/1530 train_loss:3.8597 train_time:49527ms step_avg:164.00ms step:313/1530 train_loss:3.9339 train_time:49693ms step_avg:164.00ms step:314/1530 train_loss:4.0613 train_time:49859ms step_avg:164.01ms step:315/1530 train_loss:3.9397 train_time:50023ms step_avg:164.01ms step:316/1530 train_loss:3.7911 train_time:50190ms step_avg:164.02ms step:317/1530 train_loss:3.8732 train_time:50354ms step_avg:164.02ms step:318/1530 train_loss:3.9178 train_time:50520ms step_avg:164.03ms step:319/1530 train_loss:3.8874 train_time:50686ms step_avg:164.03ms step:320/1530 train_loss:4.0141 train_time:50854ms step_avg:164.05ms step:321/1530 train_loss:3.9601 train_time:51019ms step_avg:164.05ms step:322/1530 train_loss:3.9266 train_time:51184ms step_avg:164.05ms step:323/1530 train_loss:4.0110 train_time:51350ms step_avg:164.06ms step:324/1530 train_loss:3.9404 train_time:51515ms step_avg:164.06ms step:325/1530 train_loss:4.0138 train_time:51680ms step_avg:164.06ms step:326/1530 train_loss:3.8898 train_time:51846ms step_avg:164.07ms step:327/1530 train_loss:4.3932 train_time:52012ms step_avg:164.08ms step:328/1530 train_loss:4.0716 train_time:52177ms step_avg:164.08ms step:329/1530 train_loss:3.7912 train_time:52342ms step_avg:164.08ms step:330/1530 train_loss:3.7438 train_time:52507ms step_avg:164.08ms step:331/1530 train_loss:3.9749 train_time:52673ms step_avg:164.09ms step:332/1530 train_loss:3.9097 train_time:52839ms step_avg:164.10ms step:333/1530 train_loss:3.8883 train_time:53003ms step_avg:164.10ms step:334/1530 train_loss:3.8379 train_time:53169ms step_avg:164.10ms step:335/1530 train_loss:4.0116 train_time:53334ms step_avg:164.11ms step:336/1530 train_loss:3.9668 train_time:53498ms step_avg:164.10ms step:337/1530 train_loss:4.4240 train_time:53663ms step_avg:164.11ms step:338/1530 train_loss:3.9328 train_time:53829ms step_avg:164.11ms step:339/1530 train_loss:3.8575 train_time:53994ms step_avg:164.12ms step:340/1530 train_loss:3.9333 train_time:54159ms step_avg:164.12ms step:341/1530 train_loss:3.8551 train_time:54325ms step_avg:164.12ms step:342/1530 train_loss:3.8070 train_time:54493ms step_avg:164.14ms step:343/1530 train_loss:3.8359 train_time:54661ms step_avg:164.15ms step:344/1530 train_loss:3.9947 train_time:54829ms step_avg:164.16ms step:345/1530 train_loss:3.8160 train_time:54997ms step_avg:164.17ms step:346/1530 train_loss:3.7681 train_time:55165ms step_avg:164.18ms step:347/1530 train_loss:3.7904 train_time:55335ms step_avg:164.20ms step:348/1530 train_loss:3.8514 train_time:55502ms step_avg:164.21ms step:349/1530 train_loss:3.8220 train_time:55671ms step_avg:164.22ms step:350/1530 train_loss:3.5712 train_time:55840ms step_avg:164.24ms step:351/1530 train_loss:3.8266 train_time:56008ms step_avg:164.24ms step:352/1530 train_loss:4.1802 train_time:56175ms step_avg:164.26ms step:353/1530 train_loss:3.6712 train_time:56343ms step_avg:164.26ms step:354/1530 train_loss:3.9218 train_time:56510ms step_avg:164.27ms step:355/1530 train_loss:3.7808 train_time:56679ms step_avg:164.29ms step:356/1530 train_loss:3.8784 train_time:56847ms step_avg:164.30ms step:357/1530 train_loss:3.7525 train_time:57016ms step_avg:164.31ms step:358/1530 train_loss:3.8700 train_time:57183ms step_avg:164.32ms step:359/1530 train_loss:3.8006 train_time:57354ms step_avg:164.34ms step:360/1530 train_loss:3.4228 train_time:57523ms step_avg:164.35ms step:361/1530 train_loss:4.0122 train_time:57691ms step_avg:164.36ms step:362/1530 train_loss:3.9090 train_time:57861ms step_avg:164.38ms step:363/1530 train_loss:3.8324 train_time:58028ms step_avg:164.38ms step:364/1530 train_loss:3.7378 train_time:58196ms step_avg:164.39ms step:365/1530 train_loss:3.9109 train_time:58364ms step_avg:164.40ms step:366/1530 train_loss:3.8541 train_time:58533ms step_avg:164.42ms step:367/1530 train_loss:3.8519 train_time:58699ms step_avg:164.42ms step:368/1530 train_loss:3.8445 train_time:58867ms step_avg:164.43ms step:369/1530 train_loss:3.7441 train_time:59035ms step_avg:164.44ms step:370/1530 train_loss:3.8743 train_time:59202ms step_avg:164.45ms step:371/1530 train_loss:3.7307 train_time:59371ms step_avg:164.46ms step:372/1530 train_loss:3.6949 train_time:59541ms step_avg:164.48ms step:373/1530 train_loss:3.9099 train_time:59709ms step_avg:164.49ms step:374/1530 train_loss:3.8275 train_time:59877ms step_avg:164.50ms step:375/1530 train_loss:3.8016 train_time:60044ms step_avg:164.50ms step:375/1530 val_loss:3.8237 train_time:60092ms step_avg:164.64ms step:376/1530 train_loss:3.8634 train_time:60214ms step_avg:164.52ms step:377/1530 train_loss:3.7874 train_time:60514ms step_avg:164.89ms step:378/1530 train_loss:3.8406 train_time:60693ms step_avg:164.93ms step:379/1530 train_loss:3.8690 train_time:61020ms step_avg:165.37ms step:380/1530 train_loss:3.9641 train_time:61188ms step_avg:165.37ms step:381/1530 train_loss:3.8381 train_time:61355ms step_avg:165.38ms step:382/1530 train_loss:3.7965 train_time:61524ms step_avg:165.39ms step:383/1530 train_loss:3.7926 train_time:61693ms step_avg:165.40ms step:384/1530 train_loss:3.8664 train_time:61860ms step_avg:165.40ms step:385/1530 train_loss:3.7892 train_time:62028ms step_avg:165.41ms step:386/1530 train_loss:3.8885 train_time:62196ms step_avg:165.42ms step:387/1530 train_loss:4.0626 train_time:62364ms step_avg:165.42ms step:388/1530 train_loss:3.7898 train_time:62532ms step_avg:165.43ms step:389/1530 train_loss:3.7889 train_time:62698ms step_avg:165.43ms step:390/1530 train_loss:3.8910 train_time:62868ms step_avg:165.44ms step:391/1530 train_loss:3.8070 train_time:63035ms step_avg:165.45ms step:392/1530 train_loss:3.9180 train_time:63202ms step_avg:165.45ms step:393/1530 train_loss:3.7587 train_time:63369ms step_avg:165.46ms step:394/1530 train_loss:3.8786 train_time:63536ms step_avg:165.46ms step:395/1530 train_loss:3.6336 train_time:63704ms step_avg:165.47ms step:396/1530 train_loss:3.8347 train_time:63875ms step_avg:165.48ms step:397/1530 train_loss:3.8552 train_time:64043ms step_avg:165.49ms step:398/1530 train_loss:3.8664 train_time:64211ms step_avg:165.49ms step:399/1530 train_loss:3.7662 train_time:64378ms step_avg:165.50ms step:400/1530 train_loss:3.8303 train_time:64547ms step_avg:165.50ms step:401/1530 train_loss:3.9215 train_time:64714ms step_avg:165.51ms step:402/1530 train_loss:3.8434 train_time:64881ms step_avg:165.51ms step:403/1530 train_loss:3.9592 train_time:65048ms step_avg:165.52ms step:404/1530 train_loss:3.6691 train_time:65216ms step_avg:165.52ms step:405/1530 train_loss:3.7773 train_time:65383ms step_avg:165.53ms step:406/1530 train_loss:4.0949 train_time:65552ms step_avg:165.53ms step:407/1530 train_loss:3.7748 train_time:65719ms step_avg:165.54ms step:408/1530 train_loss:3.8138 train_time:65885ms step_avg:165.54ms step:409/1530 train_loss:3.8519 train_time:66052ms step_avg:165.54ms step:410/1530 train_loss:3.7558 train_time:66218ms step_avg:165.55ms step:411/1530 train_loss:3.7573 train_time:66385ms step_avg:165.55ms step:412/1530 train_loss:4.1877 train_time:66553ms step_avg:165.55ms step:413/1530 train_loss:3.6852 train_time:66720ms step_avg:165.56ms step:414/1530 train_loss:4.0068 train_time:66887ms step_avg:165.56ms step:415/1530 train_loss:3.7511 train_time:67053ms step_avg:165.56ms step:416/1530 train_loss:3.7635 train_time:67220ms step_avg:165.57ms step:417/1530 train_loss:3.9567 train_time:67389ms step_avg:165.57ms step:418/1530 train_loss:3.6907 train_time:67556ms step_avg:165.58ms step:419/1530 train_loss:3.7982 train_time:67722ms step_avg:165.58ms step:420/1530 train_loss:3.7003 train_time:67891ms step_avg:165.59ms step:421/1530 train_loss:3.6467 train_time:68057ms step_avg:165.59ms step:422/1530 train_loss:3.7812 train_time:68224ms step_avg:165.59ms step:423/1530 train_loss:3.8670 train_time:68393ms step_avg:165.60ms step:424/1530 train_loss:3.6099 train_time:68560ms step_avg:165.60ms step:425/1530 train_loss:3.7953 train_time:68727ms step_avg:165.61ms step:426/1530 train_loss:3.6580 train_time:68895ms step_avg:165.61ms step:427/1530 train_loss:3.8874 train_time:69062ms step_avg:165.62ms step:428/1530 train_loss:3.8105 train_time:69230ms step_avg:165.62ms step:429/1530 train_loss:3.7573 train_time:69397ms step_avg:165.62ms step:430/1530 train_loss:3.7043 train_time:69565ms step_avg:165.63ms step:431/1530 train_loss:3.6208 train_time:69732ms step_avg:165.64ms step:432/1530 train_loss:3.7634 train_time:69899ms step_avg:165.64ms step:433/1530 train_loss:3.8066 train_time:70067ms step_avg:165.64ms step:434/1530 train_loss:3.7673 train_time:70233ms step_avg:165.64ms step:435/1530 train_loss:3.7954 train_time:70399ms step_avg:165.65ms step:436/1530 train_loss:3.8208 train_time:70567ms step_avg:165.65ms step:437/1530 train_loss:3.7102 train_time:70734ms step_avg:165.65ms step:438/1530 train_loss:3.7012 train_time:70901ms step_avg:165.66ms step:439/1530 train_loss:3.6997 train_time:71069ms step_avg:165.66ms step:440/1530 train_loss:3.8891 train_time:71236ms step_avg:165.66ms step:441/1530 train_loss:3.7590 train_time:71402ms step_avg:165.67ms step:442/1530 train_loss:3.7339 train_time:71573ms step_avg:165.68ms step:443/1530 train_loss:3.6183 train_time:71739ms step_avg:165.68ms step:444/1530 train_loss:3.9293 train_time:71906ms step_avg:165.68ms step:445/1530 train_loss:3.8426 train_time:72073ms step_avg:165.68ms step:446/1530 train_loss:3.8351 train_time:72239ms step_avg:165.69ms step:447/1530 train_loss:3.7459 train_time:72407ms step_avg:165.69ms step:448/1530 train_loss:3.8475 train_time:72575ms step_avg:165.70ms step:449/1530 train_loss:3.6843 train_time:72741ms step_avg:165.70ms step:450/1530 train_loss:3.7198 train_time:72910ms step_avg:165.70ms step:451/1530 train_loss:3.5849 train_time:73077ms step_avg:165.71ms step:452/1530 train_loss:3.7121 train_time:73243ms step_avg:165.71ms step:453/1530 train_loss:3.6702 train_time:73412ms step_avg:165.72ms step:454/1530 train_loss:3.6308 train_time:73579ms step_avg:165.72ms step:455/1530 train_loss:3.8411 train_time:73748ms step_avg:165.73ms step:456/1530 train_loss:3.7248 train_time:73917ms step_avg:165.73ms step:457/1530 train_loss:3.7756 train_time:74086ms step_avg:165.74ms step:458/1530 train_loss:3.8238 train_time:74256ms step_avg:165.75ms step:459/1530 train_loss:3.6293 train_time:74427ms step_avg:165.76ms step:460/1530 train_loss:3.7836 train_time:74596ms step_avg:165.77ms step:461/1530 train_loss:3.6795 train_time:74767ms step_avg:165.78ms step:462/1530 train_loss:3.7346 train_time:74935ms step_avg:165.79ms step:463/1530 train_loss:3.7721 train_time:75105ms step_avg:165.79ms step:464/1530 train_loss:3.7143 train_time:75275ms step_avg:165.80ms step:465/1530 train_loss:3.7133 train_time:75443ms step_avg:165.81ms step:466/1530 train_loss:3.7942 train_time:75614ms step_avg:165.82ms step:467/1530 train_loss:3.8221 train_time:75785ms step_avg:165.83ms step:468/1530 train_loss:3.7903 train_time:75954ms step_avg:165.84ms step:469/1530 train_loss:3.6821 train_time:76125ms step_avg:165.85ms step:470/1530 train_loss:3.7644 train_time:76295ms step_avg:165.86ms step:471/1530 train_loss:3.8070 train_time:76464ms step_avg:165.87ms step:472/1530 train_loss:3.7774 train_time:76636ms step_avg:165.88ms step:473/1530 train_loss:3.7124 train_time:76804ms step_avg:165.88ms step:474/1530 train_loss:3.5881 train_time:76975ms step_avg:165.89ms step:475/1530 train_loss:4.0127 train_time:77144ms step_avg:165.90ms step:476/1530 train_loss:3.7471 train_time:77314ms step_avg:165.91ms step:477/1530 train_loss:3.5815 train_time:77483ms step_avg:165.92ms step:478/1530 train_loss:3.8205 train_time:77652ms step_avg:165.92ms step:479/1530 train_loss:3.7677 train_time:77822ms step_avg:165.93ms step:480/1530 train_loss:3.9197 train_time:77993ms step_avg:165.94ms step:481/1530 train_loss:3.7180 train_time:78162ms step_avg:165.95ms step:482/1530 train_loss:3.5202 train_time:78332ms step_avg:165.96ms step:483/1530 train_loss:3.8000 train_time:78501ms step_avg:165.96ms step:484/1530 train_loss:3.6587 train_time:78673ms step_avg:165.98ms step:485/1530 train_loss:3.6524 train_time:78842ms step_avg:165.98ms step:486/1530 train_loss:3.5673 train_time:79014ms step_avg:165.99ms step:487/1530 train_loss:3.6796 train_time:79182ms step_avg:166.00ms step:488/1530 train_loss:3.8704 train_time:79352ms step_avg:166.01ms step:489/1530 train_loss:3.7052 train_time:79521ms step_avg:166.02ms step:490/1530 train_loss:3.5860 train_time:79692ms step_avg:166.03ms step:491/1530 train_loss:3.6079 train_time:79861ms step_avg:166.03ms step:492/1530 train_loss:3.7297 train_time:80030ms step_avg:166.04ms step:493/1530 train_loss:3.5710 train_time:80200ms step_avg:166.05ms step:494/1530 train_loss:3.6967 train_time:80370ms step_avg:166.05ms step:495/1530 train_loss:3.6535 train_time:80540ms step_avg:166.06ms step:496/1530 train_loss:3.5032 train_time:80713ms step_avg:166.08ms step:497/1530 train_loss:3.7334 train_time:80881ms step_avg:166.08ms step:498/1530 train_loss:3.7808 train_time:81051ms step_avg:166.09ms step:499/1530 train_loss:3.8206 train_time:81220ms step_avg:166.09ms step:500/1530 train_loss:3.7296 train_time:81393ms step_avg:166.11ms step:500/1530 val_loss:3.6993 train_time:81442ms step_avg:166.21ms step:501/1530 train_loss:3.8004 train_time:81563ms step_avg:166.12ms step:502/1530 train_loss:3.7418 train_time:81734ms step_avg:166.13ms step:503/1530 train_loss:3.7715 train_time:81905ms step_avg:166.14ms step:504/1530 train_loss:3.7106 train_time:82073ms step_avg:166.14ms step:505/1530 train_loss:3.8001 train_time:82243ms step_avg:166.15ms step:506/1530 train_loss:3.6389 train_time:82412ms step_avg:166.15ms step:507/1530 train_loss:3.7627 train_time:82581ms step_avg:166.16ms step:508/1530 train_loss:3.8209 train_time:82750ms step_avg:166.16ms step:509/1530 train_loss:3.7716 train_time:82921ms step_avg:166.17ms step:510/1530 train_loss:3.5782 train_time:83090ms step_avg:166.18ms step:511/1530 train_loss:3.7715 train_time:83260ms step_avg:166.19ms step:512/1530 train_loss:3.7149 train_time:83430ms step_avg:166.19ms step:513/1530 train_loss:3.6598 train_time:83600ms step_avg:166.20ms step:514/1530 train_loss:3.8606 train_time:83768ms step_avg:166.21ms step:515/1530 train_loss:3.7316 train_time:83937ms step_avg:166.21ms step:516/1530 train_loss:4.0725 train_time:84107ms step_avg:166.22ms step:517/1530 train_loss:3.6883 train_time:84276ms step_avg:166.23ms step:518/1530 train_loss:3.7609 train_time:84445ms step_avg:166.23ms step:519/1530 train_loss:3.6533 train_time:84615ms step_avg:166.24ms step:520/1530 train_loss:3.6774 train_time:84784ms step_avg:166.24ms step:521/1530 train_loss:3.6669 train_time:84954ms step_avg:166.25ms step:522/1530 train_loss:3.6464 train_time:85124ms step_avg:166.26ms step:523/1530 train_loss:4.2813 train_time:85294ms step_avg:166.27ms step:524/1530 train_loss:3.7347 train_time:85462ms step_avg:166.27ms step:525/1530 train_loss:3.6762 train_time:85630ms step_avg:166.27ms step:526/1530 train_loss:3.6913 train_time:85801ms step_avg:166.28ms step:527/1530 train_loss:3.6484 train_time:85969ms step_avg:166.28ms step:528/1530 train_loss:3.6260 train_time:86139ms step_avg:166.29ms step:529/1530 train_loss:3.8451 train_time:86308ms step_avg:166.30ms step:530/1530 train_loss:3.6410 train_time:86478ms step_avg:166.30ms step:531/1530 train_loss:3.9160 train_time:86648ms step_avg:166.31ms step:532/1530 train_loss:3.7270 train_time:86817ms step_avg:166.32ms step:533/1530 train_loss:3.6503 train_time:86986ms step_avg:166.32ms step:534/1530 train_loss:3.6636 train_time:87154ms step_avg:166.32ms step:535/1530 train_loss:3.6063 train_time:87324ms step_avg:166.33ms step:536/1530 train_loss:3.7473 train_time:87495ms step_avg:166.34ms step:537/1530 train_loss:3.7180 train_time:87664ms step_avg:166.34ms step:538/1530 train_loss:3.6220 train_time:87835ms step_avg:166.35ms step:539/1530 train_loss:4.1017 train_time:88007ms step_avg:166.36ms step:540/1530 train_loss:3.6719 train_time:88176ms step_avg:166.37ms step:541/1530 train_loss:3.7798 train_time:88345ms step_avg:166.37ms step:542/1530 train_loss:3.5791 train_time:88516ms step_avg:166.38ms step:543/1530 train_loss:3.5847 train_time:88683ms step_avg:166.39ms step:544/1530 train_loss:3.6290 train_time:88852ms step_avg:166.39ms step:545/1530 train_loss:3.5887 train_time:89021ms step_avg:166.39ms step:546/1530 train_loss:3.6153 train_time:89189ms step_avg:166.40ms step:547/1530 train_loss:3.6377 train_time:89359ms step_avg:166.40ms step:548/1530 train_loss:3.6035 train_time:89528ms step_avg:166.41ms step:549/1530 train_loss:3.7189 train_time:89697ms step_avg:166.41ms step:550/1530 train_loss:3.6109 train_time:89865ms step_avg:166.42ms step:551/1530 train_loss:3.6245 train_time:90033ms step_avg:166.42ms step:552/1530 train_loss:3.9297 train_time:90204ms step_avg:166.43ms step:553/1530 train_loss:3.7529 train_time:90374ms step_avg:166.43ms step:554/1530 train_loss:3.7083 train_time:90542ms step_avg:166.44ms step:555/1530 train_loss:3.6214 train_time:90711ms step_avg:166.44ms step:556/1530 train_loss:3.6916 train_time:90880ms step_avg:166.45ms step:557/1530 train_loss:3.3035 train_time:91049ms step_avg:166.45ms step:558/1530 train_loss:3.6099 train_time:91218ms step_avg:166.46ms step:559/1530 train_loss:3.6401 train_time:91387ms step_avg:166.46ms step:560/1530 train_loss:3.6882 train_time:91557ms step_avg:166.47ms step:561/1530 train_loss:3.6103 train_time:91725ms step_avg:166.47ms step:562/1530 train_loss:3.5474 train_time:91894ms step_avg:166.48ms step:563/1530 train_loss:3.7521 train_time:92063ms step_avg:166.48ms step:564/1530 train_loss:3.5649 train_time:92233ms step_avg:166.49ms step:565/1530 train_loss:3.6807 train_time:92402ms step_avg:166.49ms step:566/1530 train_loss:3.6092 train_time:92702ms step_avg:166.73ms step:567/1530 train_loss:3.5983 train_time:92882ms step_avg:166.75ms step:568/1530 train_loss:3.6737 train_time:93052ms step_avg:166.76ms step:569/1530 train_loss:3.6448 train_time:93386ms step_avg:167.06ms step:570/1530 train_loss:3.6865 train_time:93557ms step_avg:167.07ms step:571/1530 train_loss:3.7502 train_time:93726ms step_avg:167.07ms step:572/1530 train_loss:3.7211 train_time:93899ms step_avg:167.08ms step:573/1530 train_loss:3.7301 train_time:94070ms step_avg:167.09ms step:574/1530 train_loss:3.7740 train_time:94244ms step_avg:167.10ms step:575/1530 train_loss:3.7230 train_time:94416ms step_avg:167.11ms step:576/1530 train_loss:3.7569 train_time:94586ms step_avg:167.11ms step:577/1530 train_loss:3.6608 train_time:94757ms step_avg:167.12ms step:578/1530 train_loss:3.6696 train_time:94928ms step_avg:167.13ms step:579/1530 train_loss:3.6664 train_time:95101ms step_avg:167.14ms step:580/1530 train_loss:3.5857 train_time:95271ms step_avg:167.14ms step:581/1530 train_loss:3.6281 train_time:95443ms step_avg:167.15ms step:582/1530 train_loss:3.8424 train_time:95614ms step_avg:167.16ms step:583/1530 train_loss:3.6156 train_time:95785ms step_avg:167.16ms step:584/1530 train_loss:3.5802 train_time:95958ms step_avg:167.17ms step:585/1530 train_loss:3.7830 train_time:96127ms step_avg:167.18ms step:586/1530 train_loss:3.5086 train_time:96302ms step_avg:167.19ms step:587/1530 train_loss:3.6590 train_time:96471ms step_avg:167.19ms step:588/1530 train_loss:3.6379 train_time:96643ms step_avg:167.20ms step:589/1530 train_loss:3.9890 train_time:96815ms step_avg:167.21ms step:590/1530 train_loss:3.7797 train_time:96986ms step_avg:167.22ms step:591/1530 train_loss:3.4989 train_time:97159ms step_avg:167.23ms step:592/1530 train_loss:3.5292 train_time:97331ms step_avg:167.24ms step:593/1530 train_loss:3.4956 train_time:97505ms step_avg:167.25ms step:594/1530 train_loss:3.5465 train_time:97676ms step_avg:167.25ms step:595/1530 train_loss:3.9169 train_time:97849ms step_avg:167.26ms step:596/1530 train_loss:3.6438 train_time:98022ms step_avg:167.27ms step:597/1530 train_loss:3.5796 train_time:98193ms step_avg:167.28ms step:598/1530 train_loss:3.6508 train_time:98362ms step_avg:167.28ms step:599/1530 train_loss:3.4750 train_time:98533ms step_avg:167.29ms step:600/1530 train_loss:3.5922 train_time:98704ms step_avg:167.30ms step:601/1530 train_loss:3.6462 train_time:98878ms step_avg:167.31ms step:602/1530 train_loss:3.6634 train_time:99050ms step_avg:167.31ms step:603/1530 train_loss:3.7795 train_time:99222ms step_avg:167.32ms step:604/1530 train_loss:3.6051 train_time:99393ms step_avg:167.33ms step:605/1530 train_loss:3.6105 train_time:99566ms step_avg:167.34ms step:606/1530 train_loss:3.5625 train_time:99739ms step_avg:167.35ms step:607/1530 train_loss:3.8340 train_time:99910ms step_avg:167.35ms step:608/1530 train_loss:3.6301 train_time:100083ms step_avg:167.36ms step:609/1530 train_loss:3.6129 train_time:100253ms step_avg:167.37ms step:610/1530 train_loss:3.6966 train_time:100424ms step_avg:167.37ms step:611/1530 train_loss:3.5943 train_time:100594ms step_avg:167.38ms step:612/1530 train_loss:3.5650 train_time:100764ms step_avg:167.38ms step:613/1530 train_loss:3.7550 train_time:100936ms step_avg:167.39ms step:614/1530 train_loss:3.7000 train_time:101106ms step_avg:167.39ms step:615/1530 train_loss:3.7005 train_time:101277ms step_avg:167.40ms step:616/1530 train_loss:3.6239 train_time:101446ms step_avg:167.40ms step:617/1530 train_loss:3.5558 train_time:101621ms step_avg:167.41ms step:618/1530 train_loss:3.6804 train_time:101789ms step_avg:167.42ms step:619/1530 train_loss:3.5408 train_time:101962ms step_avg:167.42ms step:620/1530 train_loss:3.5912 train_time:102132ms step_avg:167.43ms step:621/1530 train_loss:3.9149 train_time:102305ms step_avg:167.44ms step:622/1530 train_loss:3.5663 train_time:102478ms step_avg:167.45ms step:623/1530 train_loss:3.5999 train_time:102649ms step_avg:167.45ms step:624/1530 train_loss:3.6828 train_time:102821ms step_avg:167.46ms step:625/1530 train_loss:3.7023 train_time:102991ms step_avg:167.47ms step:625/1530 val_loss:3.6167 train_time:103040ms step_avg:167.55ms step:626/1530 train_loss:3.7281 train_time:103164ms step_avg:167.47ms step:627/1530 train_loss:3.7128 train_time:103335ms step_avg:167.48ms step:628/1530 train_loss:3.7572 train_time:103504ms step_avg:167.48ms step:629/1530 train_loss:3.5896 train_time:103674ms step_avg:167.49ms step:630/1530 train_loss:3.7176 train_time:103845ms step_avg:167.49ms step:631/1530 train_loss:3.7382 train_time:104013ms step_avg:167.49ms step:632/1530 train_loss:3.6418 train_time:104186ms step_avg:167.50ms step:633/1530 train_loss:3.6015 train_time:104357ms step_avg:167.51ms step:634/1530 train_loss:3.6949 train_time:104527ms step_avg:167.51ms step:635/1530 train_loss:3.9452 train_time:104698ms step_avg:167.52ms step:636/1530 train_loss:3.5373 train_time:104869ms step_avg:167.52ms step:637/1530 train_loss:3.3484 train_time:105042ms step_avg:167.53ms step:638/1530 train_loss:3.5863 train_time:105212ms step_avg:167.53ms step:639/1530 train_loss:3.6289 train_time:105382ms step_avg:167.54ms step:640/1530 train_loss:3.5571 train_time:105552ms step_avg:167.54ms step:641/1530 train_loss:3.5776 train_time:105723ms step_avg:167.55ms step:642/1530 train_loss:3.6251 train_time:105894ms step_avg:167.55ms step:643/1530 train_loss:3.5889 train_time:106065ms step_avg:167.56ms step:644/1530 train_loss:3.5539 train_time:106235ms step_avg:167.56ms step:645/1530 train_loss:3.7661 train_time:106406ms step_avg:167.57ms step:646/1530 train_loss:3.6661 train_time:106578ms step_avg:167.57ms step:647/1530 train_loss:3.6516 train_time:106749ms step_avg:167.58ms step:648/1530 train_loss:3.7062 train_time:106922ms step_avg:167.59ms step:649/1530 train_loss:3.7553 train_time:107092ms step_avg:167.59ms step:650/1530 train_loss:3.6146 train_time:107265ms step_avg:167.60ms step:651/1530 train_loss:3.7608 train_time:107435ms step_avg:167.61ms step:652/1530 train_loss:3.5835 train_time:107605ms step_avg:167.61ms step:653/1530 train_loss:3.6556 train_time:107776ms step_avg:167.61ms step:654/1530 train_loss:3.4206 train_time:107947ms step_avg:167.62ms step:655/1530 train_loss:3.5742 train_time:108117ms step_avg:167.62ms step:656/1530 train_loss:3.5656 train_time:108286ms step_avg:167.63ms step:657/1530 train_loss:3.4939 train_time:108456ms step_avg:167.63ms step:658/1530 train_loss:3.6806 train_time:108627ms step_avg:167.63ms step:659/1530 train_loss:3.5798 train_time:108799ms step_avg:167.64ms step:660/1530 train_loss:3.6816 train_time:108969ms step_avg:167.65ms step:661/1530 train_loss:3.7412 train_time:109143ms step_avg:167.65ms step:662/1530 train_loss:3.6597 train_time:109312ms step_avg:167.66ms step:663/1530 train_loss:3.5500 train_time:109482ms step_avg:167.66ms step:664/1530 train_loss:3.6049 train_time:109652ms step_avg:167.66ms step:665/1530 train_loss:3.4888 train_time:109824ms step_avg:167.67ms step:666/1530 train_loss:3.7748 train_time:109994ms step_avg:167.67ms step:667/1530 train_loss:3.5994 train_time:110165ms step_avg:167.68ms step:668/1530 train_loss:3.6423 train_time:110336ms step_avg:167.68ms step:669/1530 train_loss:3.4860 train_time:110507ms step_avg:167.69ms step:670/1530 train_loss:3.6029 train_time:110677ms step_avg:167.69ms step:671/1530 train_loss:3.5553 train_time:110847ms step_avg:167.70ms step:672/1530 train_loss:3.5622 train_time:111018ms step_avg:167.70ms step:673/1530 train_loss:3.8505 train_time:111188ms step_avg:167.70ms step:674/1530 train_loss:3.6223 train_time:111359ms step_avg:167.71ms step:675/1530 train_loss:3.7024 train_time:111530ms step_avg:167.71ms step:676/1530 train_loss:3.4851 train_time:111701ms step_avg:167.72ms step:677/1530 train_loss:3.6033 train_time:111872ms step_avg:167.72ms step:678/1530 train_loss:3.5465 train_time:112044ms step_avg:167.73ms step:679/1530 train_loss:3.6732 train_time:112215ms step_avg:167.74ms step:680/1530 train_loss:3.5833 train_time:112385ms step_avg:167.74ms step:681/1530 train_loss:3.6096 train_time:112556ms step_avg:167.74ms step:682/1530 train_loss:3.6575 train_time:112732ms step_avg:167.76ms step:683/1530 train_loss:3.7322 train_time:112905ms step_avg:167.76ms step:684/1530 train_loss:3.6471 train_time:113078ms step_avg:167.77ms step:685/1530 train_loss:3.6805 train_time:113251ms step_avg:167.78ms step:686/1530 train_loss:3.6307 train_time:113425ms step_avg:167.79ms step:687/1530 train_loss:3.6549 train_time:113596ms step_avg:167.79ms step:688/1530 train_loss:3.1995 train_time:113772ms step_avg:167.81ms step:689/1530 train_loss:3.4030 train_time:113946ms step_avg:167.81ms step:690/1530 train_loss:3.5333 train_time:114121ms step_avg:167.83ms step:691/1530 train_loss:3.4088 train_time:114293ms step_avg:167.83ms step:692/1530 train_loss:3.6197 train_time:114465ms step_avg:167.84ms step:693/1530 train_loss:3.6405 train_time:114638ms step_avg:167.84ms step:694/1530 train_loss:3.5461 train_time:114810ms step_avg:167.85ms step:695/1530 train_loss:3.5267 train_time:114981ms step_avg:167.86ms step:696/1530 train_loss:3.8501 train_time:115153ms step_avg:167.86ms step:697/1530 train_loss:3.5821 train_time:115328ms step_avg:167.87ms step:698/1530 train_loss:3.6378 train_time:115499ms step_avg:167.88ms step:699/1530 train_loss:3.7577 train_time:115672ms step_avg:167.88ms step:700/1530 train_loss:3.5645 train_time:115846ms step_avg:167.89ms step:701/1530 train_loss:3.5359 train_time:116018ms step_avg:167.90ms step:702/1530 train_loss:3.5070 train_time:116190ms step_avg:167.90ms step:703/1530 train_loss:3.4964 train_time:116364ms step_avg:167.91ms step:704/1530 train_loss:3.5669 train_time:116536ms step_avg:167.92ms step:705/1530 train_loss:3.5572 train_time:116712ms step_avg:167.93ms step:706/1530 train_loss:3.5738 train_time:116888ms step_avg:167.94ms step:707/1530 train_loss:3.6367 train_time:117063ms step_avg:167.95ms step:708/1530 train_loss:3.5958 train_time:117235ms step_avg:167.96ms step:709/1530 train_loss:3.5842 train_time:117408ms step_avg:167.97ms step:710/1530 train_loss:3.5380 train_time:117580ms step_avg:167.97ms step:711/1530 train_loss:3.5875 train_time:117752ms step_avg:167.98ms step:712/1530 train_loss:3.6415 train_time:117928ms step_avg:167.99ms step:713/1530 train_loss:3.6505 train_time:118104ms step_avg:168.00ms step:714/1530 train_loss:3.5553 train_time:118277ms step_avg:168.01ms step:715/1530 train_loss:3.5639 train_time:118450ms step_avg:168.01ms step:716/1530 train_loss:3.5778 train_time:118623ms step_avg:168.02ms step:717/1530 train_loss:3.6967 train_time:118797ms step_avg:168.03ms step:718/1530 train_loss:3.5932 train_time:118968ms step_avg:168.03ms step:719/1530 train_loss:3.6724 train_time:119143ms step_avg:168.04ms step:720/1530 train_loss:3.8550 train_time:119316ms step_avg:168.05ms step:721/1530 train_loss:3.4616 train_time:119489ms step_avg:168.06ms step:722/1530 train_loss:3.7347 train_time:119662ms step_avg:168.06ms step:723/1530 train_loss:3.7660 train_time:119832ms step_avg:168.07ms step:724/1530 train_loss:3.5584 train_time:120006ms step_avg:168.08ms step:725/1530 train_loss:3.6507 train_time:120180ms step_avg:168.08ms step:726/1530 train_loss:3.5252 train_time:120353ms step_avg:168.09ms step:727/1530 train_loss:3.5743 train_time:120529ms step_avg:168.10ms step:728/1530 train_loss:3.7219 train_time:120703ms step_avg:168.11ms step:729/1530 train_loss:3.6643 train_time:120875ms step_avg:168.12ms step:730/1530 train_loss:3.6626 train_time:121049ms step_avg:168.12ms step:731/1530 train_loss:3.5476 train_time:121222ms step_avg:168.13ms step:732/1530 train_loss:3.5927 train_time:121394ms step_avg:168.14ms step:733/1530 train_loss:3.8275 train_time:121568ms step_avg:168.14ms step:734/1530 train_loss:3.5558 train_time:121744ms step_avg:168.15ms step:735/1530 train_loss:3.6129 train_time:121916ms step_avg:168.16ms step:736/1530 train_loss:3.7292 train_time:122088ms step_avg:168.16ms step:737/1530 train_loss:3.6657 train_time:122261ms step_avg:168.17ms step:738/1530 train_loss:3.5993 train_time:122433ms step_avg:168.18ms step:739/1530 train_loss:3.4952 train_time:122605ms step_avg:168.18ms step:740/1530 train_loss:4.1114 train_time:122780ms step_avg:168.19ms step:741/1530 train_loss:3.4840 train_time:122952ms step_avg:168.20ms step:742/1530 train_loss:3.5479 train_time:123125ms step_avg:168.20ms step:743/1530 train_loss:3.5702 train_time:123298ms step_avg:168.21ms step:744/1530 train_loss:3.6443 train_time:123470ms step_avg:168.22ms step:745/1530 train_loss:3.5833 train_time:123645ms step_avg:168.22ms step:746/1530 train_loss:3.5911 train_time:123817ms step_avg:168.23ms step:747/1530 train_loss:3.6353 train_time:123992ms step_avg:168.24ms step:748/1530 train_loss:3.5614 train_time:124169ms step_avg:168.25ms step:749/1530 train_loss:3.5559 train_time:124343ms step_avg:168.26ms step:750/1530 train_loss:3.5845 train_time:124513ms step_avg:168.26ms step:750/1530 val_loss:3.5581 train_time:124562ms step_avg:168.33ms step:751/1530 train_loss:3.5637 train_time:124684ms step_avg:168.26ms step:752/1530 train_loss:3.6128 train_time:124856ms step_avg:168.27ms step:753/1530 train_loss:3.6123 train_time:125031ms step_avg:168.28ms step:754/1530 train_loss:3.5913 train_time:125204ms step_avg:168.29ms step:755/1530 train_loss:3.6803 train_time:125508ms step_avg:168.47ms step:756/1530 train_loss:3.4569 train_time:125693ms step_avg:168.49ms step:757/1530 train_loss:3.7211 train_time:125867ms step_avg:168.50ms step:758/1530 train_loss:3.6432 train_time:126037ms step_avg:168.50ms step:759/1530 train_loss:3.5813 train_time:126372ms step_avg:168.72ms step:760/1530 train_loss:3.6996 train_time:126543ms step_avg:168.72ms step:761/1530 train_loss:3.3956 train_time:126715ms step_avg:168.73ms step:762/1530 train_loss:3.5503 train_time:126888ms step_avg:168.73ms step:763/1530 train_loss:3.6528 train_time:127060ms step_avg:168.74ms step:764/1530 train_loss:3.3160 train_time:127234ms step_avg:168.75ms step:765/1530 train_loss:3.7255 train_time:127409ms step_avg:168.75ms step:766/1530 train_loss:3.5647 train_time:127581ms step_avg:168.76ms step:767/1530 train_loss:3.5667 train_time:127752ms step_avg:168.76ms step:768/1530 train_loss:3.5599 train_time:127927ms step_avg:168.77ms step:769/1530 train_loss:3.5741 train_time:128099ms step_avg:168.77ms step:770/1530 train_loss:3.6359 train_time:128270ms step_avg:168.78ms step:771/1530 train_loss:3.8760 train_time:128443ms step_avg:168.78ms step:772/1530 train_loss:3.4506 train_time:128614ms step_avg:168.79ms step:773/1530 train_loss:3.6236 train_time:128786ms step_avg:168.79ms step:774/1530 train_loss:3.6361 train_time:128957ms step_avg:168.79ms step:775/1530 train_loss:3.5974 train_time:129130ms step_avg:168.80ms step:776/1530 train_loss:3.3975 train_time:129301ms step_avg:168.80ms step:777/1530 train_loss:3.3782 train_time:129474ms step_avg:168.81ms step:778/1530 train_loss:3.4835 train_time:129646ms step_avg:168.81ms step:779/1530 train_loss:3.5750 train_time:129818ms step_avg:168.81ms step:780/1530 train_loss:3.5808 train_time:129991ms step_avg:168.82ms step:781/1530 train_loss:3.6603 train_time:130164ms step_avg:168.83ms step:782/1530 train_loss:3.5800 train_time:130335ms step_avg:168.83ms step:783/1530 train_loss:3.5624 train_time:130507ms step_avg:168.83ms step:784/1530 train_loss:3.5987 train_time:130679ms step_avg:168.84ms step:785/1530 train_loss:3.5549 train_time:130849ms step_avg:168.84ms step:786/1530 train_loss:3.4349 train_time:131022ms step_avg:168.84ms step:787/1530 train_loss:3.7638 train_time:131194ms step_avg:168.85ms step:788/1530 train_loss:3.4957 train_time:131368ms step_avg:168.85ms step:789/1530 train_loss:3.5442 train_time:131538ms step_avg:168.86ms step:790/1530 train_loss:3.6199 train_time:131714ms step_avg:168.86ms step:791/1530 train_loss:3.7642 train_time:131888ms step_avg:168.87ms step:792/1530 train_loss:3.7530 train_time:132061ms step_avg:168.88ms step:793/1530 train_loss:3.4468 train_time:132232ms step_avg:168.88ms step:794/1530 train_loss:3.5896 train_time:132406ms step_avg:168.88ms step:795/1530 train_loss:3.6684 train_time:132579ms step_avg:168.89ms step:796/1530 train_loss:3.7278 train_time:132755ms step_avg:168.90ms step:797/1530 train_loss:3.5176 train_time:132929ms step_avg:168.91ms step:798/1530 train_loss:3.6435 train_time:133104ms step_avg:168.91ms step:799/1530 train_loss:3.5263 train_time:133279ms step_avg:168.92ms step:800/1530 train_loss:3.5240 train_time:133452ms step_avg:168.93ms step:801/1530 train_loss:3.6238 train_time:133627ms step_avg:168.93ms step:802/1530 train_loss:3.4950 train_time:133803ms step_avg:168.94ms step:803/1530 train_loss:3.4962 train_time:133976ms step_avg:168.95ms step:804/1530 train_loss:3.6186 train_time:134151ms step_avg:168.96ms step:805/1530 train_loss:3.5121 train_time:134329ms step_avg:168.97ms step:806/1530 train_loss:3.5556 train_time:134502ms step_avg:168.97ms step:807/1530 train_loss:3.6325 train_time:134674ms step_avg:168.98ms step:808/1530 train_loss:3.5441 train_time:134850ms step_avg:168.98ms step:809/1530 train_loss:3.4905 train_time:135023ms step_avg:168.99ms step:810/1530 train_loss:3.5529 train_time:135195ms step_avg:168.99ms step:811/1530 train_loss:3.5678 train_time:135369ms step_avg:169.00ms step:812/1530 train_loss:3.5903 train_time:135541ms step_avg:169.00ms step:813/1530 train_loss:3.6227 train_time:135713ms step_avg:169.01ms step:814/1530 train_loss:3.5604 train_time:135889ms step_avg:169.02ms step:815/1530 train_loss:3.5589 train_time:136061ms step_avg:169.02ms step:816/1530 train_loss:3.6777 train_time:136236ms step_avg:169.03ms step:817/1530 train_loss:3.7645 train_time:136411ms step_avg:169.03ms step:818/1530 train_loss:3.5185 train_time:136583ms step_avg:169.04ms step:819/1530 train_loss:3.7172 train_time:136758ms step_avg:169.05ms step:820/1530 train_loss:3.4900 train_time:136934ms step_avg:169.05ms step:821/1530 train_loss:3.5619 train_time:137109ms step_avg:169.06ms step:822/1530 train_loss:3.6932 train_time:137282ms step_avg:169.07ms step:823/1530 train_loss:3.5697 train_time:137456ms step_avg:169.07ms step:824/1530 train_loss:3.5024 train_time:137630ms step_avg:169.08ms step:825/1530 train_loss:3.6100 train_time:137806ms step_avg:169.09ms step:826/1530 train_loss:3.4680 train_time:137980ms step_avg:169.09ms step:827/1530 train_loss:3.7330 train_time:138155ms step_avg:169.10ms step:828/1530 train_loss:3.6162 train_time:138329ms step_avg:169.11ms step:829/1530 train_loss:3.6232 train_time:138506ms step_avg:169.12ms step:830/1530 train_loss:3.5330 train_time:138680ms step_avg:169.12ms step:831/1530 train_loss:3.5965 train_time:138856ms step_avg:169.13ms step:832/1530 train_loss:3.5113 train_time:139033ms step_avg:169.14ms step:833/1530 train_loss:3.6490 train_time:139208ms step_avg:169.15ms step:834/1530 train_loss:3.4723 train_time:139380ms step_avg:169.15ms step:835/1530 train_loss:3.4534 train_time:139556ms step_avg:169.16ms step:836/1530 train_loss:3.7121 train_time:139733ms step_avg:169.17ms step:837/1530 train_loss:3.3950 train_time:139909ms step_avg:169.18ms step:838/1530 train_loss:3.5921 train_time:140082ms step_avg:169.18ms step:839/1530 train_loss:3.4139 train_time:140258ms step_avg:169.19ms step:840/1530 train_loss:3.4618 train_time:140431ms step_avg:169.19ms step:841/1530 train_loss:3.5715 train_time:140604ms step_avg:169.20ms step:842/1530 train_loss:3.5753 train_time:140777ms step_avg:169.20ms step:843/1530 train_loss:3.5541 train_time:140952ms step_avg:169.21ms step:844/1530 train_loss:3.4240 train_time:141125ms step_avg:169.21ms step:845/1530 train_loss:3.6588 train_time:141299ms step_avg:169.22ms step:846/1530 train_loss:3.5078 train_time:141475ms step_avg:169.23ms step:847/1530 train_loss:3.4855 train_time:141651ms step_avg:169.24ms step:848/1530 train_loss:3.6381 train_time:141823ms step_avg:169.24ms step:849/1530 train_loss:3.4861 train_time:141998ms step_avg:169.25ms step:850/1530 train_loss:3.4420 train_time:142173ms step_avg:169.25ms step:851/1530 train_loss:3.7258 train_time:142347ms step_avg:169.26ms step:852/1530 train_loss:3.4343 train_time:142521ms step_avg:169.26ms step:853/1530 train_loss:3.5669 train_time:142693ms step_avg:169.27ms step:854/1530 train_loss:3.6468 train_time:142868ms step_avg:169.28ms step:855/1530 train_loss:3.5033 train_time:143041ms step_avg:169.28ms step:856/1530 train_loss:3.5448 train_time:143217ms step_avg:169.29ms step:857/1530 train_loss:3.6069 train_time:143392ms step_avg:169.29ms step:858/1530 train_loss:3.4685 train_time:143568ms step_avg:169.30ms step:859/1530 train_loss:3.5548 train_time:143741ms step_avg:169.31ms step:860/1530 train_loss:3.5820 train_time:143912ms step_avg:169.31ms step:861/1530 train_loss:3.6232 train_time:144090ms step_avg:169.32ms step:862/1530 train_loss:3.5933 train_time:144267ms step_avg:169.33ms step:863/1530 train_loss:3.5666 train_time:144442ms step_avg:169.33ms step:864/1530 train_loss:3.3775 train_time:144616ms step_avg:169.34ms step:865/1530 train_loss:3.5946 train_time:144788ms step_avg:169.34ms step:866/1530 train_loss:3.8720 train_time:144965ms step_avg:169.35ms step:867/1530 train_loss:3.4545 train_time:145138ms step_avg:169.36ms step:868/1530 train_loss:3.6392 train_time:145310ms step_avg:169.36ms step:869/1530 train_loss:3.6122 train_time:145483ms step_avg:169.36ms step:870/1530 train_loss:3.4463 train_time:145657ms step_avg:169.37ms step:871/1530 train_loss:3.3965 train_time:145832ms step_avg:169.37ms step:872/1530 train_loss:3.6419 train_time:146008ms step_avg:169.38ms step:873/1530 train_loss:3.4614 train_time:146180ms step_avg:169.39ms step:874/1530 train_loss:3.2188 train_time:146358ms step_avg:169.40ms step:875/1530 train_loss:3.6294 train_time:146532ms step_avg:169.40ms step:875/1530 val_loss:3.5146 train_time:146582ms step_avg:169.46ms step:876/1530 train_loss:3.4350 train_time:146706ms step_avg:169.41ms step:877/1530 train_loss:3.6168 train_time:146880ms step_avg:169.41ms step:878/1530 train_loss:3.4603 train_time:147055ms step_avg:169.42ms step:879/1530 train_loss:3.6428 train_time:147229ms step_avg:169.42ms step:880/1530 train_loss:3.2961 train_time:147401ms step_avg:169.43ms step:881/1530 train_loss:3.4677 train_time:147574ms step_avg:169.43ms step:882/1530 train_loss:3.6963 train_time:147747ms step_avg:169.43ms step:883/1530 train_loss:3.8381 train_time:147921ms step_avg:169.44ms step:884/1530 train_loss:3.5597 train_time:148095ms step_avg:169.44ms step:885/1530 train_loss:3.4939 train_time:148267ms step_avg:169.45ms step:886/1530 train_loss:3.5624 train_time:148439ms step_avg:169.45ms step:887/1530 train_loss:4.0887 train_time:148615ms step_avg:169.46ms step:888/1530 train_loss:3.8323 train_time:148795ms step_avg:169.47ms step:889/1530 train_loss:3.5086 train_time:148967ms step_avg:169.47ms step:890/1530 train_loss:3.5293 train_time:149138ms step_avg:169.47ms step:891/1530 train_loss:3.3571 train_time:149313ms step_avg:169.48ms step:892/1530 train_loss:3.7143 train_time:149487ms step_avg:169.49ms step:893/1530 train_loss:3.4169 train_time:149661ms step_avg:169.49ms step:894/1530 train_loss:3.6501 train_time:149837ms step_avg:169.50ms step:895/1530 train_loss:3.6711 train_time:150011ms step_avg:169.50ms step:896/1530 train_loss:3.4912 train_time:150184ms step_avg:169.51ms step:897/1530 train_loss:3.5420 train_time:150357ms step_avg:169.51ms step:898/1530 train_loss:3.5815 train_time:150534ms step_avg:169.52ms step:899/1530 train_loss:3.4675 train_time:150708ms step_avg:169.53ms step:900/1530 train_loss:3.4177 train_time:150880ms step_avg:169.53ms step:901/1530 train_loss:3.6105 train_time:151052ms step_avg:169.53ms step:902/1530 train_loss:3.6322 train_time:151227ms step_avg:169.54ms step:903/1530 train_loss:3.5335 train_time:151402ms step_avg:169.54ms step:904/1530 train_loss:3.4902 train_time:151575ms step_avg:169.55ms step:905/1530 train_loss:3.5018 train_time:151746ms step_avg:169.55ms step:906/1530 train_loss:3.7020 train_time:151921ms step_avg:169.55ms step:907/1530 train_loss:3.5075 train_time:152095ms step_avg:169.56ms step:908/1530 train_loss:3.5629 train_time:152268ms step_avg:169.56ms step:909/1530 train_loss:3.4516 train_time:152444ms step_avg:169.57ms step:910/1530 train_loss:3.5244 train_time:152624ms step_avg:169.58ms step:911/1530 train_loss:3.6417 train_time:152800ms step_avg:169.59ms step:912/1530 train_loss:3.6074 train_time:152980ms step_avg:169.60ms step:913/1530 train_loss:3.4632 train_time:153159ms step_avg:169.61ms step:914/1530 train_loss:3.7418 train_time:153337ms step_avg:169.62ms step:915/1530 train_loss:3.5306 train_time:153517ms step_avg:169.63ms step:916/1530 train_loss:3.6120 train_time:153694ms step_avg:169.64ms step:917/1530 train_loss:3.5983 train_time:153869ms step_avg:169.65ms step:918/1530 train_loss:4.8406 train_time:154050ms step_avg:169.66ms step:919/1530 train_loss:3.4976 train_time:154231ms step_avg:169.67ms step:920/1530 train_loss:3.5832 train_time:154406ms step_avg:169.68ms step:921/1530 train_loss:3.5451 train_time:154582ms step_avg:169.68ms step:922/1530 train_loss:3.5765 train_time:154759ms step_avg:169.69ms step:923/1530 train_loss:3.6103 train_time:154936ms step_avg:169.70ms step:924/1530 train_loss:3.6778 train_time:155113ms step_avg:169.71ms step:925/1530 train_loss:3.6438 train_time:155289ms step_avg:169.71ms step:926/1530 train_loss:3.5505 train_time:155463ms step_avg:169.72ms step:927/1530 train_loss:3.5464 train_time:155637ms step_avg:169.72ms step:928/1530 train_loss:3.7849 train_time:155814ms step_avg:169.73ms step:929/1530 train_loss:3.6062 train_time:155989ms step_avg:169.74ms step:930/1530 train_loss:3.3993 train_time:156165ms step_avg:169.74ms step:931/1530 train_loss:3.4872 train_time:156338ms step_avg:169.75ms step:932/1530 train_loss:3.6408 train_time:156516ms step_avg:169.76ms step:933/1530 train_loss:3.3613 train_time:156692ms step_avg:169.76ms step:934/1530 train_loss:3.5816 train_time:156870ms step_avg:169.77ms step:935/1530 train_loss:3.4327 train_time:157047ms step_avg:169.78ms step:936/1530 train_loss:3.5150 train_time:157226ms step_avg:169.79ms step:937/1530 train_loss:3.6176 train_time:157402ms step_avg:169.80ms step:938/1530 train_loss:3.5352 train_time:157576ms step_avg:169.80ms step:939/1530 train_loss:3.6688 train_time:157758ms step_avg:169.81ms step:940/1530 train_loss:3.4794 train_time:157933ms step_avg:169.82ms step:941/1530 train_loss:3.5409 train_time:158107ms step_avg:169.83ms step:942/1530 train_loss:3.3572 train_time:158283ms step_avg:169.83ms step:943/1530 train_loss:3.7070 train_time:158464ms step_avg:169.84ms step:944/1530 train_loss:3.4007 train_time:158773ms step_avg:169.99ms step:945/1530 train_loss:3.4225 train_time:158956ms step_avg:170.01ms step:946/1530 train_loss:5.0742 train_time:159138ms step_avg:170.02ms step:947/1530 train_loss:3.5940 train_time:159313ms step_avg:170.02ms step:948/1530 train_loss:3.4752 train_time:159488ms step_avg:170.03ms step:949/1530 train_loss:3.3638 train_time:159825ms step_avg:170.21ms step:950/1530 train_loss:3.4395 train_time:160000ms step_avg:170.21ms step:951/1530 train_loss:3.3999 train_time:160180ms step_avg:170.22ms step:952/1530 train_loss:3.4770 train_time:160355ms step_avg:170.23ms step:953/1530 train_loss:3.5593 train_time:160533ms step_avg:170.24ms step:954/1530 train_loss:3.4409 train_time:160713ms step_avg:170.25ms step:955/1530 train_loss:3.4694 train_time:160887ms step_avg:170.25ms step:956/1530 train_loss:3.4375 train_time:161063ms step_avg:170.26ms step:957/1530 train_loss:3.4868 train_time:161241ms step_avg:170.26ms step:958/1530 train_loss:3.5017 train_time:161421ms step_avg:170.27ms step:959/1530 train_loss:3.5058 train_time:161596ms step_avg:170.28ms step:960/1530 train_loss:3.4047 train_time:161774ms step_avg:170.29ms step:961/1530 train_loss:3.6444 train_time:161949ms step_avg:170.29ms step:962/1530 train_loss:3.5870 train_time:162122ms step_avg:170.30ms step:963/1530 train_loss:3.6834 train_time:162298ms step_avg:170.30ms step:964/1530 train_loss:3.4282 train_time:162475ms step_avg:170.31ms step:965/1530 train_loss:3.4757 train_time:162648ms step_avg:170.31ms step:966/1530 train_loss:3.7052 train_time:162823ms step_avg:170.32ms step:967/1530 train_loss:3.5196 train_time:162996ms step_avg:170.32ms step:968/1530 train_loss:3.5117 train_time:163172ms step_avg:170.33ms step:969/1530 train_loss:3.5825 train_time:163347ms step_avg:170.33ms step:970/1530 train_loss:3.3666 train_time:163520ms step_avg:170.33ms step:971/1530 train_loss:3.5272 train_time:163693ms step_avg:170.34ms step:972/1530 train_loss:3.4846 train_time:163867ms step_avg:170.34ms step:973/1530 train_loss:3.5334 train_time:164040ms step_avg:170.34ms step:974/1530 train_loss:3.5928 train_time:164217ms step_avg:170.35ms step:975/1530 train_loss:3.4630 train_time:164393ms step_avg:170.36ms step:976/1530 train_loss:3.6657 train_time:164568ms step_avg:170.36ms step:977/1530 train_loss:3.5634 train_time:164741ms step_avg:170.36ms step:978/1530 train_loss:3.3533 train_time:164916ms step_avg:170.37ms step:979/1530 train_loss:3.6230 train_time:165093ms step_avg:170.37ms step:980/1530 train_loss:3.4154 train_time:165269ms step_avg:170.38ms step:981/1530 train_loss:3.5739 train_time:165447ms step_avg:170.39ms step:982/1530 train_loss:3.5374 train_time:165620ms step_avg:170.39ms step:983/1530 train_loss:3.5166 train_time:165796ms step_avg:170.40ms step:984/1530 train_loss:3.4897 train_time:165970ms step_avg:170.40ms step:985/1530 train_loss:3.5680 train_time:166147ms step_avg:170.41ms step:986/1530 train_loss:3.4079 train_time:166323ms step_avg:170.41ms step:987/1530 train_loss:3.4796 train_time:166494ms step_avg:170.41ms step:988/1530 train_loss:3.4629 train_time:166669ms step_avg:170.42ms step:989/1530 train_loss:3.4123 train_time:166842ms step_avg:170.42ms step:990/1530 train_loss:3.6529 train_time:167020ms step_avg:170.43ms step:991/1530 train_loss:3.4645 train_time:167195ms step_avg:170.43ms step:992/1530 train_loss:3.4356 train_time:167376ms step_avg:170.44ms step:993/1530 train_loss:3.4946 train_time:167555ms step_avg:170.45ms step:994/1530 train_loss:3.5944 train_time:167730ms step_avg:170.46ms step:995/1530 train_loss:3.5248 train_time:167903ms step_avg:170.46ms step:996/1530 train_loss:3.4557 train_time:168077ms step_avg:170.46ms step:997/1530 train_loss:3.7536 train_time:168252ms step_avg:170.47ms step:998/1530 train_loss:3.4379 train_time:168425ms step_avg:170.47ms step:999/1530 train_loss:3.5844 train_time:168597ms step_avg:170.47ms step:1000/1530 train_loss:3.4347 train_time:168776ms step_avg:170.48ms step:1000/1530 val_loss:3.4627 train_time:168828ms step_avg:170.53ms step:1001/1530 train_loss:3.4936 train_time:168953ms step_avg:170.49ms step:1002/1530 train_loss:3.3736 train_time:169126ms step_avg:170.49ms step:1003/1530 train_loss:3.5540 train_time:169302ms step_avg:170.50ms step:1004/1530 train_loss:3.5997 train_time:169477ms step_avg:170.50ms step:1005/1530 train_loss:3.3862 train_time:169653ms step_avg:170.51ms step:1006/1530 train_loss:3.4608 train_time:169829ms step_avg:170.51ms step:1007/1530 train_loss:3.4358 train_time:170003ms step_avg:170.51ms step:1008/1530 train_loss:3.5528 train_time:170178ms step_avg:170.52ms step:1009/1530 train_loss:3.6563 train_time:170356ms step_avg:170.53ms step:1010/1530 train_loss:3.5568 train_time:170530ms step_avg:170.53ms step:1011/1530 train_loss:3.5301 train_time:170704ms step_avg:170.53ms step:1012/1530 train_loss:3.3873 train_time:170878ms step_avg:170.54ms step:1013/1530 train_loss:3.5271 train_time:171054ms step_avg:170.54ms step:1014/1530 train_loss:3.6173 train_time:171231ms step_avg:170.55ms step:1015/1530 train_loss:3.3231 train_time:171410ms step_avg:170.56ms step:1016/1530 train_loss:3.4052 train_time:171583ms step_avg:170.56ms step:1017/1530 train_loss:3.3932 train_time:171761ms step_avg:170.57ms step:1018/1530 train_loss:3.3899 train_time:171937ms step_avg:170.57ms step:1019/1530 train_loss:3.5150 train_time:172113ms step_avg:170.58ms step:1020/1530 train_loss:3.3697 train_time:172290ms step_avg:170.58ms step:1021/1530 train_loss:3.3493 train_time:172464ms step_avg:170.59ms step:1022/1530 train_loss:3.4751 train_time:172639ms step_avg:170.59ms step:1023/1530 train_loss:3.4984 train_time:172816ms step_avg:170.60ms step:1024/1530 train_loss:3.4705 train_time:172993ms step_avg:170.60ms step:1025/1530 train_loss:3.4724 train_time:173170ms step_avg:170.61ms step:1026/1530 train_loss:3.6153 train_time:173345ms step_avg:170.62ms step:1027/1530 train_loss:3.3135 train_time:173522ms step_avg:170.62ms step:1028/1530 train_loss:3.3913 train_time:173702ms step_avg:170.63ms step:1029/1530 train_loss:3.3017 train_time:173882ms step_avg:170.64ms step:1030/1530 train_loss:3.5333 train_time:174059ms step_avg:170.65ms step:1031/1530 train_loss:3.5014 train_time:174236ms step_avg:170.65ms step:1032/1530 train_loss:3.6888 train_time:174419ms step_avg:170.66ms step:1033/1530 train_loss:3.4849 train_time:174594ms step_avg:170.67ms step:1034/1530 train_loss:3.3958 train_time:174770ms step_avg:170.67ms step:1035/1530 train_loss:3.4361 train_time:174949ms step_avg:170.68ms step:1036/1530 train_loss:3.4753 train_time:175127ms step_avg:170.69ms step:1037/1530 train_loss:3.7828 train_time:175303ms step_avg:170.69ms step:1038/1530 train_loss:3.6116 train_time:175481ms step_avg:170.70ms step:1039/1530 train_loss:3.5056 train_time:175663ms step_avg:170.71ms step:1040/1530 train_loss:3.4055 train_time:175838ms step_avg:170.72ms step:1041/1530 train_loss:3.4792 train_time:176016ms step_avg:170.72ms step:1042/1530 train_loss:3.5195 train_time:176189ms step_avg:170.73ms step:1043/1530 train_loss:3.4404 train_time:176364ms step_avg:170.73ms step:1044/1530 train_loss:3.4537 train_time:176539ms step_avg:170.73ms step:1045/1530 train_loss:3.5103 train_time:176718ms step_avg:170.74ms step:1046/1530 train_loss:3.4231 train_time:176894ms step_avg:170.75ms step:1047/1530 train_loss:3.6292 train_time:177069ms step_avg:170.75ms step:1048/1530 train_loss:3.4913 train_time:177246ms step_avg:170.76ms step:1049/1530 train_loss:3.3969 train_time:177422ms step_avg:170.76ms step:1050/1530 train_loss:3.3907 train_time:177599ms step_avg:170.77ms step:1051/1530 train_loss:3.4910 train_time:177777ms step_avg:170.77ms step:1052/1530 train_loss:3.3588 train_time:177955ms step_avg:170.78ms step:1053/1530 train_loss:3.6842 train_time:178134ms step_avg:170.79ms step:1054/1530 train_loss:3.5325 train_time:178314ms step_avg:170.80ms step:1055/1530 train_loss:3.3820 train_time:178488ms step_avg:170.80ms step:1056/1530 train_loss:3.4904 train_time:178663ms step_avg:170.81ms step:1057/1530 train_loss:3.5750 train_time:178841ms step_avg:170.81ms step:1058/1530 train_loss:3.2954 train_time:179018ms step_avg:170.82ms step:1059/1530 train_loss:3.3644 train_time:179199ms step_avg:170.83ms step:1060/1530 train_loss:3.4320 train_time:179374ms step_avg:170.83ms step:1061/1530 train_loss:3.4129 train_time:179549ms step_avg:170.84ms step:1062/1530 train_loss:3.3776 train_time:179727ms step_avg:170.84ms step:1063/1530 train_loss:3.4553 train_time:179900ms step_avg:170.85ms step:1064/1530 train_loss:3.3756 train_time:180075ms step_avg:170.85ms step:1065/1530 train_loss:3.3575 train_time:180253ms step_avg:170.86ms step:1066/1530 train_loss:3.4123 train_time:180432ms step_avg:170.86ms step:1067/1530 train_loss:3.2719 train_time:180609ms step_avg:170.87ms step:1068/1530 train_loss:3.4282 train_time:180785ms step_avg:170.87ms step:1069/1530 train_loss:3.2920 train_time:180965ms step_avg:170.88ms step:1070/1530 train_loss:3.5651 train_time:181141ms step_avg:170.89ms step:1071/1530 train_loss:3.5046 train_time:181321ms step_avg:170.90ms step:1072/1530 train_loss:3.4359 train_time:181495ms step_avg:170.90ms step:1073/1530 train_loss:3.5139 train_time:181668ms step_avg:170.90ms step:1074/1530 train_loss:3.4246 train_time:181846ms step_avg:170.91ms step:1075/1530 train_loss:3.3955 train_time:182023ms step_avg:170.91ms step:1076/1530 train_loss:3.7865 train_time:182200ms step_avg:170.92ms step:1077/1530 train_loss:3.4254 train_time:182374ms step_avg:170.92ms step:1078/1530 train_loss:3.0769 train_time:182559ms step_avg:170.94ms step:1079/1530 train_loss:3.5307 train_time:182735ms step_avg:170.94ms step:1080/1530 train_loss:3.4199 train_time:182913ms step_avg:170.95ms step:1081/1530 train_loss:3.4930 train_time:183087ms step_avg:170.95ms step:1082/1530 train_loss:3.5877 train_time:183263ms step_avg:170.95ms step:1083/1530 train_loss:3.4855 train_time:183438ms step_avg:170.96ms step:1084/1530 train_loss:3.4587 train_time:183614ms step_avg:170.96ms step:1085/1530 train_loss:3.4267 train_time:183790ms step_avg:170.97ms step:1086/1530 train_loss:3.6285 train_time:183965ms step_avg:170.97ms step:1087/1530 train_loss:3.4980 train_time:184139ms step_avg:170.97ms step:1088/1530 train_loss:3.3689 train_time:184317ms step_avg:170.98ms step:1089/1530 train_loss:3.3679 train_time:184497ms step_avg:170.99ms step:1090/1530 train_loss:3.4732 train_time:184676ms step_avg:171.00ms step:1091/1530 train_loss:3.2758 train_time:184853ms step_avg:171.00ms step:1092/1530 train_loss:3.4765 train_time:185030ms step_avg:171.01ms step:1093/1530 train_loss:3.5949 train_time:185209ms step_avg:171.01ms step:1094/1530 train_loss:3.4395 train_time:185384ms step_avg:171.02ms step:1095/1530 train_loss:3.4104 train_time:185559ms step_avg:171.02ms step:1096/1530 train_loss:3.4164 train_time:185736ms step_avg:171.03ms step:1097/1530 train_loss:3.4829 train_time:185915ms step_avg:171.03ms step:1098/1530 train_loss:3.5567 train_time:186094ms step_avg:171.04ms step:1099/1530 train_loss:3.5195 train_time:186270ms step_avg:171.05ms step:1100/1530 train_loss:3.4184 train_time:186451ms step_avg:171.06ms step:1101/1530 train_loss:3.2875 train_time:186629ms step_avg:171.06ms step:1102/1530 train_loss:3.3132 train_time:186809ms step_avg:171.07ms step:1103/1530 train_loss:3.4360 train_time:186990ms step_avg:171.08ms step:1104/1530 train_loss:3.3171 train_time:187167ms step_avg:171.08ms step:1105/1530 train_loss:4.0614 train_time:187344ms step_avg:171.09ms step:1106/1530 train_loss:3.2204 train_time:187520ms step_avg:171.09ms step:1107/1530 train_loss:3.5597 train_time:187695ms step_avg:171.10ms step:1108/1530 train_loss:3.3407 train_time:187868ms step_avg:171.10ms step:1109/1530 train_loss:3.4976 train_time:188043ms step_avg:171.10ms step:1110/1530 train_loss:3.4203 train_time:188217ms step_avg:171.11ms step:1111/1530 train_loss:3.4801 train_time:188392ms step_avg:171.11ms step:1112/1530 train_loss:3.5535 train_time:188571ms step_avg:171.12ms step:1113/1530 train_loss:3.4245 train_time:188755ms step_avg:171.13ms step:1114/1530 train_loss:3.3647 train_time:188935ms step_avg:171.14ms step:1115/1530 train_loss:3.2303 train_time:189114ms step_avg:171.14ms step:1116/1530 train_loss:3.4221 train_time:189288ms step_avg:171.15ms step:1117/1530 train_loss:3.5864 train_time:189467ms step_avg:171.15ms step:1118/1530 train_loss:3.6159 train_time:189644ms step_avg:171.16ms step:1119/1530 train_loss:3.4711 train_time:189819ms step_avg:171.16ms step:1120/1530 train_loss:3.4830 train_time:189996ms step_avg:171.17ms step:1121/1530 train_loss:3.3793 train_time:190174ms step_avg:171.17ms step:1122/1530 train_loss:3.4545 train_time:190349ms step_avg:171.18ms step:1123/1530 train_loss:3.5738 train_time:190525ms step_avg:171.18ms step:1124/1530 train_loss:3.3323 train_time:190702ms step_avg:171.19ms step:1125/1530 train_loss:3.2370 train_time:190878ms step_avg:171.19ms step:1125/1530 val_loss:3.4029 train_time:190928ms step_avg:171.24ms step:1126/1530 train_loss:3.4662 train_time:191055ms step_avg:171.20ms step:1127/1530 train_loss:3.6683 train_time:191235ms step_avg:171.20ms step:1128/1530 train_loss:3.2232 train_time:191412ms step_avg:171.21ms step:1129/1530 train_loss:3.5544 train_time:191594ms step_avg:171.22ms step:1130/1530 train_loss:3.3735 train_time:191773ms step_avg:171.23ms step:1131/1530 train_loss:3.3975 train_time:191957ms step_avg:171.24ms step:1132/1530 train_loss:3.3581 train_time:192130ms step_avg:171.24ms step:1133/1530 train_loss:3.4830 train_time:192439ms step_avg:171.36ms step:1134/1530 train_loss:3.4423 train_time:192625ms step_avg:171.37ms step:1135/1530 train_loss:3.5182 train_time:192800ms step_avg:171.38ms step:1136/1530 train_loss:3.5602 train_time:192977ms step_avg:171.38ms step:1137/1530 train_loss:3.4553 train_time:193153ms step_avg:171.39ms step:1138/1530 train_loss:3.3471 train_time:193333ms step_avg:171.39ms step:1139/1530 train_loss:3.6503 train_time:193680ms step_avg:171.55ms step:1140/1530 train_loss:3.4501 train_time:193855ms step_avg:171.55ms step:1141/1530 train_loss:3.5915 train_time:194037ms step_avg:171.56ms step:1142/1530 train_loss:3.4418 train_time:194214ms step_avg:171.57ms step:1143/1530 train_loss:3.3587 train_time:194393ms step_avg:171.57ms step:1144/1530 train_loss:3.4365 train_time:194569ms step_avg:171.58ms step:1145/1530 train_loss:3.5841 train_time:194743ms step_avg:171.58ms step:1146/1530 train_loss:3.5529 train_time:194925ms step_avg:171.59ms step:1147/1530 train_loss:3.4791 train_time:195103ms step_avg:171.59ms step:1148/1530 train_loss:3.4889 train_time:195279ms step_avg:171.60ms step:1149/1530 train_loss:3.3208 train_time:195459ms step_avg:171.61ms step:1150/1530 train_loss:3.3707 train_time:195635ms step_avg:171.61ms step:1151/1530 train_loss:3.3148 train_time:195814ms step_avg:171.62ms step:1152/1530 train_loss:3.3944 train_time:195995ms step_avg:171.62ms step:1153/1530 train_loss:3.4279 train_time:196174ms step_avg:171.63ms step:1154/1530 train_loss:3.5171 train_time:196349ms step_avg:171.63ms step:1155/1530 train_loss:3.3171 train_time:196531ms step_avg:171.64ms step:1156/1530 train_loss:3.5318 train_time:196714ms step_avg:171.65ms step:1157/1530 train_loss:3.4894 train_time:196895ms step_avg:171.66ms step:1158/1530 train_loss:3.2468 train_time:197070ms step_avg:171.66ms step:1159/1530 train_loss:3.3450 train_time:197248ms step_avg:171.67ms step:1160/1530 train_loss:3.3310 train_time:197423ms step_avg:171.67ms step:1161/1530 train_loss:3.0883 train_time:197602ms step_avg:171.68ms step:1162/1530 train_loss:3.4165 train_time:197779ms step_avg:171.68ms step:1163/1530 train_loss:3.3877 train_time:197958ms step_avg:171.69ms step:1164/1530 train_loss:3.2895 train_time:198134ms step_avg:171.69ms step:1165/1530 train_loss:3.2380 train_time:198309ms step_avg:171.70ms step:1166/1530 train_loss:3.3829 train_time:198488ms step_avg:171.70ms step:1167/1530 train_loss:3.4054 train_time:198663ms step_avg:171.71ms step:1168/1530 train_loss:3.7160 train_time:198839ms step_avg:171.71ms step:1169/1530 train_loss:3.3708 train_time:199016ms step_avg:171.71ms step:1170/1530 train_loss:3.3860 train_time:199192ms step_avg:171.72ms step:1171/1530 train_loss:3.2822 train_time:199369ms step_avg:171.72ms step:1172/1530 train_loss:3.4218 train_time:199544ms step_avg:171.72ms step:1173/1530 train_loss:3.5363 train_time:199726ms step_avg:171.73ms step:1174/1530 train_loss:3.3798 train_time:199909ms step_avg:171.74ms step:1175/1530 train_loss:3.3573 train_time:200089ms step_avg:171.75ms step:1176/1530 train_loss:3.4179 train_time:200270ms step_avg:171.76ms step:1177/1530 train_loss:3.4474 train_time:200453ms step_avg:171.77ms step:1178/1530 train_loss:3.4922 train_time:200630ms step_avg:171.77ms step:1179/1530 train_loss:3.3971 train_time:200807ms step_avg:171.78ms step:1180/1530 train_loss:3.3527 train_time:200995ms step_avg:171.79ms step:1181/1530 train_loss:3.3321 train_time:201172ms step_avg:171.80ms step:1182/1530 train_loss:3.3643 train_time:201351ms step_avg:171.80ms step:1183/1530 train_loss:3.3293 train_time:201529ms step_avg:171.81ms step:1184/1530 train_loss:3.5079 train_time:201707ms step_avg:171.81ms step:1185/1530 train_loss:3.5378 train_time:201889ms step_avg:171.82ms step:1186/1530 train_loss:3.3569 train_time:202069ms step_avg:171.83ms step:1187/1530 train_loss:3.4136 train_time:202253ms step_avg:171.84ms step:1188/1530 train_loss:3.4400 train_time:202430ms step_avg:171.84ms step:1189/1530 train_loss:3.2754 train_time:202611ms step_avg:171.85ms step:1190/1530 train_loss:3.4376 train_time:202790ms step_avg:171.86ms step:1191/1530 train_loss:3.5763 train_time:202970ms step_avg:171.86ms step:1192/1530 train_loss:3.3839 train_time:203146ms step_avg:171.87ms step:1193/1530 train_loss:3.2670 train_time:203322ms step_avg:171.87ms step:1194/1530 train_loss:3.5528 train_time:203499ms step_avg:171.87ms step:1195/1530 train_loss:3.3686 train_time:203679ms step_avg:171.88ms step:1196/1530 train_loss:3.3805 train_time:203865ms step_avg:171.89ms step:1197/1530 train_loss:3.2876 train_time:204043ms step_avg:171.90ms step:1198/1530 train_loss:3.2951 train_time:204228ms step_avg:171.91ms step:1199/1530 train_loss:3.3388 train_time:204408ms step_avg:171.92ms step:1200/1530 train_loss:3.4491 train_time:204584ms step_avg:171.92ms step:1201/1530 train_loss:3.4738 train_time:204761ms step_avg:171.92ms step:1202/1530 train_loss:3.6079 train_time:204951ms step_avg:171.94ms step:1203/1530 train_loss:3.4026 train_time:205130ms step_avg:171.94ms step:1204/1530 train_loss:3.2987 train_time:205311ms step_avg:171.95ms step:1205/1530 train_loss:3.4347 train_time:205488ms step_avg:171.96ms step:1206/1530 train_loss:3.4676 train_time:205664ms step_avg:171.96ms step:1207/1530 train_loss:3.5086 train_time:205843ms step_avg:171.97ms step:1208/1530 train_loss:3.3893 train_time:206019ms step_avg:171.97ms step:1209/1530 train_loss:3.2416 train_time:206197ms step_avg:171.97ms step:1210/1530 train_loss:3.2982 train_time:206377ms step_avg:171.98ms step:1211/1530 train_loss:3.3883 train_time:206556ms step_avg:171.99ms step:1212/1530 train_loss:3.3900 train_time:206733ms step_avg:171.99ms step:1213/1530 train_loss:3.4048 train_time:206911ms step_avg:172.00ms step:1214/1530 train_loss:3.2415 train_time:207092ms step_avg:172.00ms step:1215/1530 train_loss:3.3912 train_time:207266ms step_avg:172.01ms step:1216/1530 train_loss:3.3259 train_time:207444ms step_avg:172.01ms step:1217/1530 train_loss:3.3136 train_time:207623ms step_avg:172.02ms step:1218/1530 train_loss:3.4012 train_time:207800ms step_avg:172.02ms step:1219/1530 train_loss:3.2521 train_time:207983ms step_avg:172.03ms step:1220/1530 train_loss:3.4702 train_time:208160ms step_avg:172.03ms step:1221/1530 train_loss:3.4968 train_time:208335ms step_avg:172.04ms step:1222/1530 train_loss:3.4276 train_time:208510ms step_avg:172.04ms step:1223/1530 train_loss:3.2922 train_time:208688ms step_avg:172.04ms step:1224/1530 train_loss:3.2472 train_time:208868ms step_avg:172.05ms step:1225/1530 train_loss:3.3558 train_time:209046ms step_avg:172.05ms step:1226/1530 train_loss:3.3266 train_time:209227ms step_avg:172.06ms step:1227/1530 train_loss:3.2719 train_time:209407ms step_avg:172.07ms step:1228/1530 train_loss:3.4419 train_time:209583ms step_avg:172.07ms step:1229/1530 train_loss:3.3652 train_time:209760ms step_avg:172.08ms step:1230/1530 train_loss:3.3952 train_time:209944ms step_avg:172.09ms step:1231/1530 train_loss:3.5686 train_time:210125ms step_avg:172.09ms step:1232/1530 train_loss:3.4929 train_time:210304ms step_avg:172.10ms step:1233/1530 train_loss:3.4191 train_time:210481ms step_avg:172.10ms step:1234/1530 train_loss:3.5797 train_time:210657ms step_avg:172.11ms step:1235/1530 train_loss:3.3182 train_time:210840ms step_avg:172.11ms step:1236/1530 train_loss:3.2826 train_time:211016ms step_avg:172.12ms step:1237/1530 train_loss:3.2650 train_time:211193ms step_avg:172.12ms step:1238/1530 train_loss:3.2720 train_time:211375ms step_avg:172.13ms step:1239/1530 train_loss:3.3313 train_time:211554ms step_avg:172.14ms step:1240/1530 train_loss:3.3816 train_time:211733ms step_avg:172.14ms step:1241/1530 train_loss:3.4187 train_time:211912ms step_avg:172.15ms step:1242/1530 train_loss:3.2997 train_time:212089ms step_avg:172.15ms step:1243/1530 train_loss:3.4043 train_time:212268ms step_avg:172.16ms step:1244/1530 train_loss:3.3985 train_time:212442ms step_avg:172.16ms step:1245/1530 train_loss:3.4082 train_time:212617ms step_avg:172.16ms step:1246/1530 train_loss:3.2385 train_time:212798ms step_avg:172.17ms step:1247/1530 train_loss:3.3704 train_time:212973ms step_avg:172.17ms step:1248/1530 train_loss:3.4270 train_time:213151ms step_avg:172.17ms step:1249/1530 train_loss:3.4190 train_time:213331ms step_avg:172.18ms step:1250/1530 train_loss:3.3046 train_time:213509ms step_avg:172.18ms step:1250/1530 val_loss:3.3500 train_time:213564ms step_avg:172.23ms step:1251/1530 train_loss:3.4847 train_time:213694ms step_avg:172.19ms step:1252/1530 train_loss:3.3566 train_time:213869ms step_avg:172.20ms step:1253/1530 train_loss:3.3070 train_time:214046ms step_avg:172.20ms step:1254/1530 train_loss:3.4075 train_time:214227ms step_avg:172.21ms step:1255/1530 train_loss:3.5119 train_time:214416ms step_avg:172.22ms step:1256/1530 train_loss:3.2979 train_time:214600ms step_avg:172.23ms step:1257/1530 train_loss:3.3697 train_time:214778ms step_avg:172.24ms step:1258/1530 train_loss:3.3601 train_time:214962ms step_avg:172.24ms step:1259/1530 train_loss:3.3216 train_time:215140ms step_avg:172.25ms step:1260/1530 train_loss:3.2009 train_time:215319ms step_avg:172.25ms step:1261/1530 train_loss:3.2999 train_time:215498ms step_avg:172.26ms step:1262/1530 train_loss:3.3230 train_time:215681ms step_avg:172.27ms step:1263/1530 train_loss:3.2359 train_time:215864ms step_avg:172.28ms step:1264/1530 train_loss:3.4345 train_time:216038ms step_avg:172.28ms step:1265/1530 train_loss:3.4218 train_time:216212ms step_avg:172.28ms step:1266/1530 train_loss:3.4336 train_time:216393ms step_avg:172.29ms step:1267/1530 train_loss:3.3653 train_time:216573ms step_avg:172.29ms step:1268/1530 train_loss:3.4034 train_time:216754ms step_avg:172.30ms step:1269/1530 train_loss:3.2513 train_time:216941ms step_avg:172.31ms step:1270/1530 train_loss:3.0968 train_time:217118ms step_avg:172.32ms step:1271/1530 train_loss:3.3984 train_time:217297ms step_avg:172.32ms step:1272/1530 train_loss:3.3474 train_time:217472ms step_avg:172.32ms step:1273/1530 train_loss:3.3741 train_time:217652ms step_avg:172.33ms step:1274/1530 train_loss:3.3560 train_time:217832ms step_avg:172.34ms step:1275/1530 train_loss:3.4238 train_time:218008ms step_avg:172.34ms step:1276/1530 train_loss:3.4667 train_time:218182ms step_avg:172.34ms step:1277/1530 train_loss:3.4083 train_time:218360ms step_avg:172.34ms step:1278/1530 train_loss:3.4020 train_time:218537ms step_avg:172.35ms step:1279/1530 train_loss:3.2567 train_time:218718ms step_avg:172.35ms step:1280/1530 train_loss:3.3657 train_time:218905ms step_avg:172.37ms step:1281/1530 train_loss:3.4234 train_time:219081ms step_avg:172.37ms step:1282/1530 train_loss:3.4662 train_time:219256ms step_avg:172.37ms step:1283/1530 train_loss:3.3285 train_time:219435ms step_avg:172.38ms step:1284/1530 train_loss:3.3664 train_time:219614ms step_avg:172.38ms step:1285/1530 train_loss:3.3565 train_time:219793ms step_avg:172.39ms step:1286/1530 train_loss:3.3324 train_time:219971ms step_avg:172.39ms step:1287/1530 train_loss:3.4837 train_time:220148ms step_avg:172.39ms step:1288/1530 train_loss:3.2894 train_time:220329ms step_avg:172.40ms step:1289/1530 train_loss:3.3819 train_time:220515ms step_avg:172.41ms step:1290/1530 train_loss:3.4549 train_time:220701ms step_avg:172.42ms step:1291/1530 train_loss:3.3801 train_time:220882ms step_avg:172.43ms step:1292/1530 train_loss:3.4771 train_time:221064ms step_avg:172.44ms step:1293/1530 train_loss:3.5037 train_time:221245ms step_avg:172.44ms step:1294/1530 train_loss:3.4578 train_time:221426ms step_avg:172.45ms step:1295/1530 train_loss:3.2805 train_time:221606ms step_avg:172.46ms step:1296/1530 train_loss:3.3706 train_time:221786ms step_avg:172.46ms step:1297/1530 train_loss:3.2688 train_time:221966ms step_avg:172.47ms step:1298/1530 train_loss:3.2682 train_time:222146ms step_avg:172.47ms step:1299/1530 train_loss:3.3912 train_time:222324ms step_avg:172.48ms step:1300/1530 train_loss:3.3959 train_time:222500ms step_avg:172.48ms step:1301/1530 train_loss:3.3973 train_time:222676ms step_avg:172.48ms step:1302/1530 train_loss:3.5719 train_time:222860ms step_avg:172.49ms step:1303/1530 train_loss:3.3059 train_time:223043ms step_avg:172.50ms step:1304/1530 train_loss:3.5076 train_time:223226ms step_avg:172.51ms step:1305/1530 train_loss:3.2557 train_time:223403ms step_avg:172.51ms step:1306/1530 train_loss:3.4534 train_time:223582ms step_avg:172.52ms step:1307/1530 train_loss:3.4496 train_time:223757ms step_avg:172.52ms step:1308/1530 train_loss:3.2790 train_time:223936ms step_avg:172.52ms step:1309/1530 train_loss:3.3077 train_time:224115ms step_avg:172.53ms step:1310/1530 train_loss:3.2810 train_time:224293ms step_avg:172.53ms step:1311/1530 train_loss:3.2946 train_time:224471ms step_avg:172.54ms step:1312/1530 train_loss:3.3680 train_time:224651ms step_avg:172.54ms step:1313/1530 train_loss:3.3407 train_time:224827ms step_avg:172.55ms step:1314/1530 train_loss:3.0402 train_time:225009ms step_avg:172.55ms step:1315/1530 train_loss:3.2698 train_time:225187ms step_avg:172.56ms step:1316/1530 train_loss:3.3902 train_time:225362ms step_avg:172.56ms step:1317/1530 train_loss:3.4138 train_time:225540ms step_avg:172.56ms step:1318/1530 train_loss:3.2992 train_time:225725ms step_avg:172.57ms step:1319/1530 train_loss:3.4230 train_time:225905ms step_avg:172.58ms step:1320/1530 train_loss:3.4559 train_time:226087ms step_avg:172.59ms step:1321/1530 train_loss:3.3604 train_time:226266ms step_avg:172.59ms step:1322/1530 train_loss:3.3211 train_time:226576ms step_avg:172.70ms step:1323/1530 train_loss:3.3110 train_time:226765ms step_avg:172.71ms step:1324/1530 train_loss:3.4314 train_time:226945ms step_avg:172.71ms step:1325/1530 train_loss:3.4854 train_time:227129ms step_avg:172.72ms step:1326/1530 train_loss:3.2096 train_time:227310ms step_avg:172.73ms step:1327/1530 train_loss:3.1631 train_time:227485ms step_avg:172.73ms step:1328/1530 train_loss:3.4923 train_time:227666ms step_avg:172.74ms step:1329/1530 train_loss:3.2927 train_time:228014ms step_avg:172.87ms step:1330/1530 train_loss:3.4229 train_time:228195ms step_avg:172.88ms step:1331/1530 train_loss:3.3267 train_time:228370ms step_avg:172.88ms step:1332/1530 train_loss:3.7431 train_time:228553ms step_avg:172.88ms step:1333/1530 train_loss:3.4780 train_time:228735ms step_avg:172.89ms step:1334/1530 train_loss:3.3655 train_time:228913ms step_avg:172.90ms step:1335/1530 train_loss:3.2832 train_time:229092ms step_avg:172.90ms step:1336/1530 train_loss:3.2934 train_time:229277ms step_avg:172.91ms step:1337/1530 train_loss:3.5478 train_time:229457ms step_avg:172.91ms step:1338/1530 train_loss:3.5195 train_time:229634ms step_avg:172.92ms step:1339/1530 train_loss:3.3337 train_time:229814ms step_avg:172.92ms step:1340/1530 train_loss:3.2799 train_time:229992ms step_avg:172.93ms step:1341/1530 train_loss:3.5887 train_time:230170ms step_avg:172.93ms step:1342/1530 train_loss:3.3543 train_time:230349ms step_avg:172.93ms step:1343/1530 train_loss:3.3621 train_time:230526ms step_avg:172.94ms step:1344/1530 train_loss:3.4088 train_time:230707ms step_avg:172.94ms step:1345/1530 train_loss:3.3827 train_time:230889ms step_avg:172.95ms step:1346/1530 train_loss:3.2927 train_time:231066ms step_avg:172.95ms step:1347/1530 train_loss:3.2752 train_time:231244ms step_avg:172.96ms step:1348/1530 train_loss:3.3462 train_time:231421ms step_avg:172.96ms step:1349/1530 train_loss:3.2654 train_time:231597ms step_avg:172.96ms step:1350/1530 train_loss:3.3861 train_time:231778ms step_avg:172.97ms step:1351/1530 train_loss:3.2412 train_time:231954ms step_avg:172.97ms step:1352/1530 train_loss:3.2980 train_time:232134ms step_avg:172.98ms step:1353/1530 train_loss:3.3973 train_time:232314ms step_avg:172.98ms step:1354/1530 train_loss:3.2550 train_time:232492ms step_avg:172.98ms step:1355/1530 train_loss:3.1847 train_time:232669ms step_avg:172.99ms step:1356/1530 train_loss:3.5097 train_time:232850ms step_avg:172.99ms step:1357/1530 train_loss:3.4199 train_time:233033ms step_avg:173.00ms step:1358/1530 train_loss:3.1839 train_time:233212ms step_avg:173.01ms step:1359/1530 train_loss:3.4371 train_time:233392ms step_avg:173.01ms step:1360/1530 train_loss:3.3418 train_time:233572ms step_avg:173.02ms step:1361/1530 train_loss:3.1200 train_time:233757ms step_avg:173.03ms step:1362/1530 train_loss:3.3886 train_time:233939ms step_avg:173.03ms step:1363/1530 train_loss:3.2822 train_time:234127ms step_avg:173.04ms step:1364/1530 train_loss:3.2973 train_time:234305ms step_avg:173.05ms step:1365/1530 train_loss:3.3108 train_time:234483ms step_avg:173.05ms step:1366/1530 train_loss:3.4206 train_time:234664ms step_avg:173.06ms step:1367/1530 train_loss:3.3937 train_time:234842ms step_avg:173.06ms step:1368/1530 train_loss:3.3441 train_time:235022ms step_avg:173.06ms step:1369/1530 train_loss:3.2720 train_time:235210ms step_avg:173.08ms step:1370/1530 train_loss:3.6039 train_time:235389ms step_avg:173.08ms step:1371/1530 train_loss:3.3078 train_time:235572ms step_avg:173.09ms step:1372/1530 train_loss:3.3643 train_time:235757ms step_avg:173.10ms step:1373/1530 train_loss:3.3654 train_time:235937ms step_avg:173.10ms step:1374/1530 train_loss:3.1474 train_time:236119ms step_avg:173.11ms step:1375/1530 train_loss:3.5287 train_time:236299ms step_avg:173.11ms step:1375/1530 val_loss:3.3078 train_time:236350ms step_avg:173.15ms step:1376/1530 train_loss:3.3464 train_time:236478ms step_avg:173.12ms step:1377/1530 train_loss:3.4758 train_time:236658ms step_avg:173.12ms step:1378/1530 train_loss:3.4716 train_time:236835ms step_avg:173.13ms step:1379/1530 train_loss:3.1178 train_time:237017ms step_avg:173.13ms step:1380/1530 train_loss:3.3082 train_time:237196ms step_avg:173.14ms step:1381/1530 train_loss:3.6937 train_time:237382ms step_avg:173.14ms step:1382/1530 train_loss:3.2053 train_time:237561ms step_avg:173.15ms step:1383/1530 train_loss:3.3874 train_time:237744ms step_avg:173.16ms step:1384/1530 train_loss:3.4707 train_time:237930ms step_avg:173.17ms step:1385/1530 train_loss:3.4021 train_time:238103ms step_avg:173.17ms step:1386/1530 train_loss:3.3337 train_time:238283ms step_avg:173.17ms step:1387/1530 train_loss:3.1977 train_time:238464ms step_avg:173.18ms step:1388/1530 train_loss:3.3444 train_time:238640ms step_avg:173.18ms step:1389/1530 train_loss:3.3142 train_time:238822ms step_avg:173.19ms step:1390/1530 train_loss:3.5614 train_time:238999ms step_avg:173.19ms step:1391/1530 train_loss:3.2883 train_time:239178ms step_avg:173.19ms step:1392/1530 train_loss:3.2800 train_time:239357ms step_avg:173.20ms step:1393/1530 train_loss:3.2335 train_time:239538ms step_avg:173.20ms step:1394/1530 train_loss:3.4940 train_time:239714ms step_avg:173.20ms step:1395/1530 train_loss:3.3921 train_time:239893ms step_avg:173.21ms step:1396/1530 train_loss:3.4005 train_time:240071ms step_avg:173.21ms step:1397/1530 train_loss:3.3086 train_time:240248ms step_avg:173.21ms step:1398/1530 train_loss:3.2548 train_time:240424ms step_avg:173.22ms step:1399/1530 train_loss:3.3124 train_time:240602ms step_avg:173.22ms step:1400/1530 train_loss:3.3164 train_time:240784ms step_avg:173.23ms step:1401/1530 train_loss:3.3441 train_time:240961ms step_avg:173.23ms step:1402/1530 train_loss:3.2898 train_time:241140ms step_avg:173.23ms step:1403/1530 train_loss:3.4907 train_time:241324ms step_avg:173.24ms step:1404/1530 train_loss:3.2767 train_time:241502ms step_avg:173.24ms step:1405/1530 train_loss:3.3091 train_time:241684ms step_avg:173.25ms step:1406/1530 train_loss:3.3153 train_time:241866ms step_avg:173.26ms step:1407/1530 train_loss:3.1679 train_time:242040ms step_avg:173.26ms step:1408/1530 train_loss:3.3094 train_time:242219ms step_avg:173.26ms step:1409/1530 train_loss:3.2988 train_time:242407ms step_avg:173.27ms step:1410/1530 train_loss:3.2832 train_time:242584ms step_avg:173.27ms step:1411/1530 train_loss:3.3625 train_time:242761ms step_avg:173.28ms step:1412/1530 train_loss:3.3295 train_time:242937ms step_avg:173.28ms step:1413/1530 train_loss:3.3550 train_time:243117ms step_avg:173.28ms step:1414/1530 train_loss:3.3208 train_time:243297ms step_avg:173.29ms step:1415/1530 train_loss:3.4054 train_time:243482ms step_avg:173.30ms step:1416/1530 train_loss:3.2258 train_time:243671ms step_avg:173.31ms step:1417/1530 train_loss:3.2733 train_time:243854ms step_avg:173.32ms step:1418/1530 train_loss:3.3840 train_time:244033ms step_avg:173.32ms step:1419/1530 train_loss:3.3380 train_time:244215ms step_avg:173.33ms step:1420/1530 train_loss:3.3598 train_time:244397ms step_avg:173.33ms step:1421/1530 train_loss:3.3693 train_time:244577ms step_avg:173.34ms step:1422/1530 train_loss:3.3283 train_time:244755ms step_avg:173.34ms step:1423/1530 train_loss:3.3111 train_time:244933ms step_avg:173.34ms step:1424/1530 train_loss:3.3309 train_time:245117ms step_avg:173.35ms step:1425/1530 train_loss:3.1899 train_time:245303ms step_avg:173.36ms step:1426/1530 train_loss:3.3183 train_time:245481ms step_avg:173.36ms step:1427/1530 train_loss:3.2790 train_time:245665ms step_avg:173.37ms step:1428/1530 train_loss:3.3743 train_time:245843ms step_avg:173.37ms step:1429/1530 train_loss:3.3504 train_time:246020ms step_avg:173.38ms step:1430/1530 train_loss:3.2564 train_time:246203ms step_avg:173.38ms step:1431/1530 train_loss:3.3169 train_time:246384ms step_avg:173.39ms step:1432/1530 train_loss:3.3353 train_time:246565ms step_avg:173.39ms step:1433/1530 train_loss:3.1323 train_time:246749ms step_avg:173.40ms step:1434/1530 train_loss:3.2883 train_time:246932ms step_avg:173.41ms step:1435/1530 train_loss:3.1080 train_time:247111ms step_avg:173.41ms step:1436/1530 train_loss:3.2320 train_time:247290ms step_avg:173.42ms step:1437/1530 train_loss:3.4046 train_time:247469ms step_avg:173.42ms step:1438/1530 train_loss:3.3773 train_time:247644ms step_avg:173.42ms step:1439/1530 train_loss:3.3097 train_time:247824ms step_avg:173.42ms step:1440/1530 train_loss:3.1924 train_time:247999ms step_avg:173.43ms step:1441/1530 train_loss:3.3330 train_time:248180ms step_avg:173.43ms step:1442/1530 train_loss:3.3817 train_time:248366ms step_avg:173.44ms step:1443/1530 train_loss:3.4840 train_time:248553ms step_avg:173.45ms step:1444/1530 train_loss:3.4439 train_time:248729ms step_avg:173.45ms step:1445/1530 train_loss:3.3348 train_time:248904ms step_avg:173.45ms step:1446/1530 train_loss:3.1956 train_time:249084ms step_avg:173.46ms step:1447/1530 train_loss:3.2933 train_time:249264ms step_avg:173.46ms step:1448/1530 train_loss:3.2909 train_time:249441ms step_avg:173.46ms step:1449/1530 train_loss:3.3935 train_time:249622ms step_avg:173.47ms step:1450/1530 train_loss:3.3836 train_time:249804ms step_avg:173.48ms step:1451/1530 train_loss:3.2005 train_time:249984ms step_avg:173.48ms step:1452/1530 train_loss:3.3204 train_time:250163ms step_avg:173.48ms step:1453/1530 train_loss:3.2556 train_time:250338ms step_avg:173.48ms step:1454/1530 train_loss:3.2880 train_time:250515ms step_avg:173.49ms step:1455/1530 train_loss:3.3288 train_time:250698ms step_avg:173.49ms step:1456/1530 train_loss:3.2826 train_time:250876ms step_avg:173.50ms step:1457/1530 train_loss:3.1515 train_time:251053ms step_avg:173.50ms step:1458/1530 train_loss:3.4197 train_time:251231ms step_avg:173.50ms step:1459/1530 train_loss:3.2717 train_time:251411ms step_avg:173.51ms step:1460/1530 train_loss:3.3130 train_time:251590ms step_avg:173.51ms step:1461/1530 train_loss:3.4254 train_time:251770ms step_avg:173.52ms step:1462/1530 train_loss:3.2584 train_time:251947ms step_avg:173.52ms step:1463/1530 train_loss:3.4630 train_time:252131ms step_avg:173.52ms step:1464/1530 train_loss:3.3564 train_time:252310ms step_avg:173.53ms step:1465/1530 train_loss:3.3521 train_time:252490ms step_avg:173.53ms step:1466/1530 train_loss:3.2868 train_time:252667ms step_avg:173.54ms step:1467/1530 train_loss:3.3933 train_time:252845ms step_avg:173.54ms step:1468/1530 train_loss:3.2846 train_time:253023ms step_avg:173.54ms step:1469/1530 train_loss:3.2695 train_time:253201ms step_avg:173.54ms step:1470/1530 train_loss:3.3298 train_time:253385ms step_avg:173.55ms step:1471/1530 train_loss:3.2591 train_time:253570ms step_avg:173.56ms step:1472/1530 train_loss:3.2446 train_time:253754ms step_avg:173.57ms step:1473/1530 train_loss:3.4365 train_time:253932ms step_avg:173.57ms step:1474/1530 train_loss:3.3083 train_time:254114ms step_avg:173.58ms step:1475/1530 train_loss:3.1484 train_time:254300ms step_avg:173.58ms step:1476/1530 train_loss:3.2657 train_time:254478ms step_avg:173.59ms step:1477/1530 train_loss:3.2351 train_time:254666ms step_avg:173.60ms step:1478/1530 train_loss:3.3053 train_time:254851ms step_avg:173.60ms step:1479/1530 train_loss:3.3905 train_time:255034ms step_avg:173.61ms step:1480/1530 train_loss:3.2648 train_time:255211ms step_avg:173.61ms step:1481/1530 train_loss:3.4528 train_time:255393ms step_avg:173.62ms step:1482/1530 train_loss:3.3676 train_time:255578ms step_avg:173.63ms step:1483/1530 train_loss:3.2732 train_time:255767ms step_avg:173.64ms step:1484/1530 train_loss:3.2592 train_time:255953ms step_avg:173.65ms step:1485/1530 train_loss:3.2795 train_time:256132ms step_avg:173.65ms step:1486/1530 train_loss:3.2214 train_time:256316ms step_avg:173.66ms step:1487/1530 train_loss:3.3385 train_time:256499ms step_avg:173.66ms step:1488/1530 train_loss:3.2399 train_time:256682ms step_avg:173.67ms step:1489/1530 train_loss:3.3085 train_time:256861ms step_avg:173.67ms step:1490/1530 train_loss:3.2496 train_time:257043ms step_avg:173.68ms step:1491/1530 train_loss:3.1522 train_time:257224ms step_avg:173.68ms step:1492/1530 train_loss:3.2653 train_time:257404ms step_avg:173.69ms step:1493/1530 train_loss:3.4303 train_time:257583ms step_avg:173.69ms step:1494/1530 train_loss:3.2930 train_time:257763ms step_avg:173.69ms step:1495/1530 train_loss:3.0281 train_time:257947ms step_avg:173.70ms step:1496/1530 train_loss:3.3578 train_time:258130ms step_avg:173.71ms step:1497/1530 train_loss:3.3056 train_time:258312ms step_avg:173.71ms step:1498/1530 train_loss:3.3392 train_time:258497ms step_avg:173.72ms step:1499/1530 train_loss:3.3054 train_time:258686ms step_avg:173.73ms step:1500/1530 train_loss:3.2946 train_time:258879ms step_avg:173.74ms step:1500/1530 val_loss:3.2760 train_time:258935ms step_avg:173.78ms step:1501/1530 train_loss:3.0809 train_time:259071ms step_avg:173.76ms step:1502/1530 train_loss:3.3529 train_time:259263ms step_avg:173.77ms step:1503/1530 train_loss:3.2404 train_time:259442ms step_avg:173.77ms step:1504/1530 train_loss:3.2400 train_time:259623ms step_avg:173.78ms step:1505/1530 train_loss:3.2158 train_time:259802ms step_avg:173.78ms step:1506/1530 train_loss:3.2769 train_time:259984ms step_avg:173.79ms step:1507/1530 train_loss:3.1759 train_time:260177ms step_avg:173.80ms step:1508/1530 train_loss:3.4787 train_time:260359ms step_avg:173.80ms step:1509/1530 train_loss:3.2732 train_time:260538ms step_avg:173.81ms step:1510/1530 train_loss:3.2644 train_time:260716ms step_avg:173.81ms step:1511/1530 train_loss:3.4159 train_time:261027ms step_avg:173.90ms step:1512/1530 train_loss:3.4148 train_time:261212ms step_avg:173.91ms step:1513/1530 train_loss:3.2670 train_time:261398ms step_avg:173.92ms step:1514/1530 train_loss:3.0817 train_time:261581ms step_avg:173.92ms step:1515/1530 train_loss:3.2376 train_time:261761ms step_avg:173.93ms step:1516/1530 train_loss:3.2531 train_time:261945ms step_avg:173.93ms step:1517/1530 train_loss:3.2941 train_time:262127ms step_avg:173.94ms step:1518/1530 train_loss:3.2037 train_time:262309ms step_avg:173.95ms step:1519/1530 train_loss:3.4980 train_time:262650ms step_avg:174.06ms step:1520/1530 train_loss:3.1246 train_time:262834ms step_avg:174.06ms step:1521/1530 train_loss:3.1987 train_time:263010ms step_avg:174.06ms step:1522/1530 train_loss:3.3507 train_time:263194ms step_avg:174.07ms step:1523/1530 train_loss:3.2224 train_time:263371ms step_avg:174.07ms step:1524/1530 train_loss:3.3431 train_time:263552ms step_avg:174.08ms step:1525/1530 train_loss:3.3365 train_time:263740ms step_avg:174.09ms step:1526/1530 train_loss:3.2717 train_time:263930ms step_avg:174.10ms step:1527/1530 train_loss:3.2886 train_time:264111ms step_avg:174.10ms step:1528/1530 train_loss:3.4040 train_time:264292ms step_avg:174.11ms step:1529/1530 train_loss:3.4074 train_time:264471ms step_avg:174.11ms step:1530/1530 train_loss:3.2356 train_time:264650ms step_avg:174.11ms step:1530/1530 val_loss:3.2735 train_time:264705ms step_avg:174.15ms