====================================================================================================
import os
import sys
with open(sys.argv[0]) as f:
    code = f.read() # read the code of this file ASAP, for logging
import uuid
import glob
import time
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP

# -----------------------------------------------------------------------------
# Muon optimizer

def zeropower_via_svd(G, steps=None):
    U, S, V = G.svd()
    return U @ V.T

@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    X /= (X.norm() + eps) # ensure top singular value <= 1
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X

zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)

class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - This optimizer assumes that all parameters passed in are 2D.
    - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
    parameters; those should all be optimized by a standard method (e.g., AdamW).
    - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
    - We believe it is unlikely to work well for training with small batch size.
    - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
    - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).

    Arguments:
        lr: The learning rate used by the internal SGD.
        momentum: The momentum used by the internal SGD.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
        backend_steps: The number of iteration steps to use in the backend, if it is iterative.
    """
    def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True,
                 backend='newtonschulz5', backend_steps=5,
                 rank=0, world_size=1):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
        super().__init__(params, defaults)
        self.rank = rank
        self.world_size = world_size

    def step(self):

        for group in self.param_groups:

            lr = group['lr']
            momentum = group['momentum']
            zeropower_backend = zeropower_backends[group['backend']]

            # generate weight updates in distributed fashion
            total_params = sum(p.numel() for p in group['params'])
            updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
            curr_idx = 0
            for i, p in enumerate(group['params']):
                # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
                if i % self.world_size == self.rank:
                    g = p.grad
                    if g is None:
                        continue
                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = torch.zeros_like(g)
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(g)
                    if group['nesterov']:
                        g = g.add(buf, alpha=momentum)
                    g = zeropower_backend(g, steps=group['backend_steps'])
                    g *= max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1
                    updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
                curr_idx += p.numel()

            # sync updates across devices. we are not memory-constrained so can do this simple deserialization
            dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)

            # deserialize and apply updates
            curr_idx = 0
            for p in group['params']:
                g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
                p.data.add_(g, alpha=-lr)
                curr_idx += p.numel()

# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model

class Rotary(torch.nn.Module):

    def __init__(self, dim, base=10000):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq).to(x.device)
            self.cos_cached = freqs.cos().bfloat16()
            self.sin_cached = freqs.sin().bfloat16()
        return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4 # multihead attention
    d = x.shape[3]//2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
        # output projection
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
        self.rotary = Rotary(self.head_dim)

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
        cos, sin = self.rotary(q)
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
        y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(F.rms_norm(x, (x.size(-1),)))
        x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
        return x

# -----------------------------------------------------------------------------
# The main GPT-2 model

@dataclass
class GPTConfig:
    vocab_size : int = 50304
    n_layer : int = 12
    n_head : int = 6 # head dim 128 suggested by @Grad62304977
    n_embd : int = 768

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

    def forward(self, idx, targets=None, return_logits=True):

        # forward the GPT model itself
        x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        for block in self.transformer.h:
            x = block(x)
        x = F.rms_norm(x, (x.size(-1),))

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            logits = logits.float() # use tf32/fp32 for logits
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            logits = logits.float() # use tf32/fp32 for logits
            loss = None

        # there are performance reasons why not returning logits is prudent, if not needed
        if not return_logits:
            logits = None

        return logits, loss

# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

def _peek_data_shard(filename):
    # only reads the header, returns header data
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
    if header[0] != 20240520:
        print("ERROR: magic number mismatch in the data .bin file!")
        print("---> HINT: Are you passing in a correct file with --input_bin?")
        print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
        print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
        exit(1)
    assert header[1] == 1, "unsupported version"
    ntok = header[2] # number of tokens (claimed)
    return ntok # for now just return the number of tokens

def _load_data_shard(filename):
    with open(filename, "rb") as f:
        # first read the header, which is 256 int32 integers (4 bytes each)
        header = np.frombuffer(f.read(256*4), dtype=np.int32)
        assert header[0] == 20240520, "magic number mismatch in the data .bin file"
        assert header[1] == 1, "unsupported version"
        ntok = header[2] # number of tokens (claimed)
        # the rest of it are tokens, stored as uint16
        tokens = np.frombuffer(f.read(), dtype=np.uint16)
    assert len(tokens) == ntok, "number of tokens read does not match header?"
    return tokens

class DistributedDataLoader:
    def __init__(self, filename_pattern, B, T, process_rank, num_processes):
        self.process_rank = process_rank
        self.num_processes = num_processes
        self.B = B
        self.T = T

        # glob files that match the pattern
        self.files = sorted(glob.glob(filename_pattern))
        assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

        # load and validate all data shards, count number of tokens in total
        ntok_total = 0
        for fname in self.files:
            shard_ntok = _peek_data_shard(fname)
            assert shard_ntok >= num_processes * B * T + 1
            ntok_total += int(shard_ntok)
        self.ntok_total = ntok_total

        # kick things off
        self.reset()

    def reset(self):
        self.current_shard = 0
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + 1) % len(self.files)
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def next_batch(self):
        B = self.B
        T = self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance current position and load next shard if necessary
        self.current_position += B * T * self.num_processes
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.advance()
        return x.cuda(), y.cuda()

# -----------------------------------------------------------------------------
# int main

@dataclass
class Hyperparameters:
    # data hyperparams
    input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
    input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
    # optimization hyperparams
    batch_size : int = 8*64 # batch size, in sequences, across all devices
    device_batch_size : int = 64 # batch size, in sequences, per device
    sequence_length : int = 1024 # sequence length, in tokens
    num_iterations : int = 5100 # number of iterations to run
    learning_rate : float = 0.0036
    warmup_iters : int = 0
    warmdown_iters : int = 1450 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
    weight_decay : float = 0
    # evaluation and logging hyperparams
    val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
    val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
    save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
args = Hyperparameters()

# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.

# convenience variables
B, T = args.device_batch_size, args.sequence_length
# calculate the number of steps to take in the val loop.
assert args.val_tokens % (B * T * ddp_world_size) == 0
val_steps = args.val_tokens // (B * T * ddp_world_size)
# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % (B * ddp_world_size) == 0
train_accumulation_steps = args.batch_size // (B * ddp_world_size)

# load tokens
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
if master_process:
    print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
    print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
x, y = train_loader.next_batch()

# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
model = model.cuda()
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model)
# here we wrap model into DDP container
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module # always contains the "raw" unwrapped model
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)

# init the optimizer(s)
optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95),
                               weight_decay=args.weight_decay, fused=True)
optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95,
                  rank=ddp_rank, world_size=ddp_world_size)
optimizers = [optimizer1, optimizer2]
# learning rate decay scheduler (linear warmup and warmdown)
def get_lr(it):
    assert it <= args.num_iterations
    # 1) linear warmup for warmup_iters steps
    if it < args.warmup_iters:
        return (it+1) / args.warmup_iters
    # 2) constant lr for a while
    elif it < args.num_iterations - args.warmdown_iters:
        return 1.0
    # 3) linear warmdown
    else:
        decay_ratio = (args.num_iterations - it) / args.warmdown_iters
        return decay_ratio
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]

# begin logging
if master_process:
    run_id = str(uuid.uuid4())
    logdir = 'logs/%s/' % run_id
    os.makedirs(logdir, exist_ok=True)
    logfile = 'logs/%s.txt' % run_id
    # create the log file
    with open(logfile, "w") as f:
        # begin the log by printing this file (the Python code)
        f.write('='*100 + '\n')
        f.write(code)
        f.write('='*100 + '\n')
        # log information about the hardware/software environment this is running on
        # and print the full `nvidia-smi` to file
        f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
        import subprocess
        result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        f.write(f'{result.stdout}\n')
        f.write('='*100 + '\n')

training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.time()
# begin training
train_loader.reset()
for step in range(args.num_iterations + 1):
    last_step = (step == args.num_iterations)
    # This effectively ignores timing first 10 steps, which are slower for weird reasons.
    # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
    # steps with dummy data first, and then re-initialize the model and reset the loader.
    if step == 10:
        training_time_ms = 0
        t0 = time.time()
    timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val

    # once in a while evaluate the validation dataset
    if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        # run validation batches
        model.eval()
        val_loader.reset()
        val_loss = 0.0
        for _ in range(val_steps):
            x_val, y_val = val_loader.next_batch()
            with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason
                _, loss = model(x_val, y_val, return_logits=False)
                val_loss += loss.detach()
                del loss
        dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
        val_loss /= val_steps
        # log val loss to console and to logfile
        if master_process:
            print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
            with open(logfile, "a") as f:
                f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        # save the state of the training process
        log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
        torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    # bit confusing: we want to make sure to eval on 0th iteration
    # but also after the very last iteration. so we loop for step <= num_iterations
    # instead of just < num_iterations (one extra due to <=), only to do
    # the validation/sampling one last time, and then we break right here as we're done.
    if last_step:
        break

    # --------------- TRAINING SECTION BEGIN -----------------
    model.train()
    for i in range(1, train_accumulation_steps+1):
        # forward pass
        with ctx:
            _, loss = model(x, y, return_logits=False)
            train_loss = loss.detach()
        # advance the dataset for the next batch
        x, y = train_loader.next_batch()
        # backward pass
        if i < train_accumulation_steps:
            with model.no_sync(): # there's no need to sync gradients every accumulation step
                loss.backward()
        else:
            loss.backward() # just sync on the last step
    for p in model.parameters():
        p.grad /= train_accumulation_steps
    # step the optimizers and schedulers
    for opt, sched in zip(optimizers, schedulers):
        opt.step()
        sched.step()
    # null the gradients
    model.zero_grad(set_to_none=True)
    # --------------- TRAINING SECTION END -------------------
    # everything that follows now is just diagnostics, prints, logging, etc.

    #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
    if master_process:
        approx_time = training_time_ms + 1000 * (time.time() - t0)
        print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
        with open(logfile, "a") as f:
            f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")

if master_process:
    print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()
====================================================================================================
Running pytorch 2.5.0+cu124 compiled for CUDA 12.4
nvidia-smi:
Fri Oct 18 06:08:02 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          Off |   00000000:18:00.0 Off |                    0 |
| N/A   33C    P0            141W /  700W |    4860MiB /  81559MiB |      7%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          Off |   00000000:2A:00.0 Off |                    0 |
| N/A   33C    P0            121W /  700W |    4908MiB /  81559MiB |      5%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          Off |   00000000:3A:00.0 Off |                    0 |
| N/A   34C    P0            123W /  700W |    4908MiB /  81559MiB |      1%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          Off |   00000000:5D:00.0 Off |                    0 |
| N/A   31C    P0            124W /  700W |    4908MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          Off |   00000000:9A:00.0 Off |                    0 |
| N/A   33C    P0            133W /  700W |    4908MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          Off |   00000000:AB:00.0 Off |                    0 |
| N/A   34C    P0            128W /  700W |    4908MiB /  81559MiB |      7%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          Off |   00000000:BA:00.0 Off |                    0 |
| N/A   33C    P0            126W /  700W |    4908MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          Off |   00000000:DB:00.0 Off |                    0 |
| N/A   33C    P0            132W /  700W |    4668MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A       958      C   /usr/bin/python3                                0MiB |
|    1   N/A  N/A       959      C   /usr/bin/python3                                0MiB |
|    2   N/A  N/A       960      C   /usr/bin/python3                                0MiB |
|    3   N/A  N/A       961      C   /usr/bin/python3                                0MiB |
|    4   N/A  N/A       962      C   /usr/bin/python3                                0MiB |
|    5   N/A  N/A       963      C   /usr/bin/python3                                0MiB |
|    6   N/A  N/A       964      C   /usr/bin/python3                                0MiB |
|    7   N/A  N/A       965      C   /usr/bin/python3                                0MiB |
+-----------------------------------------------------------------------------------------+

====================================================================================================
step:0/5100 val_loss:16.0280 train_time:475ms step_avg:nanms
step:1/5100 train_loss:16.0268 train_time:35679ms step_avg:nanms
step:2/5100 train_loss:9.5039 train_time:35770ms step_avg:nanms
step:3/5100 train_loss:8.6998 train_time:35910ms step_avg:nanms
step:4/5100 train_loss:8.0755 train_time:36048ms step_avg:nanms
step:5/5100 train_loss:7.5667 train_time:36186ms step_avg:nanms
step:6/5100 train_loss:7.5466 train_time:36322ms step_avg:nanms
step:7/5100 train_loss:7.2854 train_time:36462ms step_avg:nanms
step:8/5100 train_loss:7.5010 train_time:36600ms step_avg:nanms
step:9/5100 train_loss:7.3365 train_time:36745ms step_avg:nanms
step:10/5100 train_loss:7.0160 train_time:36885ms step_avg:nanms
step:11/5100 train_loss:6.9676 train_time:83ms step_avg:nanms
step:12/5100 train_loss:6.8712 train_time:221ms step_avg:nanms
step:13/5100 train_loss:6.6573 train_time:359ms step_avg:119.69ms
step:14/5100 train_loss:6.6428 train_time:498ms step_avg:124.60ms
step:15/5100 train_loss:6.6117 train_time:638ms step_avg:127.56ms
step:16/5100 train_loss:6.5329 train_time:780ms step_avg:130.05ms
step:17/5100 train_loss:6.5378 train_time:921ms step_avg:131.57ms
step:18/5100 train_loss:6.5676 train_time:1061ms step_avg:132.63ms
step:19/5100 train_loss:6.3968 train_time:1208ms step_avg:134.17ms
step:20/5100 train_loss:6.4218 train_time:1339ms step_avg:133.88ms
step:21/5100 train_loss:6.0760 train_time:1478ms step_avg:134.34ms
step:22/5100 train_loss:6.4488 train_time:1618ms step_avg:134.79ms
step:23/5100 train_loss:6.6571 train_time:1759ms step_avg:135.33ms
step:24/5100 train_loss:6.3400 train_time:1906ms step_avg:136.14ms
step:25/5100 train_loss:6.4840 train_time:2042ms step_avg:136.11ms
step:26/5100 train_loss:6.1829 train_time:2180ms step_avg:136.23ms
step:27/5100 train_loss:6.0988 train_time:2320ms step_avg:136.48ms
step:28/5100 train_loss:6.2481 train_time:2459ms step_avg:136.59ms
step:29/5100 train_loss:5.9280 train_time:2601ms step_avg:136.87ms
step:30/5100 train_loss:6.2061 train_time:2743ms step_avg:137.13ms
step:31/5100 train_loss:6.0375 train_time:2881ms step_avg:137.17ms
step:32/5100 train_loss:6.0071 train_time:3021ms step_avg:137.31ms
step:33/5100 train_loss:5.8334 train_time:3159ms step_avg:137.37ms
step:34/5100 train_loss:6.1077 train_time:3303ms step_avg:137.62ms
step:35/5100 train_loss:6.0470 train_time:3440ms step_avg:137.58ms
step:36/5100 train_loss:6.1838 train_time:3580ms step_avg:137.67ms
step:37/5100 train_loss:6.1180 train_time:3719ms step_avg:137.76ms
step:38/5100 train_loss:6.0241 train_time:3860ms step_avg:137.86ms
step:39/5100 train_loss:5.9013 train_time:4001ms step_avg:137.96ms
step:40/5100 train_loss:5.9212 train_time:4144ms step_avg:138.13ms
step:41/5100 train_loss:5.8390 train_time:4280ms step_avg:138.08ms
step:42/5100 train_loss:5.8646 train_time:4427ms step_avg:138.34ms
step:43/5100 train_loss:5.7445 train_time:4559ms step_avg:138.17ms
step:44/5100 train_loss:5.8480 train_time:4700ms step_avg:138.24ms
step:45/5100 train_loss:5.8083 train_time:4841ms step_avg:138.30ms
step:46/5100 train_loss:5.9675 train_time:4982ms step_avg:138.38ms
step:47/5100 train_loss:5.7653 train_time:5120ms step_avg:138.39ms
step:48/5100 train_loss:5.6361 train_time:5260ms step_avg:138.43ms
step:49/5100 train_loss:5.8404 train_time:5401ms step_avg:138.50ms
step:50/5100 train_loss:5.7207 train_time:5543ms step_avg:138.57ms
step:51/5100 train_loss:5.8657 train_time:5681ms step_avg:138.56ms
step:52/5100 train_loss:5.7324 train_time:5827ms step_avg:138.74ms
step:53/5100 train_loss:5.5902 train_time:5963ms step_avg:138.68ms
step:54/5100 train_loss:5.7212 train_time:6115ms step_avg:138.97ms
step:55/5100 train_loss:5.6004 train_time:6244ms step_avg:138.76ms
step:56/5100 train_loss:5.9441 train_time:6383ms step_avg:138.76ms
step:57/5100 train_loss:5.6002 train_time:6523ms step_avg:138.78ms
step:58/5100 train_loss:5.4733 train_time:6663ms step_avg:138.81ms
step:59/5100 train_loss:5.6211 train_time:6805ms step_avg:138.88ms
step:60/5100 train_loss:5.5806 train_time:6948ms step_avg:138.95ms
step:61/5100 train_loss:5.6756 train_time:7086ms step_avg:138.94ms
step:62/5100 train_loss:5.4460 train_time:7226ms step_avg:138.96ms
step:63/5100 train_loss:5.5505 train_time:7367ms step_avg:139.00ms
step:64/5100 train_loss:5.5335 train_time:7507ms step_avg:139.02ms
step:65/5100 train_loss:5.2016 train_time:7646ms step_avg:139.02ms
step:66/5100 train_loss:5.3469 train_time:7789ms step_avg:139.09ms
step:67/5100 train_loss:5.4996 train_time:7927ms step_avg:139.06ms
step:68/5100 train_loss:5.3771 train_time:8067ms step_avg:139.09ms
step:69/5100 train_loss:5.6439 train_time:8209ms step_avg:139.14ms
step:70/5100 train_loss:5.2914 train_time:8351ms step_avg:139.18ms
step:71/5100 train_loss:5.3085 train_time:8490ms step_avg:139.17ms
step:72/5100 train_loss:5.5120 train_time:8629ms step_avg:139.18ms
step:73/5100 train_loss:5.4412 train_time:8770ms step_avg:139.21ms
step:74/5100 train_loss:5.3140 train_time:8913ms step_avg:139.26ms
step:75/5100 train_loss:5.4440 train_time:9065ms step_avg:139.46ms
step:76/5100 train_loss:5.4089 train_time:9195ms step_avg:139.31ms
step:77/5100 train_loss:5.3792 train_time:9336ms step_avg:139.35ms
step:78/5100 train_loss:5.4709 train_time:9477ms step_avg:139.37ms
step:79/5100 train_loss:5.5314 train_time:9619ms step_avg:139.40ms
step:80/5100 train_loss:5.3345 train_time:9759ms step_avg:139.42ms
step:81/5100 train_loss:5.4311 train_time:9900ms step_avg:139.43ms
step:82/5100 train_loss:5.1976 train_time:10041ms step_avg:139.45ms
step:83/5100 train_loss:5.3647 train_time:10181ms step_avg:139.47ms
step:84/5100 train_loss:5.3113 train_time:10321ms step_avg:139.48ms
step:85/5100 train_loss:5.3062 train_time:10461ms step_avg:139.49ms
step:86/5100 train_loss:5.1549 train_time:10602ms step_avg:139.50ms
step:87/5100 train_loss:5.3655 train_time:10741ms step_avg:139.50ms
step:88/5100 train_loss:5.2718 train_time:10883ms step_avg:139.53ms
step:89/5100 train_loss:5.3375 train_time:11022ms step_avg:139.51ms
step:90/5100 train_loss:5.2967 train_time:11164ms step_avg:139.55ms
step:91/5100 train_loss:5.2241 train_time:11302ms step_avg:139.53ms
step:92/5100 train_loss:5.1922 train_time:11442ms step_avg:139.53ms
step:93/5100 train_loss:5.3442 train_time:11582ms step_avg:139.54ms
step:94/5100 train_loss:5.1495 train_time:11721ms step_avg:139.54ms
step:95/5100 train_loss:5.1675 train_time:11861ms step_avg:139.54ms
step:96/5100 train_loss:5.1980 train_time:12002ms step_avg:139.55ms
step:97/5100 train_loss:5.1103 train_time:12141ms step_avg:139.56ms
step:98/5100 train_loss:5.1987 train_time:12290ms step_avg:139.66ms
step:99/5100 train_loss:5.1154 train_time:12421ms step_avg:139.57ms
step:100/5100 train_loss:5.2372 train_time:12561ms step_avg:139.57ms
step:101/5100 train_loss:5.2118 train_time:12706ms step_avg:139.63ms
step:102/5100 train_loss:5.1125 train_time:12844ms step_avg:139.61ms
step:103/5100 train_loss:5.1995 train_time:12982ms step_avg:139.59ms
step:104/5100 train_loss:5.1493 train_time:13122ms step_avg:139.60ms
step:105/5100 train_loss:5.0168 train_time:13262ms step_avg:139.60ms
step:106/5100 train_loss:5.1194 train_time:13402ms step_avg:139.60ms
step:107/5100 train_loss:5.3136 train_time:13542ms step_avg:139.60ms
step:108/5100 train_loss:5.0858 train_time:13682ms step_avg:139.61ms
step:109/5100 train_loss:4.8767 train_time:13825ms step_avg:139.64ms
step:110/5100 train_loss:5.0602 train_time:13963ms step_avg:139.63ms
step:111/5100 train_loss:5.0414 train_time:14103ms step_avg:139.64ms
step:112/5100 train_loss:5.0043 train_time:14244ms step_avg:139.64ms
step:113/5100 train_loss:5.1192 train_time:14383ms step_avg:139.64ms
step:114/5100 train_loss:5.0516 train_time:14535ms step_avg:139.76ms
step:115/5100 train_loss:4.9023 train_time:14663ms step_avg:139.65ms
step:116/5100 train_loss:5.0652 train_time:14804ms step_avg:139.66ms
step:117/5100 train_loss:4.9585 train_time:14947ms step_avg:139.69ms
step:118/5100 train_loss:4.9147 train_time:15083ms step_avg:139.66ms
step:119/5100 train_loss:5.0614 train_time:15222ms step_avg:139.65ms
step:120/5100 train_loss:5.0139 train_time:15363ms step_avg:139.66ms
step:121/5100 train_loss:4.9475 train_time:15503ms step_avg:139.67ms
step:122/5100 train_loss:4.8469 train_time:15644ms step_avg:139.68ms
step:123/5100 train_loss:4.9655 train_time:15784ms step_avg:139.68ms
step:124/5100 train_loss:4.8197 train_time:15924ms step_avg:139.69ms
step:125/5100 train_loss:5.1296 train_time:16064ms step_avg:139.69ms
step:125/5100 val_loss:4.9523 train_time:16121ms step_avg:140.18ms
step:126/5100 train_loss:4.9996 train_time:16216ms step_avg:139.79ms
step:127/5100 train_loss:4.9385 train_time:16366ms step_avg:139.88ms
step:128/5100 train_loss:5.0094 train_time:16505ms step_avg:139.87ms
step:129/5100 train_loss:4.8793 train_time:16647ms step_avg:139.89ms
step:130/5100 train_loss:5.1899 train_time:16786ms step_avg:139.88ms
step:131/5100 train_loss:4.9339 train_time:16921ms step_avg:139.84ms
step:132/5100 train_loss:4.9435 train_time:17060ms step_avg:139.83ms
step:133/5100 train_loss:4.8981 train_time:17201ms step_avg:139.84ms
step:134/5100 train_loss:4.9410 train_time:17345ms step_avg:139.88ms
step:135/5100 train_loss:4.8279 train_time:17487ms step_avg:139.90ms
step:136/5100 train_loss:4.9531 train_time:17628ms step_avg:139.90ms
step:137/5100 train_loss:4.7257 train_time:17768ms step_avg:139.91ms
step:138/5100 train_loss:4.8931 train_time:17910ms step_avg:139.92ms
step:139/5100 train_loss:4.8438 train_time:18048ms step_avg:139.91ms
step:140/5100 train_loss:4.8759 train_time:18190ms step_avg:139.92ms
step:141/5100 train_loss:4.9391 train_time:18333ms step_avg:139.95ms
step:142/5100 train_loss:4.8126 train_time:18474ms step_avg:139.96ms
step:143/5100 train_loss:4.8760 train_time:18614ms step_avg:139.96ms
step:144/5100 train_loss:4.7315 train_time:18754ms step_avg:139.95ms
step:145/5100 train_loss:4.8627 train_time:18898ms step_avg:139.99ms
step:146/5100 train_loss:4.8117 train_time:19033ms step_avg:139.95ms
step:147/5100 train_loss:4.6851 train_time:19174ms step_avg:139.96ms
step:148/5100 train_loss:4.8378 train_time:19318ms step_avg:139.99ms
step:149/5100 train_loss:4.8399 train_time:19456ms step_avg:139.97ms
step:150/5100 train_loss:4.8707 train_time:19596ms step_avg:139.97ms
step:151/5100 train_loss:4.9068 train_time:19735ms step_avg:139.97ms
step:152/5100 train_loss:4.7968 train_time:19875ms step_avg:139.96ms
step:153/5100 train_loss:4.7896 train_time:20014ms step_avg:139.96ms
step:154/5100 train_loss:4.8759 train_time:20155ms step_avg:139.96ms
step:155/5100 train_loss:4.8361 train_time:20305ms step_avg:140.03ms
step:156/5100 train_loss:4.7903 train_time:20436ms step_avg:139.97ms
step:157/5100 train_loss:4.8125 train_time:20576ms step_avg:139.97ms
step:158/5100 train_loss:4.9304 train_time:20716ms step_avg:139.97ms
step:159/5100 train_loss:4.7162 train_time:20855ms step_avg:139.97ms
step:160/5100 train_loss:4.7894 train_time:20995ms step_avg:139.96ms
step:161/5100 train_loss:4.6171 train_time:21134ms step_avg:139.96ms
step:162/5100 train_loss:4.8063 train_time:21275ms step_avg:139.97ms
step:163/5100 train_loss:4.8376 train_time:21415ms step_avg:139.97ms
step:164/5100 train_loss:4.8225 train_time:21555ms step_avg:139.97ms
step:165/5100 train_loss:4.6319 train_time:21695ms step_avg:139.97ms
step:166/5100 train_loss:4.7650 train_time:21835ms step_avg:139.97ms
step:167/5100 train_loss:4.8934 train_time:21974ms step_avg:139.96ms
step:168/5100 train_loss:4.6755 train_time:22115ms step_avg:139.97ms
step:169/5100 train_loss:4.7737 train_time:22255ms step_avg:139.97ms
step:170/5100 train_loss:4.6248 train_time:22395ms step_avg:139.97ms
step:171/5100 train_loss:4.5313 train_time:22535ms step_avg:139.97ms
step:172/5100 train_loss:4.6855 train_time:22676ms step_avg:139.97ms
step:173/5100 train_loss:4.6687 train_time:22815ms step_avg:139.97ms
step:174/5100 train_loss:4.7283 train_time:22953ms step_avg:139.96ms
step:175/5100 train_loss:4.8815 train_time:23094ms step_avg:139.97ms
step:176/5100 train_loss:4.7293 train_time:23235ms step_avg:139.97ms
step:177/5100 train_loss:4.5799 train_time:23375ms step_avg:139.97ms
step:178/5100 train_loss:4.5530 train_time:23519ms step_avg:139.99ms
step:179/5100 train_loss:4.6195 train_time:23655ms step_avg:139.97ms
step:180/5100 train_loss:4.6284 train_time:23796ms step_avg:139.98ms
step:181/5100 train_loss:4.6216 train_time:23935ms step_avg:139.97ms
step:182/5100 train_loss:4.7571 train_time:24081ms step_avg:140.01ms
step:183/5100 train_loss:4.6170 train_time:24216ms step_avg:139.98ms
step:184/5100 train_loss:4.5692 train_time:24354ms step_avg:139.97ms
step:185/5100 train_loss:4.5808 train_time:24499ms step_avg:139.99ms
step:186/5100 train_loss:4.7013 train_time:24635ms step_avg:139.97ms
step:187/5100 train_loss:4.6171 train_time:24776ms step_avg:139.98ms
step:188/5100 train_loss:4.8039 train_time:24915ms step_avg:139.97ms
step:189/5100 train_loss:4.6349 train_time:25214ms step_avg:140.86ms
step:190/5100 train_loss:4.5574 train_time:25528ms step_avg:141.82ms
step:191/5100 train_loss:4.6914 train_time:25663ms step_avg:141.78ms
step:192/5100 train_loss:4.5396 train_time:25801ms step_avg:141.76ms
step:193/5100 train_loss:4.4604 train_time:25938ms step_avg:141.74ms
step:194/5100 train_loss:4.6879 train_time:26076ms step_avg:141.72ms
step:195/5100 train_loss:4.6142 train_time:26214ms step_avg:141.70ms
step:196/5100 train_loss:4.8141 train_time:26351ms step_avg:141.67ms
step:197/5100 train_loss:4.6685 train_time:26497ms step_avg:141.69ms
step:198/5100 train_loss:4.5125 train_time:26639ms step_avg:141.69ms
step:199/5100 train_loss:4.5861 train_time:26788ms step_avg:141.73ms
step:200/5100 train_loss:4.4490 train_time:26917ms step_avg:141.67ms
step:201/5100 train_loss:4.5490 train_time:27059ms step_avg:141.67ms
step:202/5100 train_loss:4.4513 train_time:27194ms step_avg:141.64ms
step:203/5100 train_loss:4.6903 train_time:27333ms step_avg:141.62ms
step:204/5100 train_loss:4.5505 train_time:27476ms step_avg:141.63ms
step:205/5100 train_loss:4.5938 train_time:27617ms step_avg:141.63ms
step:206/5100 train_loss:4.7052 train_time:27760ms step_avg:141.63ms
step:207/5100 train_loss:4.3718 train_time:27897ms step_avg:141.61ms
step:208/5100 train_loss:4.5220 train_time:28035ms step_avg:141.59ms
step:209/5100 train_loss:4.5035 train_time:28174ms step_avg:141.58ms
step:210/5100 train_loss:4.6587 train_time:28314ms step_avg:141.57ms
step:211/5100 train_loss:4.5725 train_time:28455ms step_avg:141.57ms
step:212/5100 train_loss:4.4606 train_time:28598ms step_avg:141.57ms
step:213/5100 train_loss:4.5614 train_time:28739ms step_avg:141.57ms
step:214/5100 train_loss:4.4331 train_time:28879ms step_avg:141.56ms
step:215/5100 train_loss:4.5012 train_time:29017ms step_avg:141.55ms
step:216/5100 train_loss:4.3770 train_time:29156ms step_avg:141.53ms
step:217/5100 train_loss:4.4625 train_time:29295ms step_avg:141.52ms
step:218/5100 train_loss:4.4426 train_time:29435ms step_avg:141.51ms
step:219/5100 train_loss:4.4648 train_time:29583ms step_avg:141.54ms
step:220/5100 train_loss:4.4582 train_time:29723ms step_avg:141.54ms
step:221/5100 train_loss:4.4875 train_time:29859ms step_avg:141.51ms
step:222/5100 train_loss:4.5042 train_time:29997ms step_avg:141.50ms
step:223/5100 train_loss:4.4190 train_time:30135ms step_avg:141.48ms
step:224/5100 train_loss:4.4362 train_time:30274ms step_avg:141.47ms
step:225/5100 train_loss:4.6332 train_time:30414ms step_avg:141.46ms
step:226/5100 train_loss:4.2978 train_time:30555ms step_avg:141.46ms
step:227/5100 train_loss:4.3456 train_time:30697ms step_avg:141.46ms
step:228/5100 train_loss:4.3605 train_time:30837ms step_avg:141.45ms
step:229/5100 train_loss:4.5079 train_time:30977ms step_avg:141.45ms
step:230/5100 train_loss:4.3049 train_time:31115ms step_avg:141.43ms
step:231/5100 train_loss:4.4461 train_time:31254ms step_avg:141.42ms
step:232/5100 train_loss:4.3060 train_time:31395ms step_avg:141.42ms
step:233/5100 train_loss:4.3158 train_time:31535ms step_avg:141.41ms
step:234/5100 train_loss:4.4816 train_time:31675ms step_avg:141.41ms
step:235/5100 train_loss:4.3614 train_time:31815ms step_avg:141.40ms
step:236/5100 train_loss:4.2647 train_time:31955ms step_avg:141.40ms
step:237/5100 train_loss:4.4681 train_time:32096ms step_avg:141.39ms
step:238/5100 train_loss:4.4326 train_time:32234ms step_avg:141.38ms
step:239/5100 train_loss:4.2878 train_time:32374ms step_avg:141.37ms
step:240/5100 train_loss:4.4543 train_time:32515ms step_avg:141.37ms
step:241/5100 train_loss:4.4453 train_time:32655ms step_avg:141.36ms
step:242/5100 train_loss:4.3327 train_time:32795ms step_avg:141.36ms
step:243/5100 train_loss:4.5133 train_time:32935ms step_avg:141.35ms
step:244/5100 train_loss:4.3444 train_time:33075ms step_avg:141.35ms
step:245/5100 train_loss:4.3847 train_time:33215ms step_avg:141.34ms
step:246/5100 train_loss:4.4585 train_time:33353ms step_avg:141.33ms
step:247/5100 train_loss:4.3941 train_time:33497ms step_avg:141.34ms
step:248/5100 train_loss:4.3347 train_time:33634ms step_avg:141.32ms
step:249/5100 train_loss:4.4618 train_time:33774ms step_avg:141.32ms
step:250/5100 train_loss:4.2348 train_time:33915ms step_avg:141.31ms
step:250/5100 val_loss:4.3355 train_time:33971ms step_avg:141.54ms
step:251/5100 train_loss:4.2891 train_time:34068ms step_avg:141.36ms
step:252/5100 train_loss:4.4021 train_time:34210ms step_avg:141.36ms
step:253/5100 train_loss:4.4402 train_time:34349ms step_avg:141.35ms
step:254/5100 train_loss:4.2654 train_time:34494ms step_avg:141.37ms
step:255/5100 train_loss:4.2126 train_time:34628ms step_avg:141.34ms
step:256/5100 train_loss:4.3859 train_time:34767ms step_avg:141.33ms
step:257/5100 train_loss:4.3120 train_time:34910ms step_avg:141.34ms
step:258/5100 train_loss:4.3147 train_time:35048ms step_avg:141.32ms
step:259/5100 train_loss:4.2802 train_time:35192ms step_avg:141.33ms
step:260/5100 train_loss:4.3168 train_time:35333ms step_avg:141.33ms
step:261/5100 train_loss:4.3628 train_time:35477ms step_avg:141.34ms
step:262/5100 train_loss:4.3211 train_time:35611ms step_avg:141.31ms
step:263/5100 train_loss:4.2856 train_time:35754ms step_avg:141.32ms
step:264/5100 train_loss:4.2053 train_time:35890ms step_avg:141.30ms
step:265/5100 train_loss:4.2866 train_time:36031ms step_avg:141.30ms
step:266/5100 train_loss:4.1538 train_time:36173ms step_avg:141.30ms
step:267/5100 train_loss:4.2081 train_time:36314ms step_avg:141.30ms
step:268/5100 train_loss:4.2269 train_time:36452ms step_avg:141.29ms
step:269/5100 train_loss:4.2355 train_time:36591ms step_avg:141.28ms
step:270/5100 train_loss:4.1494 train_time:36734ms step_avg:141.29ms
step:271/5100 train_loss:4.3922 train_time:36870ms step_avg:141.27ms
step:272/5100 train_loss:4.2877 train_time:37010ms step_avg:141.26ms
step:273/5100 train_loss:4.1898 train_time:37151ms step_avg:141.26ms
step:274/5100 train_loss:4.2434 train_time:37292ms step_avg:141.26ms
step:275/5100 train_loss:4.3153 train_time:37432ms step_avg:141.25ms
step:276/5100 train_loss:4.3375 train_time:37574ms step_avg:141.26ms
step:277/5100 train_loss:4.5175 train_time:37712ms step_avg:141.24ms
step:278/5100 train_loss:4.3018 train_time:37856ms step_avg:141.25ms
step:279/5100 train_loss:4.3763 train_time:37993ms step_avg:141.24ms
step:280/5100 train_loss:4.2722 train_time:38135ms step_avg:141.24ms
step:281/5100 train_loss:4.4043 train_time:38273ms step_avg:141.23ms
step:282/5100 train_loss:4.2265 train_time:38413ms step_avg:141.22ms
step:283/5100 train_loss:4.2598 train_time:38553ms step_avg:141.22ms
step:284/5100 train_loss:4.1781 train_time:38693ms step_avg:141.21ms
step:285/5100 train_loss:4.3250 train_time:38833ms step_avg:141.21ms
step:286/5100 train_loss:4.3280 train_time:38972ms step_avg:141.20ms
step:287/5100 train_loss:4.3635 train_time:39116ms step_avg:141.21ms
step:288/5100 train_loss:4.1933 train_time:39254ms step_avg:141.20ms
step:289/5100 train_loss:4.2827 train_time:39392ms step_avg:141.19ms
step:290/5100 train_loss:4.1403 train_time:39532ms step_avg:141.18ms
step:291/5100 train_loss:4.1319 train_time:39672ms step_avg:141.18ms
step:292/5100 train_loss:4.2187 train_time:39812ms step_avg:141.18ms
step:293/5100 train_loss:4.1353 train_time:39952ms step_avg:141.17ms
step:294/5100 train_loss:4.1773 train_time:40092ms step_avg:141.17ms
step:295/5100 train_loss:4.2191 train_time:40233ms step_avg:141.17ms
step:296/5100 train_loss:4.1036 train_time:40373ms step_avg:141.16ms
step:297/5100 train_loss:4.1130 train_time:40512ms step_avg:141.16ms
step:298/5100 train_loss:4.1240 train_time:40651ms step_avg:141.15ms
step:299/5100 train_loss:4.2226 train_time:40801ms step_avg:141.18ms
step:300/5100 train_loss:4.0927 train_time:40933ms step_avg:141.15ms
step:301/5100 train_loss:4.2335 train_time:41073ms step_avg:141.14ms
step:302/5100 train_loss:4.2414 train_time:41215ms step_avg:141.15ms
step:303/5100 train_loss:4.1842 train_time:41352ms step_avg:141.13ms
step:304/5100 train_loss:4.2413 train_time:41495ms step_avg:141.14ms
step:305/5100 train_loss:4.2201 train_time:41631ms step_avg:141.12ms
step:306/5100 train_loss:4.6999 train_time:41771ms step_avg:141.12ms
step:307/5100 train_loss:4.1954 train_time:41912ms step_avg:141.12ms
step:308/5100 train_loss:4.0971 train_time:42052ms step_avg:141.11ms
step:309/5100 train_loss:4.2585 train_time:42193ms step_avg:141.11ms
step:310/5100 train_loss:4.1086 train_time:42333ms step_avg:141.11ms
step:311/5100 train_loss:4.3348 train_time:42477ms step_avg:141.12ms
step:312/5100 train_loss:4.1895 train_time:42612ms step_avg:141.10ms
step:313/5100 train_loss:4.1204 train_time:42751ms step_avg:141.09ms
step:314/5100 train_loss:4.2241 train_time:42892ms step_avg:141.09ms
step:315/5100 train_loss:4.3334 train_time:43031ms step_avg:141.09ms
step:316/5100 train_loss:4.2118 train_time:43172ms step_avg:141.09ms
step:317/5100 train_loss:4.0455 train_time:43315ms step_avg:141.09ms
step:318/5100 train_loss:4.1227 train_time:43453ms step_avg:141.08ms
step:319/5100 train_loss:4.1540 train_time:43593ms step_avg:141.08ms
step:320/5100 train_loss:4.1345 train_time:43732ms step_avg:141.07ms
step:321/5100 train_loss:4.2412 train_time:43872ms step_avg:141.07ms
step:322/5100 train_loss:4.1950 train_time:44013ms step_avg:141.07ms
step:323/5100 train_loss:4.1623 train_time:44152ms step_avg:141.06ms
step:324/5100 train_loss:4.2450 train_time:44293ms step_avg:141.06ms
step:325/5100 train_loss:4.2092 train_time:44433ms step_avg:141.06ms
step:326/5100 train_loss:4.2768 train_time:44573ms step_avg:141.05ms
step:327/5100 train_loss:4.1299 train_time:44713ms step_avg:141.05ms
step:328/5100 train_loss:4.6284 train_time:44853ms step_avg:141.05ms
step:329/5100 train_loss:4.3118 train_time:44992ms step_avg:141.04ms
step:330/5100 train_loss:4.0570 train_time:45132ms step_avg:141.04ms
step:331/5100 train_loss:3.9979 train_time:45272ms step_avg:141.03ms
step:332/5100 train_loss:4.2168 train_time:45414ms step_avg:141.04ms
step:333/5100 train_loss:4.1448 train_time:45553ms step_avg:141.03ms
step:334/5100 train_loss:4.1207 train_time:45693ms step_avg:141.03ms
step:335/5100 train_loss:4.0790 train_time:45832ms step_avg:141.02ms
step:336/5100 train_loss:4.2491 train_time:45972ms step_avg:141.02ms
step:337/5100 train_loss:4.1919 train_time:46113ms step_avg:141.02ms
step:338/5100 train_loss:4.6667 train_time:46252ms step_avg:141.01ms
step:339/5100 train_loss:4.1764 train_time:46394ms step_avg:141.01ms
step:340/5100 train_loss:4.1236 train_time:46533ms step_avg:141.01ms
step:341/5100 train_loss:4.1626 train_time:46673ms step_avg:141.01ms
step:342/5100 train_loss:4.0757 train_time:46812ms step_avg:141.00ms
step:343/5100 train_loss:4.0488 train_time:46952ms step_avg:141.00ms
step:344/5100 train_loss:4.0972 train_time:47092ms step_avg:140.99ms
step:345/5100 train_loss:4.2305 train_time:47232ms step_avg:140.99ms
step:346/5100 train_loss:4.0744 train_time:47372ms step_avg:140.99ms
step:347/5100 train_loss:4.0078 train_time:47513ms step_avg:140.99ms
step:348/5100 train_loss:4.0508 train_time:47652ms step_avg:140.98ms
step:349/5100 train_loss:4.0869 train_time:47792ms step_avg:140.98ms
step:350/5100 train_loss:4.0463 train_time:47932ms step_avg:140.98ms
step:351/5100 train_loss:3.7723 train_time:48073ms step_avg:140.98ms
step:352/5100 train_loss:4.0418 train_time:48214ms step_avg:140.98ms
step:353/5100 train_loss:4.3813 train_time:48355ms step_avg:140.98ms
step:354/5100 train_loss:3.8922 train_time:48493ms step_avg:140.97ms
step:355/5100 train_loss:4.1522 train_time:48633ms step_avg:140.96ms
step:356/5100 train_loss:4.0244 train_time:48772ms step_avg:140.96ms
step:357/5100 train_loss:4.1172 train_time:48916ms step_avg:140.97ms
step:358/5100 train_loss:4.0827 train_time:49052ms step_avg:140.95ms
step:359/5100 train_loss:4.0725 train_time:49193ms step_avg:140.96ms
step:360/5100 train_loss:4.1245 train_time:49332ms step_avg:140.95ms
step:361/5100 train_loss:3.6890 train_time:49472ms step_avg:140.95ms
step:362/5100 train_loss:4.2446 train_time:49613ms step_avg:140.95ms
step:363/5100 train_loss:4.1444 train_time:49752ms step_avg:140.94ms
step:364/5100 train_loss:4.0674 train_time:49893ms step_avg:140.94ms
step:365/5100 train_loss:3.9845 train_time:50033ms step_avg:140.94ms
step:366/5100 train_loss:4.1359 train_time:50172ms step_avg:140.93ms
step:367/5100 train_loss:4.0992 train_time:50312ms step_avg:140.93ms
step:368/5100 train_loss:4.0830 train_time:50452ms step_avg:140.93ms
step:369/5100 train_loss:4.0625 train_time:50593ms step_avg:140.93ms
step:370/5100 train_loss:3.9645 train_time:50732ms step_avg:140.92ms
step:371/5100 train_loss:4.1114 train_time:50872ms step_avg:140.92ms
step:372/5100 train_loss:3.9922 train_time:51012ms step_avg:140.92ms
step:373/5100 train_loss:3.9238 train_time:51151ms step_avg:140.91ms
step:374/5100 train_loss:4.1321 train_time:51292ms step_avg:140.91ms
step:375/5100 train_loss:4.0562 train_time:51434ms step_avg:140.91ms
step:375/5100 val_loss:4.0552 train_time:51489ms step_avg:141.07ms
step:376/5100 train_loss:4.0260 train_time:51581ms step_avg:140.93ms
step:377/5100 train_loss:4.0951 train_time:51732ms step_avg:140.96ms
step:378/5100 train_loss:4.0099 train_time:52070ms step_avg:141.49ms
step:379/5100 train_loss:4.0686 train_time:52205ms step_avg:141.48ms
step:380/5100 train_loss:4.0964 train_time:52516ms step_avg:141.94ms
step:381/5100 train_loss:4.1709 train_time:52654ms step_avg:141.92ms
step:382/5100 train_loss:4.0740 train_time:52791ms step_avg:141.91ms
step:383/5100 train_loss:4.0536 train_time:52929ms step_avg:141.90ms
step:384/5100 train_loss:4.0085 train_time:53067ms step_avg:141.89ms
step:385/5100 train_loss:4.0928 train_time:53205ms step_avg:141.88ms
step:386/5100 train_loss:4.0061 train_time:53344ms step_avg:141.87ms
step:387/5100 train_loss:4.1163 train_time:53493ms step_avg:141.89ms
step:388/5100 train_loss:4.3070 train_time:53635ms step_avg:141.89ms
step:389/5100 train_loss:4.0191 train_time:53774ms step_avg:141.88ms
step:390/5100 train_loss:4.0120 train_time:53913ms step_avg:141.88ms
step:391/5100 train_loss:4.1075 train_time:54051ms step_avg:141.87ms
step:392/5100 train_loss:4.0275 train_time:54189ms step_avg:141.86ms
step:393/5100 train_loss:4.1395 train_time:54329ms step_avg:141.85ms
step:394/5100 train_loss:3.9731 train_time:54473ms step_avg:141.86ms
step:395/5100 train_loss:4.1123 train_time:54615ms step_avg:141.86ms
step:396/5100 train_loss:3.8467 train_time:54756ms step_avg:141.86ms
step:397/5100 train_loss:4.0517 train_time:54896ms step_avg:141.85ms
step:398/5100 train_loss:4.1039 train_time:55034ms step_avg:141.84ms
step:399/5100 train_loss:4.1109 train_time:55173ms step_avg:141.83ms
step:400/5100 train_loss:3.9946 train_time:55313ms step_avg:141.83ms
step:401/5100 train_loss:4.0622 train_time:55453ms step_avg:141.82ms
step:402/5100 train_loss:4.1246 train_time:55594ms step_avg:141.82ms
step:403/5100 train_loss:4.0561 train_time:55733ms step_avg:141.82ms
step:404/5100 train_loss:4.1741 train_time:55871ms step_avg:141.80ms
step:405/5100 train_loss:3.9264 train_time:56013ms step_avg:141.81ms
step:406/5100 train_loss:4.0069 train_time:56149ms step_avg:141.79ms
step:407/5100 train_loss:4.2999 train_time:56289ms step_avg:141.79ms
step:408/5100 train_loss:4.0070 train_time:56430ms step_avg:141.79ms
step:409/5100 train_loss:4.0363 train_time:56571ms step_avg:141.78ms
step:410/5100 train_loss:4.0872 train_time:56714ms step_avg:141.79ms
step:411/5100 train_loss:3.9635 train_time:56850ms step_avg:141.77ms
step:412/5100 train_loss:3.9845 train_time:56989ms step_avg:141.76ms
step:413/5100 train_loss:4.4026 train_time:57128ms step_avg:141.76ms
step:414/5100 train_loss:3.8368 train_time:57268ms step_avg:141.75ms
step:415/5100 train_loss:4.2296 train_time:57409ms step_avg:141.75ms
step:416/5100 train_loss:3.9735 train_time:57549ms step_avg:141.75ms
step:417/5100 train_loss:3.9744 train_time:57690ms step_avg:141.75ms
step:418/5100 train_loss:4.1742 train_time:57830ms step_avg:141.74ms
step:419/5100 train_loss:3.9055 train_time:57973ms step_avg:141.74ms
step:420/5100 train_loss:4.0176 train_time:58117ms step_avg:141.75ms
step:421/5100 train_loss:3.9479 train_time:58248ms step_avg:141.72ms
step:422/5100 train_loss:3.8620 train_time:58390ms step_avg:141.72ms
step:423/5100 train_loss:3.9969 train_time:58532ms step_avg:141.72ms
step:424/5100 train_loss:4.0887 train_time:58670ms step_avg:141.72ms
step:425/5100 train_loss:3.8467 train_time:58810ms step_avg:141.71ms
step:426/5100 train_loss:4.0210 train_time:58949ms step_avg:141.70ms
step:427/5100 train_loss:3.9012 train_time:59088ms step_avg:141.70ms
step:428/5100 train_loss:4.1143 train_time:59228ms step_avg:141.69ms
step:429/5100 train_loss:4.0320 train_time:59369ms step_avg:141.69ms
step:430/5100 train_loss:3.9642 train_time:59509ms step_avg:141.69ms
step:431/5100 train_loss:3.9396 train_time:59649ms step_avg:141.68ms
step:432/5100 train_loss:3.8432 train_time:59790ms step_avg:141.68ms
step:433/5100 train_loss:3.9761 train_time:59929ms step_avg:141.68ms
step:434/5100 train_loss:4.0386 train_time:60069ms step_avg:141.67ms
step:435/5100 train_loss:3.9863 train_time:60209ms step_avg:141.67ms
step:436/5100 train_loss:4.0329 train_time:60349ms step_avg:141.66ms
step:437/5100 train_loss:4.0409 train_time:60489ms step_avg:141.66ms
step:438/5100 train_loss:3.9234 train_time:60629ms step_avg:141.66ms
step:439/5100 train_loss:3.9377 train_time:60770ms step_avg:141.65ms
step:440/5100 train_loss:3.9117 train_time:60910ms step_avg:141.65ms
step:441/5100 train_loss:4.0948 train_time:61049ms step_avg:141.65ms
step:442/5100 train_loss:3.9758 train_time:61189ms step_avg:141.64ms
step:443/5100 train_loss:3.9647 train_time:61329ms step_avg:141.64ms
step:444/5100 train_loss:3.8523 train_time:61469ms step_avg:141.63ms
step:445/5100 train_loss:4.1166 train_time:61608ms step_avg:141.63ms
step:446/5100 train_loss:4.0495 train_time:61749ms step_avg:141.63ms
step:447/5100 train_loss:4.0484 train_time:61892ms step_avg:141.63ms
step:448/5100 train_loss:3.9641 train_time:62030ms step_avg:141.62ms
step:449/5100 train_loss:4.0647 train_time:62174ms step_avg:141.63ms
step:450/5100 train_loss:3.8859 train_time:62310ms step_avg:141.61ms
step:451/5100 train_loss:3.9345 train_time:62452ms step_avg:141.61ms
step:452/5100 train_loss:3.7939 train_time:62590ms step_avg:141.61ms
step:453/5100 train_loss:3.9208 train_time:62732ms step_avg:141.61ms
step:454/5100 train_loss:3.8925 train_time:62872ms step_avg:141.60ms
step:455/5100 train_loss:3.8469 train_time:63009ms step_avg:141.59ms
step:456/5100 train_loss:4.0633 train_time:63149ms step_avg:141.59ms
step:457/5100 train_loss:3.9313 train_time:63291ms step_avg:141.59ms
step:458/5100 train_loss:4.0002 train_time:63433ms step_avg:141.59ms
step:459/5100 train_loss:4.0445 train_time:63570ms step_avg:141.58ms
step:460/5100 train_loss:3.8429 train_time:63709ms step_avg:141.58ms
step:461/5100 train_loss:4.0112 train_time:63849ms step_avg:141.57ms
step:462/5100 train_loss:3.9091 train_time:63988ms step_avg:141.57ms
step:463/5100 train_loss:3.9285 train_time:64135ms step_avg:141.58ms
step:464/5100 train_loss:3.9870 train_time:64270ms step_avg:141.56ms
step:465/5100 train_loss:3.9302 train_time:64410ms step_avg:141.56ms
step:466/5100 train_loss:3.9338 train_time:64550ms step_avg:141.56ms
step:467/5100 train_loss:4.0251 train_time:64690ms step_avg:141.55ms
step:468/5100 train_loss:4.0442 train_time:64837ms step_avg:141.57ms
step:469/5100 train_loss:4.0137 train_time:64974ms step_avg:141.56ms
step:470/5100 train_loss:3.9038 train_time:65109ms step_avg:141.54ms
step:471/5100 train_loss:3.9841 train_time:65247ms step_avg:141.53ms
step:472/5100 train_loss:4.0341 train_time:65389ms step_avg:141.53ms
step:473/5100 train_loss:3.9843 train_time:65529ms step_avg:141.53ms
step:474/5100 train_loss:3.9330 train_time:65669ms step_avg:141.53ms
step:475/5100 train_loss:3.7922 train_time:65810ms step_avg:141.53ms
step:476/5100 train_loss:4.2264 train_time:65949ms step_avg:141.52ms
step:477/5100 train_loss:3.9817 train_time:66089ms step_avg:141.52ms
step:478/5100 train_loss:3.7974 train_time:66228ms step_avg:141.51ms
step:479/5100 train_loss:4.0292 train_time:66369ms step_avg:141.51ms
step:480/5100 train_loss:3.9811 train_time:66513ms step_avg:141.52ms
step:481/5100 train_loss:4.1221 train_time:66653ms step_avg:141.51ms
step:482/5100 train_loss:3.9374 train_time:66790ms step_avg:141.50ms
step:483/5100 train_loss:3.7387 train_time:66928ms step_avg:141.50ms
step:484/5100 train_loss:4.0224 train_time:67069ms step_avg:141.50ms
step:485/5100 train_loss:3.8790 train_time:67210ms step_avg:141.49ms
step:486/5100 train_loss:3.8833 train_time:67349ms step_avg:141.49ms
step:487/5100 train_loss:3.8192 train_time:67490ms step_avg:141.49ms
step:488/5100 train_loss:3.8852 train_time:67629ms step_avg:141.48ms
step:489/5100 train_loss:4.0872 train_time:67769ms step_avg:141.48ms
step:490/5100 train_loss:3.9289 train_time:67909ms step_avg:141.48ms
step:491/5100 train_loss:3.8181 train_time:68049ms step_avg:141.47ms
step:492/5100 train_loss:3.8315 train_time:68192ms step_avg:141.48ms
step:493/5100 train_loss:3.9424 train_time:68328ms step_avg:141.47ms
step:494/5100 train_loss:3.7857 train_time:68469ms step_avg:141.46ms
step:495/5100 train_loss:3.9200 train_time:68609ms step_avg:141.46ms
step:496/5100 train_loss:3.8583 train_time:68751ms step_avg:141.46ms
step:497/5100 train_loss:3.7406 train_time:68890ms step_avg:141.46ms
step:498/5100 train_loss:3.9405 train_time:69029ms step_avg:141.45ms
step:499/5100 train_loss:4.0187 train_time:69170ms step_avg:141.45ms
step:500/5100 train_loss:4.0402 train_time:69310ms step_avg:141.45ms
step:500/5100 val_loss:3.9204 train_time:69365ms step_avg:141.56ms
step:501/5100 train_loss:3.9507 train_time:69462ms step_avg:141.47ms
step:502/5100 train_loss:4.0073 train_time:69605ms step_avg:141.47ms
step:503/5100 train_loss:3.9528 train_time:69745ms step_avg:141.47ms
step:504/5100 train_loss:3.9941 train_time:69884ms step_avg:141.47ms
step:505/5100 train_loss:3.9432 train_time:70023ms step_avg:141.46ms
step:506/5100 train_loss:4.0270 train_time:70161ms step_avg:141.45ms
step:507/5100 train_loss:3.8448 train_time:70300ms step_avg:141.45ms
step:508/5100 train_loss:3.9728 train_time:70442ms step_avg:141.45ms
step:509/5100 train_loss:4.0496 train_time:70586ms step_avg:141.46ms
step:510/5100 train_loss:3.9824 train_time:70730ms step_avg:141.46ms
step:511/5100 train_loss:3.7835 train_time:70866ms step_avg:141.45ms
step:512/5100 train_loss:3.9888 train_time:71005ms step_avg:141.44ms
step:513/5100 train_loss:3.9315 train_time:71143ms step_avg:141.44ms
step:514/5100 train_loss:3.8872 train_time:71283ms step_avg:141.43ms
step:515/5100 train_loss:3.9598 train_time:71425ms step_avg:141.44ms
step:516/5100 train_loss:3.9512 train_time:71566ms step_avg:141.44ms
step:517/5100 train_loss:4.2909 train_time:71707ms step_avg:141.43ms
step:518/5100 train_loss:3.8911 train_time:71848ms step_avg:141.43ms
step:519/5100 train_loss:3.9982 train_time:71985ms step_avg:141.42ms
step:520/5100 train_loss:3.8968 train_time:72125ms step_avg:141.42ms
step:521/5100 train_loss:3.8945 train_time:72264ms step_avg:141.42ms
step:522/5100 train_loss:3.8440 train_time:72406ms step_avg:141.42ms
step:523/5100 train_loss:3.8653 train_time:72547ms step_avg:141.42ms
step:524/5100 train_loss:4.4959 train_time:72687ms step_avg:141.41ms
step:525/5100 train_loss:3.9553 train_time:72827ms step_avg:141.41ms
step:526/5100 train_loss:3.8912 train_time:72966ms step_avg:141.41ms
step:527/5100 train_loss:3.9068 train_time:73106ms step_avg:141.41ms
step:528/5100 train_loss:3.8550 train_time:73245ms step_avg:141.40ms
step:529/5100 train_loss:3.8325 train_time:73386ms step_avg:141.40ms
step:530/5100 train_loss:4.0531 train_time:73526ms step_avg:141.40ms
step:531/5100 train_loss:3.8505 train_time:73666ms step_avg:141.39ms
step:532/5100 train_loss:4.1327 train_time:73806ms step_avg:141.39ms
step:533/5100 train_loss:3.9411 train_time:73945ms step_avg:141.39ms
step:534/5100 train_loss:3.8674 train_time:74085ms step_avg:141.38ms
step:535/5100 train_loss:3.8905 train_time:74226ms step_avg:141.38ms
step:536/5100 train_loss:3.8228 train_time:74366ms step_avg:141.38ms
step:537/5100 train_loss:3.9557 train_time:74507ms step_avg:141.38ms
step:538/5100 train_loss:3.9436 train_time:74649ms step_avg:141.38ms
step:539/5100 train_loss:3.8429 train_time:74786ms step_avg:141.37ms
step:540/5100 train_loss:4.3383 train_time:74926ms step_avg:141.37ms
step:541/5100 train_loss:3.8767 train_time:75065ms step_avg:141.37ms
step:542/5100 train_loss:3.9860 train_time:75209ms step_avg:141.37ms
step:543/5100 train_loss:3.8107 train_time:75346ms step_avg:141.36ms
step:544/5100 train_loss:3.7872 train_time:75486ms step_avg:141.36ms
step:545/5100 train_loss:3.8745 train_time:75627ms step_avg:141.36ms
step:546/5100 train_loss:3.7999 train_time:75766ms step_avg:141.35ms
step:547/5100 train_loss:3.8454 train_time:75906ms step_avg:141.35ms
step:548/5100 train_loss:3.8552 train_time:76046ms step_avg:141.35ms
step:549/5100 train_loss:3.8338 train_time:76186ms step_avg:141.35ms
step:550/5100 train_loss:3.9334 train_time:76325ms step_avg:141.34ms
step:551/5100 train_loss:3.8129 train_time:76471ms step_avg:141.35ms
step:552/5100 train_loss:3.8369 train_time:76606ms step_avg:141.34ms
step:553/5100 train_loss:4.1618 train_time:76746ms step_avg:141.34ms
step:554/5100 train_loss:3.9534 train_time:76888ms step_avg:141.34ms
step:555/5100 train_loss:3.9152 train_time:77026ms step_avg:141.33ms
step:556/5100 train_loss:3.8642 train_time:77164ms step_avg:141.33ms
step:557/5100 train_loss:3.8953 train_time:77305ms step_avg:141.33ms
step:558/5100 train_loss:3.5651 train_time:77449ms step_avg:141.33ms
step:559/5100 train_loss:3.8142 train_time:77589ms step_avg:141.33ms
step:560/5100 train_loss:3.8594 train_time:77727ms step_avg:141.32ms
step:561/5100 train_loss:3.9030 train_time:77865ms step_avg:141.32ms
step:562/5100 train_loss:3.8139 train_time:78006ms step_avg:141.32ms
step:563/5100 train_loss:3.7542 train_time:78146ms step_avg:141.31ms
step:564/5100 train_loss:3.9678 train_time:78286ms step_avg:141.31ms
step:565/5100 train_loss:3.7757 train_time:78425ms step_avg:141.31ms
step:566/5100 train_loss:3.8928 train_time:78569ms step_avg:141.31ms
step:567/5100 train_loss:3.8354 train_time:78882ms step_avg:141.62ms
step:568/5100 train_loss:3.7896 train_time:79020ms step_avg:141.61ms
step:569/5100 train_loss:3.8920 train_time:79157ms step_avg:141.60ms
step:570/5100 train_loss:3.8630 train_time:79461ms step_avg:141.90ms
step:571/5100 train_loss:3.8906 train_time:79601ms step_avg:141.89ms
step:572/5100 train_loss:3.9722 train_time:79742ms step_avg:141.89ms
step:573/5100 train_loss:3.9244 train_time:79880ms step_avg:141.88ms
step:574/5100 train_loss:3.9315 train_time:80020ms step_avg:141.88ms
step:575/5100 train_loss:3.9757 train_time:80159ms step_avg:141.87ms
step:576/5100 train_loss:3.9393 train_time:80299ms step_avg:141.87ms
step:577/5100 train_loss:3.9574 train_time:80443ms step_avg:141.88ms
step:578/5100 train_loss:3.8873 train_time:80584ms step_avg:141.87ms
step:579/5100 train_loss:3.8797 train_time:80725ms step_avg:141.87ms
step:580/5100 train_loss:3.8638 train_time:80863ms step_avg:141.87ms
step:581/5100 train_loss:3.8101 train_time:81003ms step_avg:141.86ms
step:582/5100 train_loss:3.8353 train_time:81143ms step_avg:141.86ms
step:583/5100 train_loss:4.0683 train_time:81283ms step_avg:141.85ms
step:584/5100 train_loss:3.8330 train_time:81425ms step_avg:141.86ms
step:585/5100 train_loss:3.7905 train_time:81568ms step_avg:141.86ms
step:586/5100 train_loss:3.9843 train_time:81706ms step_avg:141.85ms
step:587/5100 train_loss:3.7319 train_time:81845ms step_avg:141.85ms
step:588/5100 train_loss:3.8730 train_time:81983ms step_avg:141.84ms
step:589/5100 train_loss:3.8548 train_time:82123ms step_avg:141.84ms
step:590/5100 train_loss:4.2131 train_time:82263ms step_avg:141.83ms
step:591/5100 train_loss:3.9857 train_time:82410ms step_avg:141.84ms
step:592/5100 train_loss:3.7274 train_time:82546ms step_avg:141.83ms
step:593/5100 train_loss:3.7401 train_time:82686ms step_avg:141.83ms
step:594/5100 train_loss:3.7358 train_time:82825ms step_avg:141.82ms
step:595/5100 train_loss:3.7741 train_time:82968ms step_avg:141.83ms
step:596/5100 train_loss:4.1370 train_time:83106ms step_avg:141.82ms
step:597/5100 train_loss:3.8574 train_time:83246ms step_avg:141.82ms
step:598/5100 train_loss:3.7888 train_time:83386ms step_avg:141.81ms
step:599/5100 train_loss:3.8656 train_time:83527ms step_avg:141.81ms
step:600/5100 train_loss:3.6860 train_time:83667ms step_avg:141.81ms
step:601/5100 train_loss:3.8091 train_time:83807ms step_avg:141.81ms
step:602/5100 train_loss:3.8394 train_time:83946ms step_avg:141.80ms
step:603/5100 train_loss:3.8614 train_time:84088ms step_avg:141.80ms
step:604/5100 train_loss:3.9828 train_time:84226ms step_avg:141.79ms
step:605/5100 train_loss:3.8357 train_time:84365ms step_avg:141.79ms
step:606/5100 train_loss:3.8205 train_time:84506ms step_avg:141.79ms
step:607/5100 train_loss:3.7751 train_time:84646ms step_avg:141.79ms
step:608/5100 train_loss:4.0278 train_time:84786ms step_avg:141.78ms
step:609/5100 train_loss:3.8578 train_time:84926ms step_avg:141.78ms
step:610/5100 train_loss:3.8196 train_time:85065ms step_avg:141.77ms
step:611/5100 train_loss:3.9191 train_time:85206ms step_avg:141.77ms
step:612/5100 train_loss:3.8273 train_time:85345ms step_avg:141.77ms
step:613/5100 train_loss:3.8043 train_time:85486ms step_avg:141.77ms
step:614/5100 train_loss:3.9688 train_time:85627ms step_avg:141.77ms
step:615/5100 train_loss:3.9255 train_time:85766ms step_avg:141.76ms
step:616/5100 train_loss:3.8896 train_time:85906ms step_avg:141.76ms
step:617/5100 train_loss:3.8215 train_time:86045ms step_avg:141.76ms
step:618/5100 train_loss:3.7801 train_time:86186ms step_avg:141.75ms
step:619/5100 train_loss:3.8838 train_time:86326ms step_avg:141.75ms
step:620/5100 train_loss:3.7822 train_time:86466ms step_avg:141.75ms
step:621/5100 train_loss:3.7935 train_time:86606ms step_avg:141.74ms
step:622/5100 train_loss:4.1066 train_time:86746ms step_avg:141.74ms
step:623/5100 train_loss:3.7943 train_time:86887ms step_avg:141.74ms
step:624/5100 train_loss:3.8225 train_time:87026ms step_avg:141.74ms
step:625/5100 train_loss:3.9021 train_time:87165ms step_avg:141.73ms
step:625/5100 val_loss:3.8317 train_time:87222ms step_avg:141.83ms
step:626/5100 train_loss:3.9298 train_time:87317ms step_avg:141.75ms
step:627/5100 train_loss:3.9540 train_time:87463ms step_avg:141.76ms
step:628/5100 train_loss:3.9337 train_time:87603ms step_avg:141.75ms
step:629/5100 train_loss:3.9750 train_time:87741ms step_avg:141.75ms
step:630/5100 train_loss:3.8030 train_time:87880ms step_avg:141.74ms
step:631/5100 train_loss:3.9237 train_time:88018ms step_avg:141.74ms
step:632/5100 train_loss:3.9637 train_time:88156ms step_avg:141.73ms
step:633/5100 train_loss:3.8594 train_time:88295ms step_avg:141.73ms
step:634/5100 train_loss:3.7933 train_time:88437ms step_avg:141.73ms
step:635/5100 train_loss:3.8958 train_time:88580ms step_avg:141.73ms
step:636/5100 train_loss:4.1482 train_time:88719ms step_avg:141.72ms
step:637/5100 train_loss:3.7345 train_time:88858ms step_avg:141.72ms
step:638/5100 train_loss:3.5581 train_time:88997ms step_avg:141.72ms
step:639/5100 train_loss:3.7846 train_time:89136ms step_avg:141.71ms
step:640/5100 train_loss:3.8274 train_time:89278ms step_avg:141.71ms
step:641/5100 train_loss:3.7799 train_time:89417ms step_avg:141.71ms
step:642/5100 train_loss:3.7817 train_time:89559ms step_avg:141.71ms
step:643/5100 train_loss:3.8237 train_time:89698ms step_avg:141.70ms
step:644/5100 train_loss:3.8318 train_time:89838ms step_avg:141.70ms
step:645/5100 train_loss:3.7634 train_time:89978ms step_avg:141.70ms
step:646/5100 train_loss:3.9748 train_time:90118ms step_avg:141.70ms
step:647/5100 train_loss:3.8779 train_time:90259ms step_avg:141.69ms
step:648/5100 train_loss:3.8694 train_time:90398ms step_avg:141.69ms
step:649/5100 train_loss:3.9052 train_time:90538ms step_avg:141.69ms
step:650/5100 train_loss:3.9695 train_time:90680ms step_avg:141.69ms
step:651/5100 train_loss:3.8255 train_time:90821ms step_avg:141.69ms
step:652/5100 train_loss:3.9701 train_time:90960ms step_avg:141.68ms
step:653/5100 train_loss:3.7852 train_time:91099ms step_avg:141.68ms
step:654/5100 train_loss:3.8625 train_time:91240ms step_avg:141.68ms
step:655/5100 train_loss:3.6271 train_time:91379ms step_avg:141.67ms
step:656/5100 train_loss:3.7742 train_time:91519ms step_avg:141.67ms
step:657/5100 train_loss:3.7787 train_time:91659ms step_avg:141.67ms
step:658/5100 train_loss:3.7123 train_time:91799ms step_avg:141.67ms
step:659/5100 train_loss:3.8926 train_time:91939ms step_avg:141.66ms
step:660/5100 train_loss:3.7916 train_time:92078ms step_avg:141.66ms
step:661/5100 train_loss:3.8855 train_time:92218ms step_avg:141.66ms
step:662/5100 train_loss:3.9593 train_time:92357ms step_avg:141.65ms
step:663/5100 train_loss:3.8684 train_time:92497ms step_avg:141.65ms
step:664/5100 train_loss:3.7526 train_time:92637ms step_avg:141.65ms
step:665/5100 train_loss:3.8409 train_time:92784ms step_avg:141.66ms
step:666/5100 train_loss:3.7017 train_time:92918ms step_avg:141.64ms
step:667/5100 train_loss:3.9923 train_time:93058ms step_avg:141.64ms
step:668/5100 train_loss:3.8199 train_time:93197ms step_avg:141.64ms
step:669/5100 train_loss:3.8310 train_time:93336ms step_avg:141.63ms
step:670/5100 train_loss:3.6792 train_time:93480ms step_avg:141.64ms
step:671/5100 train_loss:3.7969 train_time:93617ms step_avg:141.63ms
step:672/5100 train_loss:3.7625 train_time:93757ms step_avg:141.63ms
step:673/5100 train_loss:3.7755 train_time:93897ms step_avg:141.62ms
step:674/5100 train_loss:4.0541 train_time:94036ms step_avg:141.62ms
step:675/5100 train_loss:3.8490 train_time:94175ms step_avg:141.62ms
step:676/5100 train_loss:3.9114 train_time:94315ms step_avg:141.61ms
step:677/5100 train_loss:3.6913 train_time:94455ms step_avg:141.61ms
step:678/5100 train_loss:3.8017 train_time:94594ms step_avg:141.61ms
step:679/5100 train_loss:3.7409 train_time:94732ms step_avg:141.60ms
step:680/5100 train_loss:3.8821 train_time:94872ms step_avg:141.60ms
step:681/5100 train_loss:3.7904 train_time:95013ms step_avg:141.60ms
step:682/5100 train_loss:3.8192 train_time:95152ms step_avg:141.60ms
step:683/5100 train_loss:3.8926 train_time:95292ms step_avg:141.59ms
step:684/5100 train_loss:3.9384 train_time:95431ms step_avg:141.59ms
step:685/5100 train_loss:3.8270 train_time:95569ms step_avg:141.58ms
step:686/5100 train_loss:3.9083 train_time:95711ms step_avg:141.58ms
step:687/5100 train_loss:3.8354 train_time:95850ms step_avg:141.58ms
step:688/5100 train_loss:3.8802 train_time:95990ms step_avg:141.58ms
step:689/5100 train_loss:3.4959 train_time:96132ms step_avg:141.58ms
step:690/5100 train_loss:3.6178 train_time:96271ms step_avg:141.57ms
step:691/5100 train_loss:3.7622 train_time:96410ms step_avg:141.57ms
step:692/5100 train_loss:3.6332 train_time:96550ms step_avg:141.57ms
step:693/5100 train_loss:3.8498 train_time:96691ms step_avg:141.57ms
step:694/5100 train_loss:3.8689 train_time:96831ms step_avg:141.57ms
step:695/5100 train_loss:3.7536 train_time:96970ms step_avg:141.56ms
step:696/5100 train_loss:3.7423 train_time:97111ms step_avg:141.56ms
step:697/5100 train_loss:4.0628 train_time:97251ms step_avg:141.56ms
step:698/5100 train_loss:3.8063 train_time:97391ms step_avg:141.56ms
step:699/5100 train_loss:3.8434 train_time:97531ms step_avg:141.55ms
step:700/5100 train_loss:4.0026 train_time:97669ms step_avg:141.55ms
step:701/5100 train_loss:3.7775 train_time:97811ms step_avg:141.55ms
step:702/5100 train_loss:3.7379 train_time:97951ms step_avg:141.55ms
step:703/5100 train_loss:3.7244 train_time:98091ms step_avg:141.54ms
step:704/5100 train_loss:3.6761 train_time:98231ms step_avg:141.54ms
step:705/5100 train_loss:3.7669 train_time:98371ms step_avg:141.54ms
step:706/5100 train_loss:3.7589 train_time:98511ms step_avg:141.54ms
step:707/5100 train_loss:3.7787 train_time:98650ms step_avg:141.54ms
step:708/5100 train_loss:3.8499 train_time:98790ms step_avg:141.53ms
step:709/5100 train_loss:3.7977 train_time:98932ms step_avg:141.53ms
step:710/5100 train_loss:3.7823 train_time:99071ms step_avg:141.53ms
step:711/5100 train_loss:3.7541 train_time:99211ms step_avg:141.53ms
step:712/5100 train_loss:3.7925 train_time:99351ms step_avg:141.53ms
step:713/5100 train_loss:3.8488 train_time:99491ms step_avg:141.52ms
step:714/5100 train_loss:3.8593 train_time:99633ms step_avg:141.52ms
step:715/5100 train_loss:3.7682 train_time:99770ms step_avg:141.52ms
step:716/5100 train_loss:3.7731 train_time:99912ms step_avg:141.52ms
step:717/5100 train_loss:3.7845 train_time:100051ms step_avg:141.51ms
step:718/5100 train_loss:3.9362 train_time:100193ms step_avg:141.52ms
step:719/5100 train_loss:3.7949 train_time:100331ms step_avg:141.51ms
step:720/5100 train_loss:3.8723 train_time:100470ms step_avg:141.51ms
step:721/5100 train_loss:4.0347 train_time:100610ms step_avg:141.51ms
step:722/5100 train_loss:3.6654 train_time:100752ms step_avg:141.51ms
step:723/5100 train_loss:3.9260 train_time:100891ms step_avg:141.50ms
step:724/5100 train_loss:3.9832 train_time:101032ms step_avg:141.50ms
step:725/5100 train_loss:3.7683 train_time:101171ms step_avg:141.50ms
step:726/5100 train_loss:3.8529 train_time:101312ms step_avg:141.50ms
step:727/5100 train_loss:3.7478 train_time:101451ms step_avg:141.49ms
step:728/5100 train_loss:3.7641 train_time:101591ms step_avg:141.49ms
step:729/5100 train_loss:3.9380 train_time:101731ms step_avg:141.49ms
step:730/5100 train_loss:3.8837 train_time:101871ms step_avg:141.49ms
step:731/5100 train_loss:3.8763 train_time:102011ms step_avg:141.48ms
step:732/5100 train_loss:3.7671 train_time:102151ms step_avg:141.48ms
step:733/5100 train_loss:3.7919 train_time:102291ms step_avg:141.48ms
step:734/5100 train_loss:4.0278 train_time:102430ms step_avg:141.48ms
step:735/5100 train_loss:3.7613 train_time:102570ms step_avg:141.48ms
step:736/5100 train_loss:3.8195 train_time:102711ms step_avg:141.48ms
step:737/5100 train_loss:3.9407 train_time:102854ms step_avg:141.48ms
step:738/5100 train_loss:3.8572 train_time:102990ms step_avg:141.47ms
step:739/5100 train_loss:3.8047 train_time:103130ms step_avg:141.47ms
step:740/5100 train_loss:3.6975 train_time:103270ms step_avg:141.47ms
step:741/5100 train_loss:4.3352 train_time:103411ms step_avg:141.47ms
step:742/5100 train_loss:3.7007 train_time:103550ms step_avg:141.46ms
step:743/5100 train_loss:3.7834 train_time:103691ms step_avg:141.46ms
step:744/5100 train_loss:3.7843 train_time:103832ms step_avg:141.46ms
step:745/5100 train_loss:3.8472 train_time:103974ms step_avg:141.46ms
step:746/5100 train_loss:3.8163 train_time:104111ms step_avg:141.46ms
step:747/5100 train_loss:3.7961 train_time:104251ms step_avg:141.45ms
step:748/5100 train_loss:3.8331 train_time:104390ms step_avg:141.45ms
step:749/5100 train_loss:3.7640 train_time:104531ms step_avg:141.45ms
step:750/5100 train_loss:3.7665 train_time:104670ms step_avg:141.45ms
step:750/5100 val_loss:3.7730 train_time:104728ms step_avg:141.52ms
step:751/5100 train_loss:3.7971 train_time:104823ms step_avg:141.46ms
step:752/5100 train_loss:3.7665 train_time:104969ms step_avg:141.47ms
step:753/5100 train_loss:3.8063 train_time:105109ms step_avg:141.47ms
step:754/5100 train_loss:3.8216 train_time:105249ms step_avg:141.46ms
step:755/5100 train_loss:3.7917 train_time:105387ms step_avg:141.46ms
step:756/5100 train_loss:3.8598 train_time:105693ms step_avg:141.68ms
step:757/5100 train_loss:3.6898 train_time:105831ms step_avg:141.67ms
step:758/5100 train_loss:3.9301 train_time:105969ms step_avg:141.67ms
step:759/5100 train_loss:3.8442 train_time:106109ms step_avg:141.67ms
step:760/5100 train_loss:3.7761 train_time:106405ms step_avg:141.87ms
step:761/5100 train_loss:3.8897 train_time:106545ms step_avg:141.87ms
step:762/5100 train_loss:3.6041 train_time:106683ms step_avg:141.87ms
step:763/5100 train_loss:3.7544 train_time:106822ms step_avg:141.86ms
step:764/5100 train_loss:3.8695 train_time:106960ms step_avg:141.86ms
step:765/5100 train_loss:3.5134 train_time:107098ms step_avg:141.85ms
step:766/5100 train_loss:3.9445 train_time:107236ms step_avg:141.85ms
step:767/5100 train_loss:3.7929 train_time:107383ms step_avg:141.85ms
step:768/5100 train_loss:3.7581 train_time:107523ms step_avg:141.85ms
step:769/5100 train_loss:3.7747 train_time:107665ms step_avg:141.85ms
step:770/5100 train_loss:3.7963 train_time:107802ms step_avg:141.84ms
step:771/5100 train_loss:3.8533 train_time:107944ms step_avg:141.84ms
step:772/5100 train_loss:4.0837 train_time:108079ms step_avg:141.84ms
step:773/5100 train_loss:3.6587 train_time:108217ms step_avg:141.83ms
step:774/5100 train_loss:3.8546 train_time:108358ms step_avg:141.83ms
step:775/5100 train_loss:3.8390 train_time:108499ms step_avg:141.83ms
step:776/5100 train_loss:3.8069 train_time:108638ms step_avg:141.82ms
step:777/5100 train_loss:3.6123 train_time:108779ms step_avg:141.82ms
step:778/5100 train_loss:3.6081 train_time:108916ms step_avg:141.82ms
step:779/5100 train_loss:3.6753 train_time:109055ms step_avg:141.81ms
step:780/5100 train_loss:3.7701 train_time:109194ms step_avg:141.81ms
step:781/5100 train_loss:3.8040 train_time:109336ms step_avg:141.81ms
step:782/5100 train_loss:3.8606 train_time:109476ms step_avg:141.81ms
step:783/5100 train_loss:3.7715 train_time:109616ms step_avg:141.81ms
step:784/5100 train_loss:3.7718 train_time:109755ms step_avg:141.80ms
step:785/5100 train_loss:3.7757 train_time:109898ms step_avg:141.80ms
step:786/5100 train_loss:3.7515 train_time:110034ms step_avg:141.80ms
step:787/5100 train_loss:3.6520 train_time:110173ms step_avg:141.79ms
step:788/5100 train_loss:3.9572 train_time:110315ms step_avg:141.79ms
step:789/5100 train_loss:3.6992 train_time:110456ms step_avg:141.79ms
step:790/5100 train_loss:3.7611 train_time:110596ms step_avg:141.79ms
step:791/5100 train_loss:3.8216 train_time:110735ms step_avg:141.79ms
step:792/5100 train_loss:3.9616 train_time:110874ms step_avg:141.78ms
step:793/5100 train_loss:3.9672 train_time:111014ms step_avg:141.78ms
step:794/5100 train_loss:3.6727 train_time:111154ms step_avg:141.78ms
step:795/5100 train_loss:3.7968 train_time:111294ms step_avg:141.78ms
step:796/5100 train_loss:3.8553 train_time:111439ms step_avg:141.78ms
step:797/5100 train_loss:3.9523 train_time:111574ms step_avg:141.77ms
step:798/5100 train_loss:3.7095 train_time:111715ms step_avg:141.77ms
step:799/5100 train_loss:3.8619 train_time:111855ms step_avg:141.77ms
step:800/5100 train_loss:3.7457 train_time:111995ms step_avg:141.77ms
step:801/5100 train_loss:3.7445 train_time:112135ms step_avg:141.76ms
step:802/5100 train_loss:3.8338 train_time:112274ms step_avg:141.76ms
step:803/5100 train_loss:3.6943 train_time:112415ms step_avg:141.76ms
step:804/5100 train_loss:3.7260 train_time:112554ms step_avg:141.76ms
step:805/5100 train_loss:3.8346 train_time:112695ms step_avg:141.75ms
step:806/5100 train_loss:3.7335 train_time:112834ms step_avg:141.75ms
step:807/5100 train_loss:3.7450 train_time:112974ms step_avg:141.75ms
step:808/5100 train_loss:3.8408 train_time:113114ms step_avg:141.75ms
step:809/5100 train_loss:3.7609 train_time:113255ms step_avg:141.75ms
step:810/5100 train_loss:3.6853 train_time:113395ms step_avg:141.74ms
step:811/5100 train_loss:3.7612 train_time:113534ms step_avg:141.74ms
step:812/5100 train_loss:3.8039 train_time:113674ms step_avg:141.74ms
step:813/5100 train_loss:3.7942 train_time:113815ms step_avg:141.74ms
step:814/5100 train_loss:3.8289 train_time:113954ms step_avg:141.73ms
step:815/5100 train_loss:3.7746 train_time:114094ms step_avg:141.73ms
step:816/5100 train_loss:3.7557 train_time:114236ms step_avg:141.73ms
step:817/5100 train_loss:3.8602 train_time:114378ms step_avg:141.73ms
step:818/5100 train_loss:3.9595 train_time:114515ms step_avg:141.73ms
step:819/5100 train_loss:3.7179 train_time:114654ms step_avg:141.72ms
step:820/5100 train_loss:3.9279 train_time:114794ms step_avg:141.72ms
step:821/5100 train_loss:3.7026 train_time:114936ms step_avg:141.72ms
step:822/5100 train_loss:3.7507 train_time:115073ms step_avg:141.72ms
step:823/5100 train_loss:3.8709 train_time:115216ms step_avg:141.72ms
step:824/5100 train_loss:3.7827 train_time:115355ms step_avg:141.71ms
step:825/5100 train_loss:3.7099 train_time:115495ms step_avg:141.71ms
step:826/5100 train_loss:3.8110 train_time:115635ms step_avg:141.71ms
step:827/5100 train_loss:3.7044 train_time:115775ms step_avg:141.71ms
step:828/5100 train_loss:3.9309 train_time:115915ms step_avg:141.71ms
step:829/5100 train_loss:3.8248 train_time:116054ms step_avg:141.70ms
step:830/5100 train_loss:3.8766 train_time:116195ms step_avg:141.70ms
step:831/5100 train_loss:3.7307 train_time:116342ms step_avg:141.71ms
step:832/5100 train_loss:3.7855 train_time:116475ms step_avg:141.70ms
step:833/5100 train_loss:3.7173 train_time:116617ms step_avg:141.70ms
step:834/5100 train_loss:3.8391 train_time:116755ms step_avg:141.69ms
step:835/5100 train_loss:3.6803 train_time:116895ms step_avg:141.69ms
step:836/5100 train_loss:3.6561 train_time:117034ms step_avg:141.69ms
step:837/5100 train_loss:3.9273 train_time:117174ms step_avg:141.69ms
step:838/5100 train_loss:3.6237 train_time:117315ms step_avg:141.68ms
step:839/5100 train_loss:3.7892 train_time:117455ms step_avg:141.68ms
step:840/5100 train_loss:3.6307 train_time:117595ms step_avg:141.68ms
step:841/5100 train_loss:3.6678 train_time:117735ms step_avg:141.68ms
step:842/5100 train_loss:3.7667 train_time:117874ms step_avg:141.68ms
step:843/5100 train_loss:3.7807 train_time:118015ms step_avg:141.67ms
step:844/5100 train_loss:3.7751 train_time:118155ms step_avg:141.67ms
step:845/5100 train_loss:3.6269 train_time:118295ms step_avg:141.67ms
step:846/5100 train_loss:3.8624 train_time:118435ms step_avg:141.67ms
step:847/5100 train_loss:3.7313 train_time:118577ms step_avg:141.67ms
step:848/5100 train_loss:3.6884 train_time:118715ms step_avg:141.66ms
step:849/5100 train_loss:3.8275 train_time:118855ms step_avg:141.66ms
step:850/5100 train_loss:3.6908 train_time:118995ms step_avg:141.66ms
step:851/5100 train_loss:3.6422 train_time:119135ms step_avg:141.66ms
step:852/5100 train_loss:3.9449 train_time:119275ms step_avg:141.66ms
step:853/5100 train_loss:3.6456 train_time:119415ms step_avg:141.66ms
step:854/5100 train_loss:3.7692 train_time:119555ms step_avg:141.65ms
step:855/5100 train_loss:3.8482 train_time:119695ms step_avg:141.65ms
step:856/5100 train_loss:3.7354 train_time:119835ms step_avg:141.65ms
step:857/5100 train_loss:3.7519 train_time:119980ms step_avg:141.65ms
step:858/5100 train_loss:3.7942 train_time:120115ms step_avg:141.65ms
step:859/5100 train_loss:3.6840 train_time:120254ms step_avg:141.64ms
step:860/5100 train_loss:3.7591 train_time:120395ms step_avg:141.64ms
step:861/5100 train_loss:3.7932 train_time:120536ms step_avg:141.64ms
step:862/5100 train_loss:3.8494 train_time:120675ms step_avg:141.64ms
step:863/5100 train_loss:3.7911 train_time:120815ms step_avg:141.64ms
step:864/5100 train_loss:3.7688 train_time:120954ms step_avg:141.63ms
step:865/5100 train_loss:3.5917 train_time:121095ms step_avg:141.63ms
step:866/5100 train_loss:3.7888 train_time:121234ms step_avg:141.63ms
step:867/5100 train_loss:4.0670 train_time:121374ms step_avg:141.63ms
step:868/5100 train_loss:3.6435 train_time:121515ms step_avg:141.63ms
step:869/5100 train_loss:3.8330 train_time:121656ms step_avg:141.63ms
step:870/5100 train_loss:3.8115 train_time:121796ms step_avg:141.62ms
step:871/5100 train_loss:3.6479 train_time:121935ms step_avg:141.62ms
step:872/5100 train_loss:3.6183 train_time:122075ms step_avg:141.62ms
step:873/5100 train_loss:3.8645 train_time:122216ms step_avg:141.62ms
step:874/5100 train_loss:3.6511 train_time:122354ms step_avg:141.61ms
step:875/5100 train_loss:3.3741 train_time:122494ms step_avg:141.61ms
step:875/5100 val_loss:3.7258 train_time:122552ms step_avg:141.68ms
step:876/5100 train_loss:3.8437 train_time:122649ms step_avg:141.63ms
step:877/5100 train_loss:3.6484 train_time:122795ms step_avg:141.63ms
step:878/5100 train_loss:3.8231 train_time:122935ms step_avg:141.63ms
step:879/5100 train_loss:3.6792 train_time:123070ms step_avg:141.62ms
step:880/5100 train_loss:3.8630 train_time:123208ms step_avg:141.62ms
step:881/5100 train_loss:3.5199 train_time:123346ms step_avg:141.61ms
step:882/5100 train_loss:3.6984 train_time:123484ms step_avg:141.61ms
step:883/5100 train_loss:3.8895 train_time:123628ms step_avg:141.61ms
step:884/5100 train_loss:4.0437 train_time:123773ms step_avg:141.62ms
step:885/5100 train_loss:3.7685 train_time:123914ms step_avg:141.62ms
step:886/5100 train_loss:3.6846 train_time:124051ms step_avg:141.61ms
step:887/5100 train_loss:3.7788 train_time:124191ms step_avg:141.61ms
step:888/5100 train_loss:4.2840 train_time:124330ms step_avg:141.61ms
step:889/5100 train_loss:4.0333 train_time:124468ms step_avg:141.60ms
step:890/5100 train_loss:3.7165 train_time:124608ms step_avg:141.60ms
step:891/5100 train_loss:3.7359 train_time:124752ms step_avg:141.60ms
step:892/5100 train_loss:3.5611 train_time:124892ms step_avg:141.60ms
step:893/5100 train_loss:3.9078 train_time:125032ms step_avg:141.60ms
step:894/5100 train_loss:3.6301 train_time:125170ms step_avg:141.59ms
step:895/5100 train_loss:3.8780 train_time:125308ms step_avg:141.59ms
step:896/5100 train_loss:3.8919 train_time:125447ms step_avg:141.59ms
step:897/5100 train_loss:3.6939 train_time:125586ms step_avg:141.59ms
step:898/5100 train_loss:3.7311 train_time:125728ms step_avg:141.59ms
step:899/5100 train_loss:3.7912 train_time:125871ms step_avg:141.59ms
step:900/5100 train_loss:3.6759 train_time:126009ms step_avg:141.58ms
step:901/5100 train_loss:3.6175 train_time:126149ms step_avg:141.58ms
step:902/5100 train_loss:3.8265 train_time:126287ms step_avg:141.58ms
step:903/5100 train_loss:3.8327 train_time:126427ms step_avg:141.58ms
step:904/5100 train_loss:3.7296 train_time:126571ms step_avg:141.58ms
step:905/5100 train_loss:3.7078 train_time:126708ms step_avg:141.57ms
step:906/5100 train_loss:3.6946 train_time:126849ms step_avg:141.57ms
step:907/5100 train_loss:3.9211 train_time:126988ms step_avg:141.57ms
step:908/5100 train_loss:3.7083 train_time:127128ms step_avg:141.57ms
step:909/5100 train_loss:3.7473 train_time:127268ms step_avg:141.57ms
step:910/5100 train_loss:3.6554 train_time:127407ms step_avg:141.56ms
step:911/5100 train_loss:3.7423 train_time:127547ms step_avg:141.56ms
step:912/5100 train_loss:3.8216 train_time:127690ms step_avg:141.56ms
step:913/5100 train_loss:3.8126 train_time:127828ms step_avg:141.56ms
step:914/5100 train_loss:3.6793 train_time:127967ms step_avg:141.56ms
step:915/5100 train_loss:3.9363 train_time:128107ms step_avg:141.56ms
step:916/5100 train_loss:3.7301 train_time:128251ms step_avg:141.56ms
step:917/5100 train_loss:3.8250 train_time:128386ms step_avg:141.55ms
step:918/5100 train_loss:3.7984 train_time:128527ms step_avg:141.55ms
step:919/5100 train_loss:5.0211 train_time:128667ms step_avg:141.55ms
step:920/5100 train_loss:3.7165 train_time:128807ms step_avg:141.55ms
step:921/5100 train_loss:3.7673 train_time:128949ms step_avg:141.55ms
step:922/5100 train_loss:3.7324 train_time:129088ms step_avg:141.54ms
step:923/5100 train_loss:3.7878 train_time:129228ms step_avg:141.54ms
step:924/5100 train_loss:3.7974 train_time:129367ms step_avg:141.54ms
step:925/5100 train_loss:3.8836 train_time:129508ms step_avg:141.54ms
step:926/5100 train_loss:3.8627 train_time:129648ms step_avg:141.54ms
step:927/5100 train_loss:3.7503 train_time:129787ms step_avg:141.53ms
step:928/5100 train_loss:3.7424 train_time:129929ms step_avg:141.53ms
step:929/5100 train_loss:3.9688 train_time:130067ms step_avg:141.53ms
step:930/5100 train_loss:3.8148 train_time:130207ms step_avg:141.53ms
step:931/5100 train_loss:3.5956 train_time:130348ms step_avg:141.53ms
step:932/5100 train_loss:3.6867 train_time:130488ms step_avg:141.53ms
step:933/5100 train_loss:3.8753 train_time:130628ms step_avg:141.52ms
step:934/5100 train_loss:3.5851 train_time:130768ms step_avg:141.52ms
step:935/5100 train_loss:3.7720 train_time:130908ms step_avg:141.52ms
step:936/5100 train_loss:3.6475 train_time:131049ms step_avg:141.52ms
step:937/5100 train_loss:3.7108 train_time:131188ms step_avg:141.52ms
step:938/5100 train_loss:3.8082 train_time:131328ms step_avg:141.52ms
step:939/5100 train_loss:3.7381 train_time:131467ms step_avg:141.51ms
step:940/5100 train_loss:3.9036 train_time:131608ms step_avg:141.51ms
step:941/5100 train_loss:3.6817 train_time:131747ms step_avg:141.51ms
step:942/5100 train_loss:3.7512 train_time:131887ms step_avg:141.51ms
step:943/5100 train_loss:3.5445 train_time:132028ms step_avg:141.51ms
step:944/5100 train_loss:3.9000 train_time:132167ms step_avg:141.51ms
step:945/5100 train_loss:3.6082 train_time:132478ms step_avg:141.69ms
step:946/5100 train_loss:3.6221 train_time:132614ms step_avg:141.68ms
step:947/5100 train_loss:5.2429 train_time:132753ms step_avg:141.68ms
step:948/5100 train_loss:3.7912 train_time:132893ms step_avg:141.68ms
step:949/5100 train_loss:3.6976 train_time:133031ms step_avg:141.67ms
step:950/5100 train_loss:3.5893 train_time:133329ms step_avg:141.84ms
step:951/5100 train_loss:3.6496 train_time:133466ms step_avg:141.83ms
step:952/5100 train_loss:3.6039 train_time:133604ms step_avg:141.83ms
step:953/5100 train_loss:3.6766 train_time:133742ms step_avg:141.83ms
step:954/5100 train_loss:3.7581 train_time:133882ms step_avg:141.82ms
step:955/5100 train_loss:3.6337 train_time:134020ms step_avg:141.82ms
step:956/5100 train_loss:3.6683 train_time:134158ms step_avg:141.82ms
step:957/5100 train_loss:3.6350 train_time:134305ms step_avg:141.82ms
step:958/5100 train_loss:3.7005 train_time:134448ms step_avg:141.82ms
step:959/5100 train_loss:3.6936 train_time:134586ms step_avg:141.82ms
step:960/5100 train_loss:3.7097 train_time:134726ms step_avg:141.82ms
step:961/5100 train_loss:3.5871 train_time:134865ms step_avg:141.81ms
step:962/5100 train_loss:3.8522 train_time:135005ms step_avg:141.81ms
step:963/5100 train_loss:3.8025 train_time:135147ms step_avg:141.81ms
step:964/5100 train_loss:3.5975 train_time:135288ms step_avg:141.81ms
step:965/5100 train_loss:3.6446 train_time:135428ms step_avg:141.81ms
step:966/5100 train_loss:3.6829 train_time:135568ms step_avg:141.81ms
step:967/5100 train_loss:3.9021 train_time:135707ms step_avg:141.80ms
step:968/5100 train_loss:3.7271 train_time:135846ms step_avg:141.80ms
step:969/5100 train_loss:3.7153 train_time:135984ms step_avg:141.80ms
step:970/5100 train_loss:3.7703 train_time:136126ms step_avg:141.80ms
step:971/5100 train_loss:3.5818 train_time:136268ms step_avg:141.80ms
step:972/5100 train_loss:3.7459 train_time:136408ms step_avg:141.80ms
step:973/5100 train_loss:3.6911 train_time:136548ms step_avg:141.79ms
step:974/5100 train_loss:3.7366 train_time:136687ms step_avg:141.79ms
step:975/5100 train_loss:3.8063 train_time:136828ms step_avg:141.79ms
step:976/5100 train_loss:3.6882 train_time:136967ms step_avg:141.79ms
step:977/5100 train_loss:3.8823 train_time:137107ms step_avg:141.79ms
step:978/5100 train_loss:3.7668 train_time:137249ms step_avg:141.79ms
step:979/5100 train_loss:3.5931 train_time:137389ms step_avg:141.78ms
step:980/5100 train_loss:3.8833 train_time:137529ms step_avg:141.78ms
step:981/5100 train_loss:3.6167 train_time:137668ms step_avg:141.78ms
step:982/5100 train_loss:3.7827 train_time:137808ms step_avg:141.78ms
step:983/5100 train_loss:3.7619 train_time:137951ms step_avg:141.78ms
step:984/5100 train_loss:3.7690 train_time:138090ms step_avg:141.78ms
step:985/5100 train_loss:3.7051 train_time:138229ms step_avg:141.77ms
step:986/5100 train_loss:3.7914 train_time:138369ms step_avg:141.77ms
step:987/5100 train_loss:3.6108 train_time:138509ms step_avg:141.77ms
step:988/5100 train_loss:3.6889 train_time:138649ms step_avg:141.77ms
step:989/5100 train_loss:3.6988 train_time:138788ms step_avg:141.76ms
step:990/5100 train_loss:3.6224 train_time:138932ms step_avg:141.77ms
step:991/5100 train_loss:3.8495 train_time:139066ms step_avg:141.76ms
step:992/5100 train_loss:3.6666 train_time:139210ms step_avg:141.76ms
step:993/5100 train_loss:3.6374 train_time:139351ms step_avg:141.76ms
step:994/5100 train_loss:3.7144 train_time:139488ms step_avg:141.76ms
step:995/5100 train_loss:3.7955 train_time:139628ms step_avg:141.75ms
step:996/5100 train_loss:3.7363 train_time:139767ms step_avg:141.75ms
step:997/5100 train_loss:3.6545 train_time:139908ms step_avg:141.75ms
step:998/5100 train_loss:4.0035 train_time:140048ms step_avg:141.75ms
step:999/5100 train_loss:3.6603 train_time:140188ms step_avg:141.75ms
step:1000/5100 train_loss:3.7906 train_time:140328ms step_avg:141.75ms
step:1000/5100 val_loss:3.6837 train_time:140385ms step_avg:141.80ms
step:1001/5100 train_loss:3.6573 train_time:140479ms step_avg:141.75ms
step:1002/5100 train_loss:3.7040 train_time:140627ms step_avg:141.76ms
step:1003/5100 train_loss:3.5904 train_time:140767ms step_avg:141.76ms
step:1004/5100 train_loss:3.7757 train_time:140906ms step_avg:141.76ms
step:1005/5100 train_loss:3.8327 train_time:141045ms step_avg:141.75ms
step:1006/5100 train_loss:3.5993 train_time:141184ms step_avg:141.75ms
step:1007/5100 train_loss:3.6758 train_time:141324ms step_avg:141.75ms
step:1008/5100 train_loss:3.6493 train_time:141468ms step_avg:141.75ms
step:1009/5100 train_loss:3.7648 train_time:141611ms step_avg:141.75ms
step:1010/5100 train_loss:3.8724 train_time:141751ms step_avg:141.75ms
step:1011/5100 train_loss:3.7649 train_time:141890ms step_avg:141.75ms
step:1012/5100 train_loss:3.7282 train_time:142029ms step_avg:141.75ms
step:1013/5100 train_loss:3.5879 train_time:142168ms step_avg:141.74ms
step:1014/5100 train_loss:3.7354 train_time:142308ms step_avg:141.74ms
step:1015/5100 train_loss:3.8516 train_time:142449ms step_avg:141.74ms
step:1016/5100 train_loss:3.5530 train_time:142590ms step_avg:141.74ms
step:1017/5100 train_loss:3.6459 train_time:142730ms step_avg:141.74ms
step:1018/5100 train_loss:3.6444 train_time:142870ms step_avg:141.74ms
step:1019/5100 train_loss:3.5910 train_time:143009ms step_avg:141.73ms
step:1020/5100 train_loss:3.7351 train_time:143149ms step_avg:141.73ms
step:1021/5100 train_loss:3.6441 train_time:143288ms step_avg:141.73ms
step:1022/5100 train_loss:3.5735 train_time:143430ms step_avg:141.73ms
step:1023/5100 train_loss:3.6935 train_time:143574ms step_avg:141.73ms
step:1024/5100 train_loss:3.7131 train_time:143711ms step_avg:141.73ms
step:1025/5100 train_loss:3.6875 train_time:143851ms step_avg:141.72ms
step:1026/5100 train_loss:3.7029 train_time:143990ms step_avg:141.72ms
step:1027/5100 train_loss:3.8710 train_time:144130ms step_avg:141.72ms
step:1028/5100 train_loss:3.5447 train_time:144269ms step_avg:141.72ms
step:1029/5100 train_loss:3.6065 train_time:144409ms step_avg:141.72ms
step:1030/5100 train_loss:3.5614 train_time:144549ms step_avg:141.72ms
step:1031/5100 train_loss:3.7276 train_time:144689ms step_avg:141.71ms
step:1032/5100 train_loss:3.7140 train_time:144830ms step_avg:141.71ms
step:1033/5100 train_loss:3.8925 train_time:144970ms step_avg:141.71ms
step:1034/5100 train_loss:3.7036 train_time:145110ms step_avg:141.71ms
step:1035/5100 train_loss:3.6276 train_time:145250ms step_avg:141.71ms
step:1036/5100 train_loss:3.6469 train_time:145389ms step_avg:141.70ms
step:1037/5100 train_loss:3.7110 train_time:145531ms step_avg:141.70ms
step:1038/5100 train_loss:4.0282 train_time:145670ms step_avg:141.70ms
step:1039/5100 train_loss:3.8419 train_time:145810ms step_avg:141.70ms
step:1040/5100 train_loss:3.7326 train_time:145950ms step_avg:141.70ms
step:1041/5100 train_loss:3.6287 train_time:146089ms step_avg:141.70ms
step:1042/5100 train_loss:3.6975 train_time:146230ms step_avg:141.70ms
step:1043/5100 train_loss:3.7342 train_time:146369ms step_avg:141.69ms
step:1044/5100 train_loss:3.6684 train_time:146510ms step_avg:141.69ms
step:1045/5100 train_loss:3.6749 train_time:146650ms step_avg:141.69ms
step:1046/5100 train_loss:3.7546 train_time:146789ms step_avg:141.69ms
step:1047/5100 train_loss:3.6563 train_time:146930ms step_avg:141.69ms
step:1048/5100 train_loss:3.8625 train_time:147070ms step_avg:141.69ms
step:1049/5100 train_loss:3.7198 train_time:147210ms step_avg:141.68ms
step:1050/5100 train_loss:3.6341 train_time:147350ms step_avg:141.68ms
step:1051/5100 train_loss:3.6040 train_time:147491ms step_avg:141.68ms
step:1052/5100 train_loss:3.7217 train_time:147631ms step_avg:141.68ms
step:1053/5100 train_loss:3.6021 train_time:147770ms step_avg:141.68ms
step:1054/5100 train_loss:3.9256 train_time:147909ms step_avg:141.68ms
step:1055/5100 train_loss:3.7607 train_time:148050ms step_avg:141.67ms
step:1056/5100 train_loss:3.6203 train_time:148189ms step_avg:141.67ms
step:1057/5100 train_loss:3.7179 train_time:148330ms step_avg:141.67ms
step:1058/5100 train_loss:3.7970 train_time:148470ms step_avg:141.67ms
step:1059/5100 train_loss:3.5239 train_time:148610ms step_avg:141.67ms
step:1060/5100 train_loss:3.6442 train_time:148750ms step_avg:141.67ms
step:1061/5100 train_loss:3.6687 train_time:148889ms step_avg:141.66ms
step:1062/5100 train_loss:3.6329 train_time:149030ms step_avg:141.66ms
step:1063/5100 train_loss:3.6060 train_time:149169ms step_avg:141.66ms
step:1064/5100 train_loss:3.7120 train_time:149310ms step_avg:141.66ms
step:1065/5100 train_loss:3.6086 train_time:149451ms step_avg:141.66ms
step:1066/5100 train_loss:3.5941 train_time:149590ms step_avg:141.66ms
step:1067/5100 train_loss:3.6175 train_time:149731ms step_avg:141.66ms
step:1068/5100 train_loss:3.5364 train_time:149870ms step_avg:141.65ms
step:1069/5100 train_loss:3.6443 train_time:150010ms step_avg:141.65ms
step:1070/5100 train_loss:3.5176 train_time:150150ms step_avg:141.65ms
step:1071/5100 train_loss:3.7801 train_time:150289ms step_avg:141.65ms
step:1072/5100 train_loss:3.7239 train_time:150430ms step_avg:141.65ms
step:1073/5100 train_loss:3.6806 train_time:150570ms step_avg:141.65ms
step:1074/5100 train_loss:3.7397 train_time:150711ms step_avg:141.65ms
step:1075/5100 train_loss:3.6845 train_time:150850ms step_avg:141.64ms
step:1076/5100 train_loss:3.6227 train_time:150989ms step_avg:141.64ms
step:1077/5100 train_loss:4.0134 train_time:151130ms step_avg:141.64ms
step:1078/5100 train_loss:3.6905 train_time:151269ms step_avg:141.64ms
step:1079/5100 train_loss:3.3604 train_time:151410ms step_avg:141.64ms
step:1080/5100 train_loss:3.7491 train_time:151550ms step_avg:141.64ms
step:1081/5100 train_loss:3.6747 train_time:151690ms step_avg:141.63ms
step:1082/5100 train_loss:3.7417 train_time:151831ms step_avg:141.63ms
step:1083/5100 train_loss:3.8361 train_time:151970ms step_avg:141.63ms
step:1084/5100 train_loss:3.7305 train_time:152110ms step_avg:141.63ms
step:1085/5100 train_loss:3.7053 train_time:152250ms step_avg:141.63ms
step:1086/5100 train_loss:3.6673 train_time:152389ms step_avg:141.63ms
step:1087/5100 train_loss:3.8649 train_time:152530ms step_avg:141.62ms
step:1088/5100 train_loss:3.7584 train_time:152670ms step_avg:141.62ms
step:1089/5100 train_loss:3.5808 train_time:152810ms step_avg:141.62ms
step:1090/5100 train_loss:3.6079 train_time:152950ms step_avg:141.62ms
step:1091/5100 train_loss:3.7303 train_time:153089ms step_avg:141.62ms
step:1092/5100 train_loss:3.5214 train_time:153230ms step_avg:141.62ms
step:1093/5100 train_loss:3.7219 train_time:153369ms step_avg:141.61ms
step:1094/5100 train_loss:3.8525 train_time:153510ms step_avg:141.61ms
step:1095/5100 train_loss:3.6864 train_time:153650ms step_avg:141.61ms
step:1096/5100 train_loss:3.6412 train_time:153790ms step_avg:141.61ms
step:1097/5100 train_loss:3.6684 train_time:153931ms step_avg:141.61ms
step:1098/5100 train_loss:3.7150 train_time:154070ms step_avg:141.61ms
step:1099/5100 train_loss:3.7917 train_time:154210ms step_avg:141.61ms
step:1100/5100 train_loss:3.7388 train_time:154350ms step_avg:141.61ms
step:1101/5100 train_loss:3.6778 train_time:154490ms step_avg:141.60ms
step:1102/5100 train_loss:3.5296 train_time:154630ms step_avg:141.60ms
step:1103/5100 train_loss:3.5977 train_time:154769ms step_avg:141.60ms
step:1104/5100 train_loss:3.6813 train_time:154910ms step_avg:141.60ms
step:1105/5100 train_loss:3.5551 train_time:155050ms step_avg:141.60ms
step:1106/5100 train_loss:4.3093 train_time:155190ms step_avg:141.60ms
step:1107/5100 train_loss:3.4603 train_time:155331ms step_avg:141.60ms
step:1108/5100 train_loss:3.8040 train_time:155470ms step_avg:141.59ms
step:1109/5100 train_loss:3.5841 train_time:155611ms step_avg:141.59ms
step:1110/5100 train_loss:3.7319 train_time:155751ms step_avg:141.59ms
step:1111/5100 train_loss:3.6652 train_time:155890ms step_avg:141.59ms
step:1112/5100 train_loss:3.7047 train_time:156031ms step_avg:141.59ms
step:1113/5100 train_loss:3.8021 train_time:156170ms step_avg:141.59ms
step:1114/5100 train_loss:3.6638 train_time:156310ms step_avg:141.59ms
step:1115/5100 train_loss:3.5935 train_time:156450ms step_avg:141.58ms
step:1116/5100 train_loss:3.5017 train_time:156590ms step_avg:141.58ms
step:1117/5100 train_loss:3.6672 train_time:156736ms step_avg:141.59ms
step:1118/5100 train_loss:3.8267 train_time:156869ms step_avg:141.58ms
step:1119/5100 train_loss:3.8599 train_time:157010ms step_avg:141.58ms
step:1120/5100 train_loss:3.6989 train_time:157150ms step_avg:141.58ms
step:1121/5100 train_loss:3.7273 train_time:157289ms step_avg:141.57ms
step:1122/5100 train_loss:3.6277 train_time:157429ms step_avg:141.57ms
step:1123/5100 train_loss:3.6888 train_time:157570ms step_avg:141.57ms
step:1124/5100 train_loss:3.8250 train_time:157709ms step_avg:141.57ms
step:1125/5100 train_loss:3.5967 train_time:157850ms step_avg:141.57ms
step:1125/5100 val_loss:3.6555 train_time:157906ms step_avg:141.62ms
step:1126/5100 train_loss:3.4938 train_time:158002ms step_avg:141.58ms
step:1127/5100 train_loss:3.7151 train_time:158148ms step_avg:141.58ms
step:1128/5100 train_loss:3.9323 train_time:158286ms step_avg:141.58ms
step:1129/5100 train_loss:3.4743 train_time:158429ms step_avg:141.58ms
step:1130/5100 train_loss:3.7934 train_time:158561ms step_avg:141.57ms
step:1131/5100 train_loss:3.6236 train_time:158700ms step_avg:141.57ms
step:1132/5100 train_loss:3.6515 train_time:158838ms step_avg:141.57ms
step:1133/5100 train_loss:3.6088 train_time:158978ms step_avg:141.57ms
step:1134/5100 train_loss:3.7695 train_time:159286ms step_avg:141.71ms
step:1135/5100 train_loss:3.7005 train_time:159423ms step_avg:141.71ms
step:1136/5100 train_loss:3.7472 train_time:159561ms step_avg:141.71ms
step:1137/5100 train_loss:3.7823 train_time:159701ms step_avg:141.70ms
step:1138/5100 train_loss:3.6961 train_time:159839ms step_avg:141.70ms
step:1139/5100 train_loss:3.5968 train_time:159977ms step_avg:141.70ms
step:1140/5100 train_loss:3.9069 train_time:160283ms step_avg:141.84ms
step:1141/5100 train_loss:3.7070 train_time:160423ms step_avg:141.84ms
step:1142/5100 train_loss:3.8098 train_time:160560ms step_avg:141.84ms
step:1143/5100 train_loss:3.6955 train_time:160698ms step_avg:141.83ms
step:1144/5100 train_loss:3.6066 train_time:160837ms step_avg:141.83ms
step:1145/5100 train_loss:3.7038 train_time:160976ms step_avg:141.83ms
step:1146/5100 train_loss:3.8314 train_time:161115ms step_avg:141.83ms
step:1147/5100 train_loss:3.8045 train_time:161262ms step_avg:141.83ms
step:1148/5100 train_loss:3.7304 train_time:161405ms step_avg:141.83ms
step:1149/5100 train_loss:3.7375 train_time:161546ms step_avg:141.83ms
step:1150/5100 train_loss:3.5864 train_time:161687ms step_avg:141.83ms
step:1151/5100 train_loss:3.6075 train_time:161823ms step_avg:141.83ms
step:1152/5100 train_loss:3.5720 train_time:161961ms step_avg:141.82ms
step:1153/5100 train_loss:3.7236 train_time:162100ms step_avg:141.82ms
step:1154/5100 train_loss:3.6921 train_time:162240ms step_avg:141.82ms
step:1155/5100 train_loss:3.7546 train_time:162382ms step_avg:141.82ms
step:1156/5100 train_loss:3.6045 train_time:162522ms step_avg:141.82ms
step:1157/5100 train_loss:3.7742 train_time:162665ms step_avg:141.82ms
step:1158/5100 train_loss:3.7319 train_time:162801ms step_avg:141.81ms
step:1159/5100 train_loss:3.5514 train_time:162939ms step_avg:141.81ms
step:1160/5100 train_loss:3.5745 train_time:163080ms step_avg:141.81ms
step:1161/5100 train_loss:3.5650 train_time:163217ms step_avg:141.80ms
step:1162/5100 train_loss:3.3809 train_time:163360ms step_avg:141.81ms
step:1163/5100 train_loss:3.6829 train_time:163501ms step_avg:141.80ms
step:1164/5100 train_loss:3.6545 train_time:163641ms step_avg:141.80ms
step:1165/5100 train_loss:3.5223 train_time:163779ms step_avg:141.80ms
step:1166/5100 train_loss:3.5155 train_time:163922ms step_avg:141.80ms
step:1167/5100 train_loss:3.6195 train_time:164060ms step_avg:141.80ms
step:1168/5100 train_loss:3.6331 train_time:164197ms step_avg:141.79ms
step:1169/5100 train_loss:3.9523 train_time:164339ms step_avg:141.79ms
step:1170/5100 train_loss:3.6362 train_time:164480ms step_avg:141.79ms
step:1171/5100 train_loss:3.6355 train_time:164620ms step_avg:141.79ms
step:1172/5100 train_loss:3.5727 train_time:164759ms step_avg:141.79ms
step:1173/5100 train_loss:3.6521 train_time:164897ms step_avg:141.79ms
step:1174/5100 train_loss:3.7862 train_time:165036ms step_avg:141.78ms
step:1175/5100 train_loss:3.6231 train_time:165176ms step_avg:141.78ms
step:1176/5100 train_loss:3.6508 train_time:165318ms step_avg:141.78ms
step:1177/5100 train_loss:3.6973 train_time:165459ms step_avg:141.78ms
step:1178/5100 train_loss:3.6804 train_time:165599ms step_avg:141.78ms
step:1179/5100 train_loss:3.7403 train_time:165739ms step_avg:141.78ms
step:1180/5100 train_loss:3.6476 train_time:165878ms step_avg:141.78ms
step:1181/5100 train_loss:3.6538 train_time:166018ms step_avg:141.77ms
step:1182/5100 train_loss:3.6040 train_time:166158ms step_avg:141.77ms
step:1183/5100 train_loss:3.6503 train_time:166300ms step_avg:141.77ms
step:1184/5100 train_loss:3.5738 train_time:166439ms step_avg:141.77ms
step:1185/5100 train_loss:3.7502 train_time:166579ms step_avg:141.77ms
step:1186/5100 train_loss:3.8070 train_time:166719ms step_avg:141.77ms
step:1187/5100 train_loss:3.6001 train_time:166858ms step_avg:141.77ms
step:1188/5100 train_loss:3.6564 train_time:166997ms step_avg:141.76ms
step:1189/5100 train_loss:3.6774 train_time:167138ms step_avg:141.76ms
step:1190/5100 train_loss:3.5199 train_time:167277ms step_avg:141.76ms
step:1191/5100 train_loss:3.7027 train_time:167420ms step_avg:141.76ms
step:1192/5100 train_loss:3.8426 train_time:167559ms step_avg:141.76ms
step:1193/5100 train_loss:3.6401 train_time:167698ms step_avg:141.76ms
step:1194/5100 train_loss:3.5285 train_time:167839ms step_avg:141.76ms
step:1195/5100 train_loss:3.8179 train_time:167978ms step_avg:141.75ms
step:1196/5100 train_loss:3.6217 train_time:168118ms step_avg:141.75ms
step:1197/5100 train_loss:3.6349 train_time:168258ms step_avg:141.75ms
step:1198/5100 train_loss:3.5297 train_time:168398ms step_avg:141.75ms
step:1199/5100 train_loss:3.5438 train_time:168539ms step_avg:141.75ms
step:1200/5100 train_loss:3.5969 train_time:168678ms step_avg:141.75ms
step:1201/5100 train_loss:3.6783 train_time:168818ms step_avg:141.74ms
step:1202/5100 train_loss:3.7522 train_time:168960ms step_avg:141.74ms
step:1203/5100 train_loss:3.8374 train_time:169098ms step_avg:141.74ms
step:1204/5100 train_loss:3.6647 train_time:169238ms step_avg:141.74ms
step:1205/5100 train_loss:3.5883 train_time:169378ms step_avg:141.74ms
step:1206/5100 train_loss:3.6697 train_time:169518ms step_avg:141.74ms
step:1207/5100 train_loss:3.7211 train_time:169661ms step_avg:141.74ms
step:1208/5100 train_loss:3.7718 train_time:169798ms step_avg:141.73ms
step:1209/5100 train_loss:3.6520 train_time:169938ms step_avg:141.73ms
step:1210/5100 train_loss:3.5050 train_time:170079ms step_avg:141.73ms
step:1211/5100 train_loss:3.5546 train_time:170219ms step_avg:141.73ms
step:1212/5100 train_loss:3.6551 train_time:170359ms step_avg:141.73ms
step:1213/5100 train_loss:3.6682 train_time:170498ms step_avg:141.73ms
step:1214/5100 train_loss:3.7019 train_time:170638ms step_avg:141.73ms
step:1215/5100 train_loss:3.6014 train_time:170778ms step_avg:141.72ms
step:1216/5100 train_loss:3.6504 train_time:170918ms step_avg:141.72ms
step:1217/5100 train_loss:3.5950 train_time:171059ms step_avg:141.72ms
step:1218/5100 train_loss:3.5797 train_time:171197ms step_avg:141.72ms
step:1219/5100 train_loss:3.6801 train_time:171339ms step_avg:141.72ms
step:1220/5100 train_loss:3.5332 train_time:171478ms step_avg:141.72ms
step:1221/5100 train_loss:3.7471 train_time:171618ms step_avg:141.72ms
step:1222/5100 train_loss:3.7710 train_time:171758ms step_avg:141.71ms
step:1223/5100 train_loss:3.6990 train_time:171897ms step_avg:141.71ms
step:1224/5100 train_loss:3.5490 train_time:172040ms step_avg:141.71ms
step:1225/5100 train_loss:3.5431 train_time:172179ms step_avg:141.71ms
step:1226/5100 train_loss:3.6197 train_time:172319ms step_avg:141.71ms
step:1227/5100 train_loss:3.6024 train_time:172460ms step_avg:141.71ms
step:1228/5100 train_loss:3.5413 train_time:172599ms step_avg:141.71ms
step:1229/5100 train_loss:3.7124 train_time:172739ms step_avg:141.71ms
step:1230/5100 train_loss:3.6297 train_time:172878ms step_avg:141.70ms
step:1231/5100 train_loss:3.6908 train_time:173019ms step_avg:141.70ms
step:1232/5100 train_loss:3.8489 train_time:173158ms step_avg:141.70ms
step:1233/5100 train_loss:3.7498 train_time:173298ms step_avg:141.70ms
step:1234/5100 train_loss:3.6856 train_time:173438ms step_avg:141.70ms
step:1235/5100 train_loss:3.8384 train_time:173579ms step_avg:141.70ms
step:1236/5100 train_loss:3.5988 train_time:173719ms step_avg:141.70ms
step:1237/5100 train_loss:3.5677 train_time:173858ms step_avg:141.69ms
step:1238/5100 train_loss:3.5158 train_time:173997ms step_avg:141.69ms
step:1239/5100 train_loss:3.5935 train_time:174138ms step_avg:141.69ms
step:1240/5100 train_loss:3.5999 train_time:174277ms step_avg:141.69ms
step:1241/5100 train_loss:3.6423 train_time:174419ms step_avg:141.69ms
step:1242/5100 train_loss:3.6911 train_time:174559ms step_avg:141.69ms
step:1243/5100 train_loss:3.5609 train_time:174699ms step_avg:141.69ms
step:1244/5100 train_loss:3.6614 train_time:174839ms step_avg:141.68ms
step:1245/5100 train_loss:3.6758 train_time:174978ms step_avg:141.68ms
step:1246/5100 train_loss:3.6759 train_time:175118ms step_avg:141.68ms
step:1247/5100 train_loss:3.5072 train_time:175259ms step_avg:141.68ms
step:1248/5100 train_loss:3.6522 train_time:175398ms step_avg:141.68ms
step:1249/5100 train_loss:3.7103 train_time:175539ms step_avg:141.68ms
step:1250/5100 train_loss:3.6709 train_time:175678ms step_avg:141.68ms
step:1250/5100 val_loss:3.6247 train_time:175734ms step_avg:141.72ms
step:1251/5100 train_loss:3.5732 train_time:175830ms step_avg:141.68ms
step:1252/5100 train_loss:3.7843 train_time:175975ms step_avg:141.69ms
step:1253/5100 train_loss:3.6459 train_time:176114ms step_avg:141.69ms
step:1254/5100 train_loss:3.5761 train_time:176255ms step_avg:141.68ms
step:1255/5100 train_loss:3.7007 train_time:176392ms step_avg:141.68ms
step:1256/5100 train_loss:3.7764 train_time:176531ms step_avg:141.68ms
step:1257/5100 train_loss:3.5812 train_time:176669ms step_avg:141.68ms
step:1258/5100 train_loss:3.6047 train_time:176814ms step_avg:141.68ms
step:1259/5100 train_loss:3.6342 train_time:176958ms step_avg:141.68ms
step:1260/5100 train_loss:3.6118 train_time:177097ms step_avg:141.68ms
step:1261/5100 train_loss:3.4669 train_time:177237ms step_avg:141.68ms
step:1262/5100 train_loss:3.5680 train_time:177375ms step_avg:141.67ms
step:1263/5100 train_loss:3.6331 train_time:177513ms step_avg:141.67ms
step:1264/5100 train_loss:3.4861 train_time:177654ms step_avg:141.67ms
step:1265/5100 train_loss:3.7053 train_time:177795ms step_avg:141.67ms
step:1266/5100 train_loss:3.6890 train_time:177939ms step_avg:141.67ms
step:1267/5100 train_loss:3.6992 train_time:178080ms step_avg:141.67ms
step:1268/5100 train_loss:3.6400 train_time:178219ms step_avg:141.67ms
step:1269/5100 train_loss:3.6755 train_time:178357ms step_avg:141.67ms
step:1270/5100 train_loss:3.5294 train_time:178496ms step_avg:141.66ms
step:1271/5100 train_loss:3.3754 train_time:178636ms step_avg:141.66ms
step:1272/5100 train_loss:3.6606 train_time:178776ms step_avg:141.66ms
step:1273/5100 train_loss:3.6224 train_time:178919ms step_avg:141.66ms
step:1274/5100 train_loss:3.6728 train_time:179057ms step_avg:141.66ms
step:1275/5100 train_loss:3.6182 train_time:179197ms step_avg:141.66ms
step:1276/5100 train_loss:3.7142 train_time:179336ms step_avg:141.66ms
step:1277/5100 train_loss:3.7360 train_time:179475ms step_avg:141.65ms
step:1278/5100 train_loss:3.6968 train_time:179614ms step_avg:141.65ms
step:1279/5100 train_loss:3.6860 train_time:179755ms step_avg:141.65ms
step:1280/5100 train_loss:3.5203 train_time:179896ms step_avg:141.65ms
step:1281/5100 train_loss:3.6372 train_time:180038ms step_avg:141.65ms
step:1282/5100 train_loss:3.7033 train_time:180177ms step_avg:141.65ms
step:1283/5100 train_loss:3.7364 train_time:180317ms step_avg:141.65ms
step:1284/5100 train_loss:3.6194 train_time:180456ms step_avg:141.65ms
step:1285/5100 train_loss:3.6500 train_time:180594ms step_avg:141.64ms
step:1286/5100 train_loss:3.6313 train_time:180736ms step_avg:141.64ms
step:1287/5100 train_loss:3.6075 train_time:180877ms step_avg:141.64ms
step:1288/5100 train_loss:3.7434 train_time:181018ms step_avg:141.64ms
step:1289/5100 train_loss:3.5758 train_time:181157ms step_avg:141.64ms
step:1290/5100 train_loss:3.6617 train_time:181296ms step_avg:141.64ms
step:1291/5100 train_loss:3.7293 train_time:181436ms step_avg:141.64ms
step:1292/5100 train_loss:3.6565 train_time:181574ms step_avg:141.63ms
step:1293/5100 train_loss:3.7575 train_time:181715ms step_avg:141.63ms
step:1294/5100 train_loss:3.7771 train_time:181856ms step_avg:141.63ms
step:1295/5100 train_loss:3.7442 train_time:181996ms step_avg:141.63ms
step:1296/5100 train_loss:3.5547 train_time:182139ms step_avg:141.63ms
step:1297/5100 train_loss:3.6322 train_time:182285ms step_avg:141.64ms
step:1298/5100 train_loss:3.5344 train_time:182416ms step_avg:141.63ms
step:1299/5100 train_loss:3.6042 train_time:182555ms step_avg:141.63ms
step:1300/5100 train_loss:3.6706 train_time:182695ms step_avg:141.62ms
step:1301/5100 train_loss:3.6745 train_time:182835ms step_avg:141.62ms
step:1302/5100 train_loss:3.6801 train_time:182976ms step_avg:141.62ms
step:1303/5100 train_loss:3.8363 train_time:183116ms step_avg:141.62ms
step:1304/5100 train_loss:3.6110 train_time:183256ms step_avg:141.62ms
step:1305/5100 train_loss:3.8251 train_time:183396ms step_avg:141.62ms
step:1306/5100 train_loss:3.5454 train_time:183535ms step_avg:141.62ms
step:1307/5100 train_loss:3.7267 train_time:183675ms step_avg:141.62ms
step:1308/5100 train_loss:3.7326 train_time:183816ms step_avg:141.61ms
step:1309/5100 train_loss:3.5957 train_time:183955ms step_avg:141.61ms
step:1310/5100 train_loss:3.5676 train_time:184095ms step_avg:141.61ms
step:1311/5100 train_loss:3.6032 train_time:184236ms step_avg:141.61ms
step:1312/5100 train_loss:3.5576 train_time:184375ms step_avg:141.61ms
step:1313/5100 train_loss:3.6768 train_time:184518ms step_avg:141.61ms
step:1314/5100 train_loss:3.6298 train_time:184656ms step_avg:141.61ms
step:1315/5100 train_loss:3.3360 train_time:184795ms step_avg:141.61ms
step:1316/5100 train_loss:3.5731 train_time:184936ms step_avg:141.60ms
step:1317/5100 train_loss:3.6504 train_time:185075ms step_avg:141.60ms
step:1318/5100 train_loss:3.6818 train_time:185216ms step_avg:141.60ms
step:1319/5100 train_loss:3.5477 train_time:185356ms step_avg:141.60ms
step:1320/5100 train_loss:3.6974 train_time:185495ms step_avg:141.60ms
step:1321/5100 train_loss:3.7501 train_time:185636ms step_avg:141.60ms
step:1322/5100 train_loss:3.6365 train_time:185775ms step_avg:141.60ms
step:1323/5100 train_loss:3.5827 train_time:186080ms step_avg:141.72ms
step:1324/5100 train_loss:3.6166 train_time:186217ms step_avg:141.72ms
step:1325/5100 train_loss:3.7019 train_time:186355ms step_avg:141.72ms
step:1326/5100 train_loss:3.7586 train_time:186493ms step_avg:141.71ms
step:1327/5100 train_loss:3.5277 train_time:186632ms step_avg:141.71ms
step:1328/5100 train_loss:3.4443 train_time:186770ms step_avg:141.71ms
step:1329/5100 train_loss:3.7461 train_time:186909ms step_avg:141.70ms
step:1330/5100 train_loss:3.5956 train_time:187214ms step_avg:141.83ms
step:1331/5100 train_loss:3.7242 train_time:187354ms step_avg:141.83ms
step:1332/5100 train_loss:3.6252 train_time:187492ms step_avg:141.82ms
step:1333/5100 train_loss:4.0368 train_time:187630ms step_avg:141.82ms
step:1334/5100 train_loss:3.7192 train_time:187768ms step_avg:141.82ms
step:1335/5100 train_loss:3.6359 train_time:187908ms step_avg:141.82ms
step:1336/5100 train_loss:3.5799 train_time:188046ms step_avg:141.81ms
step:1337/5100 train_loss:3.5746 train_time:188192ms step_avg:141.82ms
step:1338/5100 train_loss:3.8359 train_time:188335ms step_avg:141.82ms
step:1339/5100 train_loss:3.7704 train_time:188474ms step_avg:141.82ms
step:1340/5100 train_loss:3.6177 train_time:188614ms step_avg:141.81ms
step:1341/5100 train_loss:3.5723 train_time:188754ms step_avg:141.81ms
step:1342/5100 train_loss:3.8830 train_time:188892ms step_avg:141.81ms
step:1343/5100 train_loss:3.6440 train_time:189033ms step_avg:141.81ms
step:1344/5100 train_loss:3.6366 train_time:189175ms step_avg:141.81ms
step:1345/5100 train_loss:3.7034 train_time:189316ms step_avg:141.81ms
step:1346/5100 train_loss:3.6685 train_time:189456ms step_avg:141.81ms
step:1347/5100 train_loss:3.5724 train_time:189595ms step_avg:141.81ms
step:1348/5100 train_loss:3.5207 train_time:189735ms step_avg:141.81ms
step:1349/5100 train_loss:3.6140 train_time:189874ms step_avg:141.80ms
step:1350/5100 train_loss:3.5513 train_time:190014ms step_avg:141.80ms
step:1351/5100 train_loss:3.6807 train_time:190156ms step_avg:141.80ms
step:1352/5100 train_loss:3.5289 train_time:190296ms step_avg:141.80ms
step:1353/5100 train_loss:3.5904 train_time:190437ms step_avg:141.80ms
step:1354/5100 train_loss:3.6992 train_time:190576ms step_avg:141.80ms
step:1355/5100 train_loss:3.5379 train_time:190715ms step_avg:141.80ms
step:1356/5100 train_loss:3.4596 train_time:190855ms step_avg:141.79ms
step:1357/5100 train_loss:3.8070 train_time:190995ms step_avg:141.79ms
step:1358/5100 train_loss:3.7376 train_time:191136ms step_avg:141.79ms
step:1359/5100 train_loss:3.4501 train_time:191277ms step_avg:141.79ms
step:1360/5100 train_loss:3.7349 train_time:191417ms step_avg:141.79ms
step:1361/5100 train_loss:3.6293 train_time:191556ms step_avg:141.79ms
step:1362/5100 train_loss:3.4811 train_time:191695ms step_avg:141.79ms
step:1363/5100 train_loss:3.6615 train_time:191835ms step_avg:141.79ms
step:1364/5100 train_loss:3.5511 train_time:191975ms step_avg:141.78ms
step:1365/5100 train_loss:3.5785 train_time:192115ms step_avg:141.78ms
step:1366/5100 train_loss:3.6001 train_time:192256ms step_avg:141.78ms
step:1367/5100 train_loss:3.6986 train_time:192398ms step_avg:141.78ms
step:1368/5100 train_loss:3.6822 train_time:192535ms step_avg:141.78ms
step:1369/5100 train_loss:3.6357 train_time:192674ms step_avg:141.78ms
step:1370/5100 train_loss:3.5418 train_time:192815ms step_avg:141.78ms
step:1371/5100 train_loss:3.8657 train_time:192955ms step_avg:141.77ms
step:1372/5100 train_loss:3.6095 train_time:193102ms step_avg:141.78ms
step:1373/5100 train_loss:3.6442 train_time:193236ms step_avg:141.77ms
step:1374/5100 train_loss:3.6472 train_time:193376ms step_avg:141.77ms
step:1375/5100 train_loss:3.4434 train_time:193516ms step_avg:141.77ms
step:1375/5100 val_loss:3.6046 train_time:193572ms step_avg:141.81ms
step:1376/5100 train_loss:3.8450 train_time:193670ms step_avg:141.78ms
step:1377/5100 train_loss:3.6212 train_time:193814ms step_avg:141.78ms
step:1378/5100 train_loss:3.7704 train_time:193955ms step_avg:141.78ms
step:1379/5100 train_loss:3.8160 train_time:194094ms step_avg:141.78ms
step:1380/5100 train_loss:3.4805 train_time:194233ms step_avg:141.78ms
step:1381/5100 train_loss:3.6120 train_time:194373ms step_avg:141.77ms
step:1382/5100 train_loss:4.0506 train_time:194512ms step_avg:141.77ms
step:1383/5100 train_loss:3.5246 train_time:194654ms step_avg:141.77ms
step:1384/5100 train_loss:3.6807 train_time:194799ms step_avg:141.78ms
step:1385/5100 train_loss:3.7611 train_time:194940ms step_avg:141.77ms
step:1386/5100 train_loss:3.6652 train_time:195079ms step_avg:141.77ms
step:1387/5100 train_loss:3.6602 train_time:195219ms step_avg:141.77ms
step:1388/5100 train_loss:3.4846 train_time:195360ms step_avg:141.77ms
step:1389/5100 train_loss:3.6290 train_time:195497ms step_avg:141.77ms
step:1390/5100 train_loss:3.5965 train_time:195637ms step_avg:141.77ms
step:1391/5100 train_loss:3.8649 train_time:195779ms step_avg:141.77ms
step:1392/5100 train_loss:3.5825 train_time:195920ms step_avg:141.77ms
step:1393/5100 train_loss:3.5707 train_time:196059ms step_avg:141.76ms
step:1394/5100 train_loss:3.5395 train_time:196198ms step_avg:141.76ms
step:1395/5100 train_loss:3.8249 train_time:196338ms step_avg:141.76ms
step:1396/5100 train_loss:3.7129 train_time:196478ms step_avg:141.76ms
step:1397/5100 train_loss:3.7191 train_time:196618ms step_avg:141.76ms
step:1398/5100 train_loss:3.5850 train_time:196759ms step_avg:141.76ms
step:1399/5100 train_loss:3.5605 train_time:196900ms step_avg:141.76ms
step:1400/5100 train_loss:3.6267 train_time:197040ms step_avg:141.76ms
step:1401/5100 train_loss:3.5958 train_time:197178ms step_avg:141.75ms
step:1402/5100 train_loss:3.6231 train_time:197318ms step_avg:141.75ms
step:1403/5100 train_loss:3.5850 train_time:197459ms step_avg:141.75ms
step:1404/5100 train_loss:3.8204 train_time:197598ms step_avg:141.75ms
step:1405/5100 train_loss:3.5593 train_time:197739ms step_avg:141.75ms
step:1406/5100 train_loss:3.6048 train_time:197880ms step_avg:141.75ms
step:1407/5100 train_loss:3.5984 train_time:198021ms step_avg:141.75ms
step:1408/5100 train_loss:3.4763 train_time:198161ms step_avg:141.75ms
step:1409/5100 train_loss:3.5811 train_time:198299ms step_avg:141.74ms
step:1410/5100 train_loss:3.5721 train_time:198440ms step_avg:141.74ms
step:1411/5100 train_loss:3.5672 train_time:198579ms step_avg:141.74ms
step:1412/5100 train_loss:3.6573 train_time:198719ms step_avg:141.74ms
step:1413/5100 train_loss:3.5885 train_time:198861ms step_avg:141.74ms
step:1414/5100 train_loss:3.6409 train_time:199000ms step_avg:141.74ms
step:1415/5100 train_loss:3.6338 train_time:199140ms step_avg:141.74ms
step:1416/5100 train_loss:3.7143 train_time:199279ms step_avg:141.73ms
step:1417/5100 train_loss:3.5118 train_time:199419ms step_avg:141.73ms
step:1418/5100 train_loss:3.5787 train_time:199560ms step_avg:141.73ms
step:1419/5100 train_loss:3.6663 train_time:199698ms step_avg:141.73ms
step:1420/5100 train_loss:3.7069 train_time:199839ms step_avg:141.73ms
step:1421/5100 train_loss:3.6683 train_time:199979ms step_avg:141.73ms
step:1422/5100 train_loss:3.6636 train_time:200119ms step_avg:141.73ms
step:1423/5100 train_loss:3.6437 train_time:200259ms step_avg:141.73ms
step:1424/5100 train_loss:3.6344 train_time:200398ms step_avg:141.72ms
step:1425/5100 train_loss:3.6255 train_time:200539ms step_avg:141.72ms
step:1426/5100 train_loss:3.4947 train_time:200679ms step_avg:141.72ms
step:1427/5100 train_loss:3.6187 train_time:200820ms step_avg:141.72ms
step:1428/5100 train_loss:3.5574 train_time:200960ms step_avg:141.72ms
step:1429/5100 train_loss:3.6686 train_time:201099ms step_avg:141.72ms
step:1430/5100 train_loss:3.6272 train_time:201239ms step_avg:141.72ms
step:1431/5100 train_loss:3.5592 train_time:201379ms step_avg:141.72ms
step:1432/5100 train_loss:3.6060 train_time:201519ms step_avg:141.71ms
step:1433/5100 train_loss:3.6482 train_time:201660ms step_avg:141.71ms
step:1434/5100 train_loss:3.5249 train_time:201799ms step_avg:141.71ms
step:1435/5100 train_loss:3.6178 train_time:201939ms step_avg:141.71ms
step:1436/5100 train_loss:3.4460 train_time:202079ms step_avg:141.71ms
step:1437/5100 train_loss:3.5053 train_time:202218ms step_avg:141.71ms
step:1438/5100 train_loss:3.6980 train_time:202358ms step_avg:141.71ms
step:1439/5100 train_loss:3.6577 train_time:202498ms step_avg:141.71ms
step:1440/5100 train_loss:3.6084 train_time:202640ms step_avg:141.71ms
step:1441/5100 train_loss:3.4668 train_time:202779ms step_avg:141.70ms
step:1442/5100 train_loss:3.6357 train_time:202919ms step_avg:141.70ms
step:1443/5100 train_loss:3.6916 train_time:203060ms step_avg:141.70ms
step:1444/5100 train_loss:3.7673 train_time:203199ms step_avg:141.70ms
step:1445/5100 train_loss:3.7383 train_time:203338ms step_avg:141.70ms
step:1446/5100 train_loss:3.6252 train_time:203479ms step_avg:141.70ms
step:1447/5100 train_loss:3.4964 train_time:203619ms step_avg:141.70ms
step:1448/5100 train_loss:3.5713 train_time:203759ms step_avg:141.70ms
step:1449/5100 train_loss:3.5925 train_time:203899ms step_avg:141.70ms
step:1450/5100 train_loss:3.7146 train_time:204040ms step_avg:141.69ms
step:1451/5100 train_loss:3.6959 train_time:204179ms step_avg:141.69ms
step:1452/5100 train_loss:3.5173 train_time:204319ms step_avg:141.69ms
step:1453/5100 train_loss:3.6245 train_time:204459ms step_avg:141.69ms
step:1454/5100 train_loss:3.5436 train_time:204600ms step_avg:141.69ms
step:1455/5100 train_loss:3.5727 train_time:204740ms step_avg:141.69ms
step:1456/5100 train_loss:3.6162 train_time:204880ms step_avg:141.69ms
step:1457/5100 train_loss:3.5495 train_time:205020ms step_avg:141.69ms
step:1458/5100 train_loss:3.4528 train_time:205160ms step_avg:141.68ms
step:1459/5100 train_loss:3.6932 train_time:205299ms step_avg:141.68ms
step:1460/5100 train_loss:3.5664 train_time:205439ms step_avg:141.68ms
step:1461/5100 train_loss:3.6187 train_time:205579ms step_avg:141.68ms
step:1462/5100 train_loss:3.7443 train_time:205720ms step_avg:141.68ms
step:1463/5100 train_loss:3.5637 train_time:205862ms step_avg:141.68ms
step:1464/5100 train_loss:3.7487 train_time:206000ms step_avg:141.68ms
step:1465/5100 train_loss:3.6445 train_time:206140ms step_avg:141.68ms
step:1466/5100 train_loss:3.6523 train_time:206278ms step_avg:141.67ms
step:1467/5100 train_loss:3.5676 train_time:206419ms step_avg:141.67ms
step:1468/5100 train_loss:3.7259 train_time:206559ms step_avg:141.67ms
step:1469/5100 train_loss:3.5894 train_time:206699ms step_avg:141.67ms
step:1470/5100 train_loss:3.5582 train_time:206839ms step_avg:141.67ms
step:1471/5100 train_loss:3.6117 train_time:206978ms step_avg:141.67ms
step:1472/5100 train_loss:3.5361 train_time:207118ms step_avg:141.67ms
step:1473/5100 train_loss:3.6237 train_time:207258ms step_avg:141.67ms
step:1474/5100 train_loss:3.7228 train_time:207398ms step_avg:141.67ms
step:1475/5100 train_loss:3.5977 train_time:207539ms step_avg:141.66ms
step:1476/5100 train_loss:3.4277 train_time:207679ms step_avg:141.66ms
step:1477/5100 train_loss:3.5485 train_time:207819ms step_avg:141.66ms
step:1478/5100 train_loss:3.5272 train_time:207959ms step_avg:141.66ms
step:1479/5100 train_loss:3.6108 train_time:208098ms step_avg:141.66ms
step:1480/5100 train_loss:3.6928 train_time:208239ms step_avg:141.66ms
step:1481/5100 train_loss:3.5701 train_time:208379ms step_avg:141.66ms
step:1482/5100 train_loss:3.7391 train_time:208519ms step_avg:141.66ms
step:1483/5100 train_loss:3.6652 train_time:208662ms step_avg:141.66ms
step:1484/5100 train_loss:3.5606 train_time:208800ms step_avg:141.66ms
step:1485/5100 train_loss:3.5611 train_time:208940ms step_avg:141.65ms
step:1486/5100 train_loss:3.5505 train_time:209078ms step_avg:141.65ms
step:1487/5100 train_loss:3.5296 train_time:209222ms step_avg:141.65ms
step:1488/5100 train_loss:3.6171 train_time:209359ms step_avg:141.65ms
step:1489/5100 train_loss:3.5258 train_time:209498ms step_avg:141.65ms
step:1490/5100 train_loss:3.6162 train_time:209639ms step_avg:141.65ms
step:1491/5100 train_loss:3.5569 train_time:209779ms step_avg:141.65ms
step:1492/5100 train_loss:3.4706 train_time:209919ms step_avg:141.65ms
step:1493/5100 train_loss:3.5478 train_time:210059ms step_avg:141.64ms
step:1494/5100 train_loss:3.7204 train_time:210199ms step_avg:141.64ms
step:1495/5100 train_loss:3.5817 train_time:210340ms step_avg:141.64ms
step:1496/5100 train_loss:3.3347 train_time:210478ms step_avg:141.64ms
step:1497/5100 train_loss:3.6399 train_time:210619ms step_avg:141.64ms
step:1498/5100 train_loss:3.6003 train_time:210760ms step_avg:141.64ms
step:1499/5100 train_loss:3.6449 train_time:210900ms step_avg:141.64ms
step:1500/5100 train_loss:3.6081 train_time:211041ms step_avg:141.64ms
step:1500/5100 val_loss:3.5800 train_time:211096ms step_avg:141.68ms
step:1501/5100 train_loss:3.5821 train_time:211194ms step_avg:141.65ms
step:1502/5100 train_loss:3.3804 train_time:211337ms step_avg:141.65ms
step:1503/5100 train_loss:3.6550 train_time:211477ms step_avg:141.65ms
step:1504/5100 train_loss:3.5274 train_time:211616ms step_avg:141.64ms
step:1505/5100 train_loss:3.5404 train_time:211754ms step_avg:141.64ms
step:1506/5100 train_loss:3.5012 train_time:211892ms step_avg:141.64ms
step:1507/5100 train_loss:3.5818 train_time:212030ms step_avg:141.64ms
step:1508/5100 train_loss:3.4961 train_time:212172ms step_avg:141.64ms
step:1509/5100 train_loss:3.8152 train_time:212317ms step_avg:141.64ms
step:1510/5100 train_loss:3.5511 train_time:212458ms step_avg:141.64ms
step:1511/5100 train_loss:3.5604 train_time:212597ms step_avg:141.64ms
step:1512/5100 train_loss:3.6809 train_time:212910ms step_avg:141.75ms
step:1513/5100 train_loss:3.7131 train_time:213048ms step_avg:141.75ms
step:1514/5100 train_loss:3.5689 train_time:213187ms step_avg:141.75ms
step:1515/5100 train_loss:3.4133 train_time:213325ms step_avg:141.74ms
step:1516/5100 train_loss:3.5326 train_time:213464ms step_avg:141.74ms
step:1517/5100 train_loss:3.5449 train_time:213602ms step_avg:141.74ms
step:1518/5100 train_loss:3.6234 train_time:213740ms step_avg:141.74ms
step:1519/5100 train_loss:3.5057 train_time:213883ms step_avg:141.74ms
step:1520/5100 train_loss:3.8018 train_time:214187ms step_avg:141.85ms
step:1521/5100 train_loss:3.4656 train_time:214326ms step_avg:141.84ms
step:1522/5100 train_loss:3.5245 train_time:214466ms step_avg:141.84ms
step:1523/5100 train_loss:3.6626 train_time:214604ms step_avg:141.84ms
step:1524/5100 train_loss:3.5244 train_time:214746ms step_avg:141.84ms
step:1525/5100 train_loss:3.6168 train_time:214882ms step_avg:141.84ms
step:1526/5100 train_loss:3.6151 train_time:215021ms step_avg:141.83ms
step:1527/5100 train_loss:3.5797 train_time:215166ms step_avg:141.84ms
step:1528/5100 train_loss:3.5758 train_time:215307ms step_avg:141.84ms
step:1529/5100 train_loss:3.7284 train_time:215449ms step_avg:141.84ms
step:1530/5100 train_loss:3.6976 train_time:215590ms step_avg:141.84ms
step:1531/5100 train_loss:3.5304 train_time:215729ms step_avg:141.83ms
step:1532/5100 train_loss:3.4870 train_time:215869ms step_avg:141.83ms
step:1533/5100 train_loss:3.6363 train_time:216016ms step_avg:141.84ms
step:1534/5100 train_loss:3.5959 train_time:216150ms step_avg:141.83ms
step:1535/5100 train_loss:3.5798 train_time:216291ms step_avg:141.83ms
step:1536/5100 train_loss:3.5779 train_time:216431ms step_avg:141.83ms
step:1537/5100 train_loss:3.5199 train_time:216571ms step_avg:141.83ms
step:1538/5100 train_loss:3.5738 train_time:216711ms step_avg:141.83ms
step:1539/5100 train_loss:3.7520 train_time:216852ms step_avg:141.83ms
step:1540/5100 train_loss:3.6780 train_time:216992ms step_avg:141.82ms
step:1541/5100 train_loss:3.5947 train_time:217132ms step_avg:141.82ms
step:1542/5100 train_loss:3.5449 train_time:217274ms step_avg:141.82ms
step:1543/5100 train_loss:3.5428 train_time:217412ms step_avg:141.82ms
step:1544/5100 train_loss:3.5107 train_time:217553ms step_avg:141.82ms
step:1545/5100 train_loss:3.5968 train_time:217692ms step_avg:141.82ms
step:1546/5100 train_loss:3.5600 train_time:217832ms step_avg:141.82ms
step:1547/5100 train_loss:3.5361 train_time:217972ms step_avg:141.82ms
step:1548/5100 train_loss:3.5050 train_time:218112ms step_avg:141.82ms
step:1549/5100 train_loss:3.5377 train_time:218253ms step_avg:141.82ms
step:1550/5100 train_loss:3.6493 train_time:218393ms step_avg:141.81ms
step:1551/5100 train_loss:3.5782 train_time:218533ms step_avg:141.81ms
step:1552/5100 train_loss:3.5135 train_time:218672ms step_avg:141.81ms
step:1553/5100 train_loss:3.5170 train_time:218811ms step_avg:141.81ms
step:1554/5100 train_loss:3.5068 train_time:218952ms step_avg:141.81ms
step:1555/5100 train_loss:3.6276 train_time:219092ms step_avg:141.81ms
step:1556/5100 train_loss:3.6414 train_time:219232ms step_avg:141.81ms
step:1557/5100 train_loss:3.5678 train_time:219373ms step_avg:141.81ms
step:1558/5100 train_loss:3.6299 train_time:219511ms step_avg:141.80ms
step:1559/5100 train_loss:3.5511 train_time:219653ms step_avg:141.80ms
step:1560/5100 train_loss:3.4713 train_time:219792ms step_avg:141.80ms
step:1561/5100 train_loss:3.7046 train_time:219932ms step_avg:141.80ms
step:1562/5100 train_loss:3.5226 train_time:220075ms step_avg:141.80ms
step:1563/5100 train_loss:3.5026 train_time:220211ms step_avg:141.80ms
step:1564/5100 train_loss:3.6326 train_time:220353ms step_avg:141.80ms
step:1565/5100 train_loss:3.4590 train_time:220493ms step_avg:141.80ms
step:1566/5100 train_loss:3.5111 train_time:220632ms step_avg:141.79ms
step:1567/5100 train_loss:3.6610 train_time:220775ms step_avg:141.80ms
step:1568/5100 train_loss:3.5440 train_time:220913ms step_avg:141.79ms
step:1569/5100 train_loss:3.5262 train_time:221053ms step_avg:141.79ms
step:1570/5100 train_loss:3.6262 train_time:221192ms step_avg:141.79ms
step:1571/5100 train_loss:3.6360 train_time:221332ms step_avg:141.79ms
step:1572/5100 train_loss:3.4617 train_time:221472ms step_avg:141.79ms
step:1573/5100 train_loss:3.4840 train_time:221611ms step_avg:141.79ms
step:1574/5100 train_loss:3.6136 train_time:221754ms step_avg:141.79ms
step:1575/5100 train_loss:3.4765 train_time:221892ms step_avg:141.78ms
step:1576/5100 train_loss:3.6224 train_time:222031ms step_avg:141.78ms
step:1577/5100 train_loss:3.5292 train_time:222171ms step_avg:141.78ms
step:1578/5100 train_loss:3.5816 train_time:222311ms step_avg:141.78ms
step:1579/5100 train_loss:3.5527 train_time:222452ms step_avg:141.78ms
step:1580/5100 train_loss:3.5215 train_time:222602ms step_avg:141.78ms
step:1581/5100 train_loss:3.4880 train_time:222731ms step_avg:141.78ms
step:1582/5100 train_loss:3.7348 train_time:222872ms step_avg:141.78ms
step:1583/5100 train_loss:3.5136 train_time:223011ms step_avg:141.77ms
step:1584/5100 train_loss:3.6660 train_time:223152ms step_avg:141.77ms
step:1585/5100 train_loss:3.4949 train_time:223292ms step_avg:141.77ms
step:1586/5100 train_loss:3.6597 train_time:223432ms step_avg:141.77ms
step:1587/5100 train_loss:3.4398 train_time:223571ms step_avg:141.77ms
step:1588/5100 train_loss:3.6355 train_time:223712ms step_avg:141.77ms
step:1589/5100 train_loss:3.5491 train_time:223853ms step_avg:141.77ms
step:1590/5100 train_loss:3.7081 train_time:223992ms step_avg:141.77ms
step:1591/5100 train_loss:3.5223 train_time:224132ms step_avg:141.77ms
step:1592/5100 train_loss:3.5381 train_time:224273ms step_avg:141.77ms
step:1593/5100 train_loss:3.6100 train_time:224412ms step_avg:141.76ms
step:1594/5100 train_loss:3.5867 train_time:224552ms step_avg:141.76ms
step:1595/5100 train_loss:3.5602 train_time:224692ms step_avg:141.76ms
step:1596/5100 train_loss:3.6992 train_time:224832ms step_avg:141.76ms
step:1597/5100 train_loss:3.4307 train_time:224972ms step_avg:141.76ms
step:1598/5100 train_loss:3.5915 train_time:225111ms step_avg:141.76ms
step:1599/5100 train_loss:3.6373 train_time:225252ms step_avg:141.76ms
step:1600/5100 train_loss:3.6819 train_time:225393ms step_avg:141.76ms
step:1601/5100 train_loss:3.5319 train_time:225532ms step_avg:141.76ms
step:1602/5100 train_loss:3.8224 train_time:225673ms step_avg:141.75ms
step:1603/5100 train_loss:3.7166 train_time:225811ms step_avg:141.75ms
step:1604/5100 train_loss:3.4924 train_time:225952ms step_avg:141.75ms
step:1605/5100 train_loss:3.5337 train_time:226093ms step_avg:141.75ms
step:1606/5100 train_loss:3.4191 train_time:226232ms step_avg:141.75ms
step:1607/5100 train_loss:3.7420 train_time:226376ms step_avg:141.75ms
step:1608/5100 train_loss:3.5405 train_time:226512ms step_avg:141.75ms
step:1609/5100 train_loss:3.5669 train_time:226653ms step_avg:141.75ms
step:1610/5100 train_loss:3.5121 train_time:226793ms step_avg:141.75ms
step:1611/5100 train_loss:4.1237 train_time:226933ms step_avg:141.74ms
step:1612/5100 train_loss:3.7499 train_time:227073ms step_avg:141.74ms
step:1613/5100 train_loss:3.6651 train_time:227212ms step_avg:141.74ms
step:1614/5100 train_loss:3.5287 train_time:227352ms step_avg:141.74ms
step:1615/5100 train_loss:3.5751 train_time:227491ms step_avg:141.74ms
step:1616/5100 train_loss:3.5690 train_time:227632ms step_avg:141.74ms
step:1617/5100 train_loss:3.5241 train_time:227772ms step_avg:141.74ms
step:1618/5100 train_loss:3.6085 train_time:227913ms step_avg:141.74ms
step:1619/5100 train_loss:3.5613 train_time:228056ms step_avg:141.74ms
step:1620/5100 train_loss:3.4533 train_time:228192ms step_avg:141.73ms
step:1621/5100 train_loss:3.7187 train_time:228333ms step_avg:141.73ms
step:1622/5100 train_loss:3.6269 train_time:228473ms step_avg:141.73ms
step:1623/5100 train_loss:3.4202 train_time:228612ms step_avg:141.73ms
step:1624/5100 train_loss:3.5362 train_time:228752ms step_avg:141.73ms
step:1625/5100 train_loss:3.4946 train_time:228892ms step_avg:141.73ms
step:1625/5100 val_loss:3.5641 train_time:228949ms step_avg:141.76ms
step:1626/5100 train_loss:3.5686 train_time:229040ms step_avg:141.73ms
step:1627/5100 train_loss:3.5385 train_time:229188ms step_avg:141.74ms
step:1628/5100 train_loss:3.4970 train_time:229328ms step_avg:141.74ms
step:1629/5100 train_loss:3.6138 train_time:229466ms step_avg:141.73ms
step:1630/5100 train_loss:3.5064 train_time:229604ms step_avg:141.73ms
step:1631/5100 train_loss:3.5608 train_time:229742ms step_avg:141.73ms
step:1632/5100 train_loss:3.4394 train_time:229881ms step_avg:141.73ms
step:1633/5100 train_loss:3.4125 train_time:230027ms step_avg:141.73ms
step:1634/5100 train_loss:3.5785 train_time:230168ms step_avg:141.73ms
step:1635/5100 train_loss:3.5578 train_time:230308ms step_avg:141.73ms
step:1636/5100 train_loss:3.4986 train_time:230447ms step_avg:141.73ms
step:1637/5100 train_loss:3.5886 train_time:230585ms step_avg:141.72ms
step:1638/5100 train_loss:3.6365 train_time:230724ms step_avg:141.72ms
step:1639/5100 train_loss:3.6713 train_time:230863ms step_avg:141.72ms
step:1640/5100 train_loss:3.8369 train_time:231007ms step_avg:141.72ms
step:1641/5100 train_loss:3.6501 train_time:231146ms step_avg:141.72ms
step:1642/5100 train_loss:3.5630 train_time:231287ms step_avg:141.72ms
step:1643/5100 train_loss:3.6485 train_time:231426ms step_avg:141.72ms
step:1644/5100 train_loss:3.5505 train_time:231565ms step_avg:141.72ms
step:1645/5100 train_loss:3.5649 train_time:231703ms step_avg:141.71ms
step:1646/5100 train_loss:3.5676 train_time:231846ms step_avg:141.71ms
step:1647/5100 train_loss:3.3420 train_time:231983ms step_avg:141.71ms
step:1648/5100 train_loss:3.6024 train_time:232127ms step_avg:141.71ms
step:1649/5100 train_loss:3.4728 train_time:232265ms step_avg:141.71ms
step:1650/5100 train_loss:3.5401 train_time:232405ms step_avg:141.71ms
step:1651/5100 train_loss:3.5190 train_time:232544ms step_avg:141.71ms
step:1652/5100 train_loss:3.5878 train_time:232683ms step_avg:141.71ms
step:1653/5100 train_loss:3.5242 train_time:232824ms step_avg:141.71ms
step:1654/5100 train_loss:3.6471 train_time:232965ms step_avg:141.71ms
step:1655/5100 train_loss:3.6401 train_time:233106ms step_avg:141.71ms
step:1656/5100 train_loss:3.4547 train_time:233247ms step_avg:141.71ms
step:1657/5100 train_loss:3.6066 train_time:233386ms step_avg:141.70ms
step:1658/5100 train_loss:3.5124 train_time:233526ms step_avg:141.70ms
step:1659/5100 train_loss:3.4848 train_time:233666ms step_avg:141.70ms
step:1660/5100 train_loss:3.5758 train_time:233805ms step_avg:141.70ms
step:1661/5100 train_loss:3.5926 train_time:233945ms step_avg:141.70ms
step:1662/5100 train_loss:3.5128 train_time:234085ms step_avg:141.70ms
step:1663/5100 train_loss:3.6070 train_time:234226ms step_avg:141.70ms
step:1664/5100 train_loss:3.6150 train_time:234365ms step_avg:141.70ms
step:1665/5100 train_loss:3.6389 train_time:234505ms step_avg:141.69ms
step:1666/5100 train_loss:3.6217 train_time:234645ms step_avg:141.69ms
step:1667/5100 train_loss:3.7570 train_time:234784ms step_avg:141.69ms
step:1668/5100 train_loss:3.4702 train_time:234925ms step_avg:141.69ms
step:1669/5100 train_loss:3.5463 train_time:235066ms step_avg:141.69ms
step:1670/5100 train_loss:3.4761 train_time:235206ms step_avg:141.69ms
step:1671/5100 train_loss:3.4785 train_time:235346ms step_avg:141.69ms
step:1672/5100 train_loss:3.6423 train_time:235485ms step_avg:141.69ms
step:1673/5100 train_loss:3.8151 train_time:235624ms step_avg:141.69ms
step:1674/5100 train_loss:3.5351 train_time:235764ms step_avg:141.69ms
step:1675/5100 train_loss:3.5199 train_time:235904ms step_avg:141.68ms
step:1676/5100 train_loss:3.4085 train_time:236045ms step_avg:141.68ms
step:1677/5100 train_loss:3.6158 train_time:236184ms step_avg:141.68ms
step:1678/5100 train_loss:3.5272 train_time:236325ms step_avg:141.68ms
step:1679/5100 train_loss:3.5524 train_time:236465ms step_avg:141.68ms
step:1680/5100 train_loss:3.5363 train_time:236605ms step_avg:141.68ms
step:1681/5100 train_loss:3.3611 train_time:236745ms step_avg:141.68ms
step:1682/5100 train_loss:3.5371 train_time:236884ms step_avg:141.68ms
step:1683/5100 train_loss:3.5558 train_time:237026ms step_avg:141.68ms
step:1684/5100 train_loss:3.5954 train_time:237165ms step_avg:141.68ms
step:1685/5100 train_loss:3.6049 train_time:237306ms step_avg:141.67ms
step:1686/5100 train_loss:3.5093 train_time:237446ms step_avg:141.67ms
step:1687/5100 train_loss:3.6141 train_time:237585ms step_avg:141.67ms
step:1688/5100 train_loss:3.4954 train_time:237725ms step_avg:141.67ms
step:1689/5100 train_loss:3.5869 train_time:237865ms step_avg:141.67ms
step:1690/5100 train_loss:3.4949 train_time:238005ms step_avg:141.67ms
step:1691/5100 train_loss:3.4021 train_time:238145ms step_avg:141.67ms
step:1692/5100 train_loss:3.5493 train_time:238285ms step_avg:141.67ms
step:1693/5100 train_loss:3.5411 train_time:238425ms step_avg:141.67ms
step:1694/5100 train_loss:3.4613 train_time:238564ms step_avg:141.67ms
step:1695/5100 train_loss:3.9062 train_time:238707ms step_avg:141.67ms
step:1696/5100 train_loss:3.6176 train_time:238845ms step_avg:141.66ms
step:1697/5100 train_loss:3.5948 train_time:238984ms step_avg:141.66ms
step:1698/5100 train_loss:3.5056 train_time:239133ms step_avg:141.67ms
step:1699/5100 train_loss:3.4214 train_time:239265ms step_avg:141.66ms
step:1700/5100 train_loss:3.5122 train_time:239405ms step_avg:141.66ms
step:1701/5100 train_loss:3.4997 train_time:239714ms step_avg:141.76ms
step:1702/5100 train_loss:3.5746 train_time:239845ms step_avg:141.75ms
step:1703/5100 train_loss:3.4960 train_time:239984ms step_avg:141.75ms
step:1704/5100 train_loss:3.7003 train_time:240123ms step_avg:141.75ms
step:1705/5100 train_loss:3.4696 train_time:240261ms step_avg:141.75ms
step:1706/5100 train_loss:3.6975 train_time:240399ms step_avg:141.74ms
step:1707/5100 train_loss:3.5442 train_time:240539ms step_avg:141.74ms
step:1708/5100 train_loss:3.3166 train_time:240684ms step_avg:141.75ms
step:1709/5100 train_loss:3.6519 train_time:240826ms step_avg:141.75ms
step:1710/5100 train_loss:3.5659 train_time:241127ms step_avg:141.84ms
step:1711/5100 train_loss:3.5491 train_time:241265ms step_avg:141.84ms
step:1712/5100 train_loss:3.5429 train_time:241403ms step_avg:141.83ms
step:1713/5100 train_loss:3.5809 train_time:241541ms step_avg:141.83ms
step:1714/5100 train_loss:3.6023 train_time:241679ms step_avg:141.83ms
step:1715/5100 train_loss:3.5315 train_time:241819ms step_avg:141.83ms
step:1716/5100 train_loss:3.5249 train_time:241959ms step_avg:141.83ms
step:1717/5100 train_loss:3.3692 train_time:242106ms step_avg:141.83ms
step:1718/5100 train_loss:3.5079 train_time:242247ms step_avg:141.83ms
step:1719/5100 train_loss:3.5312 train_time:242386ms step_avg:141.83ms
step:1720/5100 train_loss:3.4776 train_time:242524ms step_avg:141.83ms
step:1721/5100 train_loss:3.6309 train_time:242664ms step_avg:141.83ms
step:1722/5100 train_loss:3.4377 train_time:242803ms step_avg:141.82ms
step:1723/5100 train_loss:3.5805 train_time:242945ms step_avg:141.82ms
step:1724/5100 train_loss:3.6688 train_time:243086ms step_avg:141.82ms
step:1725/5100 train_loss:3.5135 train_time:243228ms step_avg:141.82ms
step:1726/5100 train_loss:3.7420 train_time:243366ms step_avg:141.82ms
step:1727/5100 train_loss:3.5335 train_time:243506ms step_avg:141.82ms
step:1728/5100 train_loss:3.5881 train_time:243644ms step_avg:141.82ms
step:1729/5100 train_loss:3.5630 train_time:243783ms step_avg:141.82ms
step:1730/5100 train_loss:3.5701 train_time:243924ms step_avg:141.82ms
step:1731/5100 train_loss:3.9350 train_time:244067ms step_avg:141.82ms
step:1732/5100 train_loss:3.5518 train_time:244208ms step_avg:141.82ms
step:1733/5100 train_loss:3.6807 train_time:244347ms step_avg:141.81ms
step:1734/5100 train_loss:3.4647 train_time:244486ms step_avg:141.81ms
step:1735/5100 train_loss:3.5044 train_time:244625ms step_avg:141.81ms
step:1736/5100 train_loss:3.5221 train_time:244774ms step_avg:141.82ms
step:1737/5100 train_loss:3.5077 train_time:244904ms step_avg:141.81ms
step:1738/5100 train_loss:3.6516 train_time:245046ms step_avg:141.81ms
step:1739/5100 train_loss:3.5114 train_time:245187ms step_avg:141.81ms
step:1740/5100 train_loss:3.5787 train_time:245326ms step_avg:141.81ms
step:1741/5100 train_loss:3.6302 train_time:245470ms step_avg:141.81ms
step:1742/5100 train_loss:3.4270 train_time:245604ms step_avg:141.80ms
step:1743/5100 train_loss:3.3189 train_time:245744ms step_avg:141.80ms
step:1744/5100 train_loss:3.2531 train_time:245883ms step_avg:141.80ms
step:1745/5100 train_loss:3.5500 train_time:246025ms step_avg:141.80ms
step:1746/5100 train_loss:3.5591 train_time:246165ms step_avg:141.80ms
step:1747/5100 train_loss:3.5347 train_time:246305ms step_avg:141.80ms
step:1748/5100 train_loss:3.5401 train_time:246444ms step_avg:141.80ms
step:1749/5100 train_loss:3.7822 train_time:246583ms step_avg:141.80ms
step:1750/5100 train_loss:3.4966 train_time:246723ms step_avg:141.79ms
step:1750/5100 val_loss:3.5450 train_time:246779ms step_avg:141.83ms
step:1751/5100 train_loss:3.5639 train_time:246874ms step_avg:141.80ms
step:1752/5100 train_loss:3.5481 train_time:247019ms step_avg:141.80ms
step:1753/5100 train_loss:3.1891 train_time:247160ms step_avg:141.80ms
step:1754/5100 train_loss:3.3048 train_time:247301ms step_avg:141.80ms
step:1755/5100 train_loss:3.4249 train_time:247440ms step_avg:141.80ms
step:1756/5100 train_loss:3.3628 train_time:247581ms step_avg:141.80ms
step:1757/5100 train_loss:3.5185 train_time:247717ms step_avg:141.80ms
step:1758/5100 train_loss:3.3994 train_time:247859ms step_avg:141.80ms
step:1759/5100 train_loss:3.3977 train_time:248003ms step_avg:141.80ms
step:1760/5100 train_loss:4.4561 train_time:248144ms step_avg:141.80ms
step:1761/5100 train_loss:3.5231 train_time:248283ms step_avg:141.79ms
step:1762/5100 train_loss:3.5626 train_time:248424ms step_avg:141.79ms
step:1763/5100 train_loss:3.5622 train_time:248564ms step_avg:141.79ms
step:1764/5100 train_loss:3.5848 train_time:248703ms step_avg:141.79ms
step:1765/5100 train_loss:3.4952 train_time:248845ms step_avg:141.79ms
step:1766/5100 train_loss:3.5327 train_time:248985ms step_avg:141.79ms
step:1767/5100 train_loss:3.5520 train_time:249127ms step_avg:141.79ms
step:1768/5100 train_loss:3.8040 train_time:249266ms step_avg:141.79ms
step:1769/5100 train_loss:3.5343 train_time:249406ms step_avg:141.79ms
step:1770/5100 train_loss:3.6026 train_time:249546ms step_avg:141.79ms
step:1771/5100 train_loss:3.8815 train_time:249685ms step_avg:141.79ms
step:1772/5100 train_loss:3.5321 train_time:249826ms step_avg:141.79ms
step:1773/5100 train_loss:3.4349 train_time:249971ms step_avg:141.79ms
step:1774/5100 train_loss:3.6957 train_time:250108ms step_avg:141.78ms
step:1775/5100 train_loss:3.4247 train_time:250250ms step_avg:141.78ms
step:1776/5100 train_loss:3.5930 train_time:250387ms step_avg:141.78ms
step:1777/5100 train_loss:3.6463 train_time:250527ms step_avg:141.78ms
step:1778/5100 train_loss:3.7354 train_time:250666ms step_avg:141.78ms
step:1779/5100 train_loss:3.5424 train_time:250810ms step_avg:141.78ms
step:1780/5100 train_loss:3.8385 train_time:250951ms step_avg:141.78ms
step:1781/5100 train_loss:3.6111 train_time:251089ms step_avg:141.78ms
step:1782/5100 train_loss:3.6239 train_time:251228ms step_avg:141.78ms
step:1783/5100 train_loss:3.4110 train_time:251370ms step_avg:141.78ms
step:1784/5100 train_loss:3.4984 train_time:251508ms step_avg:141.77ms
step:1785/5100 train_loss:3.6362 train_time:251647ms step_avg:141.77ms
step:1786/5100 train_loss:3.5299 train_time:251786ms step_avg:141.77ms
step:1787/5100 train_loss:3.6956 train_time:251927ms step_avg:141.77ms
step:1788/5100 train_loss:3.5060 train_time:252068ms step_avg:141.77ms
step:1789/5100 train_loss:3.4905 train_time:252207ms step_avg:141.77ms
step:1790/5100 train_loss:3.6332 train_time:252347ms step_avg:141.77ms
step:1791/5100 train_loss:3.5267 train_time:252486ms step_avg:141.77ms
step:1792/5100 train_loss:3.4794 train_time:252631ms step_avg:141.77ms
step:1793/5100 train_loss:3.6102 train_time:252767ms step_avg:141.77ms
step:1794/5100 train_loss:3.4877 train_time:252907ms step_avg:141.76ms
step:1795/5100 train_loss:3.4760 train_time:253048ms step_avg:141.76ms
step:1796/5100 train_loss:3.5307 train_time:253186ms step_avg:141.76ms
step:1797/5100 train_loss:3.4966 train_time:253327ms step_avg:141.76ms
step:1798/5100 train_loss:3.6382 train_time:253467ms step_avg:141.76ms
step:1799/5100 train_loss:3.5200 train_time:253608ms step_avg:141.76ms
step:1800/5100 train_loss:3.5959 train_time:253748ms step_avg:141.76ms
step:1801/5100 train_loss:3.5260 train_time:253888ms step_avg:141.76ms
step:1802/5100 train_loss:3.5644 train_time:254028ms step_avg:141.76ms
step:1803/5100 train_loss:3.4727 train_time:254168ms step_avg:141.76ms
step:1804/5100 train_loss:3.4033 train_time:254308ms step_avg:141.75ms
step:1805/5100 train_loss:3.6525 train_time:254447ms step_avg:141.75ms
step:1806/5100 train_loss:3.5769 train_time:254587ms step_avg:141.75ms
step:1807/5100 train_loss:3.5817 train_time:254728ms step_avg:141.75ms
step:1808/5100 train_loss:3.6926 train_time:254867ms step_avg:141.75ms
step:1809/5100 train_loss:3.4946 train_time:255008ms step_avg:141.75ms
step:1810/5100 train_loss:3.5876 train_time:255148ms step_avg:141.75ms
step:1811/5100 train_loss:3.7274 train_time:255287ms step_avg:141.75ms
step:1812/5100 train_loss:3.5842 train_time:255427ms step_avg:141.75ms
step:1813/5100 train_loss:3.6201 train_time:255568ms step_avg:141.75ms
step:1814/5100 train_loss:3.6454 train_time:255708ms step_avg:141.74ms
step:1815/5100 train_loss:3.5923 train_time:255847ms step_avg:141.74ms
step:1816/5100 train_loss:3.6308 train_time:255987ms step_avg:141.74ms
step:1817/5100 train_loss:3.5756 train_time:256131ms step_avg:141.74ms
step:1818/5100 train_loss:3.6323 train_time:256267ms step_avg:141.74ms
step:1819/5100 train_loss:3.5602 train_time:256407ms step_avg:141.74ms
step:1820/5100 train_loss:3.5489 train_time:256547ms step_avg:141.74ms
step:1821/5100 train_loss:3.4982 train_time:256688ms step_avg:141.74ms
step:1822/5100 train_loss:3.4829 train_time:256829ms step_avg:141.74ms
step:1823/5100 train_loss:3.4017 train_time:256967ms step_avg:141.74ms
step:1824/5100 train_loss:3.5639 train_time:257107ms step_avg:141.74ms
step:1825/5100 train_loss:3.6825 train_time:257248ms step_avg:141.73ms
step:1826/5100 train_loss:3.6428 train_time:257387ms step_avg:141.73ms
step:1827/5100 train_loss:3.6163 train_time:257527ms step_avg:141.73ms
step:1828/5100 train_loss:3.4879 train_time:257667ms step_avg:141.73ms
step:1829/5100 train_loss:3.5134 train_time:257807ms step_avg:141.73ms
step:1830/5100 train_loss:3.6485 train_time:257948ms step_avg:141.73ms
step:1831/5100 train_loss:3.4224 train_time:258087ms step_avg:141.73ms
step:1832/5100 train_loss:3.5756 train_time:258228ms step_avg:141.73ms
step:1833/5100 train_loss:3.4492 train_time:258368ms step_avg:141.73ms
step:1834/5100 train_loss:3.7681 train_time:258508ms step_avg:141.73ms
step:1835/5100 train_loss:3.6106 train_time:258648ms step_avg:141.72ms
step:1836/5100 train_loss:3.5938 train_time:258787ms step_avg:141.72ms
step:1837/5100 train_loss:3.7170 train_time:258928ms step_avg:141.72ms
step:1838/5100 train_loss:3.5735 train_time:259067ms step_avg:141.72ms
step:1839/5100 train_loss:3.4525 train_time:259207ms step_avg:141.72ms
step:1840/5100 train_loss:3.5700 train_time:259346ms step_avg:141.72ms
step:1841/5100 train_loss:3.4600 train_time:259486ms step_avg:141.72ms
step:1842/5100 train_loss:3.5706 train_time:259627ms step_avg:141.72ms
step:1843/5100 train_loss:3.6254 train_time:259766ms step_avg:141.72ms
step:1844/5100 train_loss:3.3751 train_time:259906ms step_avg:141.72ms
step:1845/5100 train_loss:3.4963 train_time:260047ms step_avg:141.72ms
step:1846/5100 train_loss:3.5541 train_time:260188ms step_avg:141.71ms
step:1847/5100 train_loss:3.4959 train_time:260327ms step_avg:141.71ms
step:1848/5100 train_loss:3.3997 train_time:260466ms step_avg:141.71ms
step:1849/5100 train_loss:3.6584 train_time:260607ms step_avg:141.71ms
step:1850/5100 train_loss:3.4274 train_time:260748ms step_avg:141.71ms
step:1851/5100 train_loss:3.5132 train_time:260887ms step_avg:141.71ms
step:1852/5100 train_loss:3.4745 train_time:261027ms step_avg:141.71ms
step:1853/5100 train_loss:3.6714 train_time:261166ms step_avg:141.71ms
step:1854/5100 train_loss:3.6490 train_time:261307ms step_avg:141.71ms
step:1855/5100 train_loss:3.5204 train_time:261447ms step_avg:141.71ms
step:1856/5100 train_loss:3.4753 train_time:261586ms step_avg:141.70ms
step:1857/5100 train_loss:3.5094 train_time:261726ms step_avg:141.70ms
step:1858/5100 train_loss:3.7480 train_time:261867ms step_avg:141.70ms
step:1859/5100 train_loss:3.5923 train_time:262007ms step_avg:141.70ms
step:1860/5100 train_loss:3.5335 train_time:262147ms step_avg:141.70ms
step:1861/5100 train_loss:3.5723 train_time:262287ms step_avg:141.70ms
step:1862/5100 train_loss:3.4668 train_time:262428ms step_avg:141.70ms
step:1863/5100 train_loss:3.4571 train_time:262566ms step_avg:141.70ms
step:1864/5100 train_loss:3.5380 train_time:262707ms step_avg:141.70ms
step:1865/5100 train_loss:3.5710 train_time:262847ms step_avg:141.70ms
step:1866/5100 train_loss:3.3332 train_time:262986ms step_avg:141.70ms
step:1867/5100 train_loss:3.4672 train_time:263128ms step_avg:141.70ms
step:1868/5100 train_loss:3.4193 train_time:263267ms step_avg:141.69ms
step:1869/5100 train_loss:3.4204 train_time:263408ms step_avg:141.69ms
step:1870/5100 train_loss:3.5841 train_time:263547ms step_avg:141.69ms
step:1871/5100 train_loss:3.5647 train_time:263687ms step_avg:141.69ms
step:1872/5100 train_loss:3.5184 train_time:263831ms step_avg:141.69ms
step:1873/5100 train_loss:3.5247 train_time:263968ms step_avg:141.69ms
step:1874/5100 train_loss:3.4507 train_time:264107ms step_avg:141.69ms
step:1875/5100 train_loss:3.5548 train_time:264248ms step_avg:141.69ms
step:1875/5100 val_loss:3.5305 train_time:264303ms step_avg:141.72ms
step:1876/5100 train_loss:3.5543 train_time:264400ms step_avg:141.69ms
step:1877/5100 train_loss:3.4753 train_time:264544ms step_avg:141.69ms
step:1878/5100 train_loss:3.5250 train_time:264684ms step_avg:141.69ms
step:1879/5100 train_loss:3.6401 train_time:264822ms step_avg:141.69ms
step:1880/5100 train_loss:3.5111 train_time:264961ms step_avg:141.69ms
step:1881/5100 train_loss:3.5742 train_time:265098ms step_avg:141.69ms
step:1882/5100 train_loss:3.4805 train_time:265237ms step_avg:141.69ms
step:1883/5100 train_loss:3.5577 train_time:265381ms step_avg:141.69ms
step:1884/5100 train_loss:3.5534 train_time:265523ms step_avg:141.69ms
step:1885/5100 train_loss:3.3067 train_time:265663ms step_avg:141.69ms
step:1886/5100 train_loss:3.7038 train_time:265802ms step_avg:141.69ms
step:1887/5100 train_loss:3.4369 train_time:265942ms step_avg:141.68ms
step:1888/5100 train_loss:3.4632 train_time:266080ms step_avg:141.68ms
step:1889/5100 train_loss:3.5305 train_time:266220ms step_avg:141.68ms
step:1890/5100 train_loss:3.5802 train_time:266531ms step_avg:141.77ms
step:1891/5100 train_loss:3.3999 train_time:266667ms step_avg:141.77ms
step:1892/5100 train_loss:3.6711 train_time:266806ms step_avg:141.77ms
step:1893/5100 train_loss:3.4269 train_time:266944ms step_avg:141.77ms
step:1894/5100 train_loss:3.5588 train_time:267083ms step_avg:141.76ms
step:1895/5100 train_loss:3.5998 train_time:267220ms step_avg:141.76ms
step:1896/5100 train_loss:3.4000 train_time:267359ms step_avg:141.76ms
step:1897/5100 train_loss:3.5584 train_time:267502ms step_avg:141.76ms
step:1898/5100 train_loss:3.5260 train_time:267645ms step_avg:141.76ms
step:1899/5100 train_loss:3.6007 train_time:267785ms step_avg:141.76ms
step:1900/5100 train_loss:3.3783 train_time:268087ms step_avg:141.85ms
step:1901/5100 train_loss:3.6199 train_time:268226ms step_avg:141.84ms
step:1902/5100 train_loss:3.5053 train_time:268366ms step_avg:141.84ms
step:1903/5100 train_loss:3.6672 train_time:268504ms step_avg:141.84ms
step:1904/5100 train_loss:3.4756 train_time:268642ms step_avg:141.84ms
step:1905/5100 train_loss:3.7491 train_time:268779ms step_avg:141.84ms
step:1906/5100 train_loss:3.4821 train_time:268917ms step_avg:141.83ms
step:1907/5100 train_loss:3.4789 train_time:269063ms step_avg:141.84ms
step:1908/5100 train_loss:3.5578 train_time:269205ms step_avg:141.84ms
step:1909/5100 train_loss:3.4406 train_time:269345ms step_avg:141.84ms
step:1910/5100 train_loss:3.5032 train_time:269484ms step_avg:141.83ms
step:1911/5100 train_loss:3.5985 train_time:269622ms step_avg:141.83ms
step:1912/5100 train_loss:3.5192 train_time:269760ms step_avg:141.83ms
step:1913/5100 train_loss:3.3928 train_time:269900ms step_avg:141.83ms
step:1914/5100 train_loss:3.2745 train_time:270042ms step_avg:141.83ms
step:1915/5100 train_loss:3.4688 train_time:270184ms step_avg:141.83ms
step:1916/5100 train_loss:3.6904 train_time:270325ms step_avg:141.83ms
step:1917/5100 train_loss:3.6832 train_time:270464ms step_avg:141.83ms
step:1918/5100 train_loss:3.6366 train_time:270602ms step_avg:141.83ms
step:1919/5100 train_loss:3.4660 train_time:270742ms step_avg:141.82ms
step:1920/5100 train_loss:3.7100 train_time:270881ms step_avg:141.82ms
step:1921/5100 train_loss:3.5303 train_time:271021ms step_avg:141.82ms
step:1922/5100 train_loss:3.4613 train_time:271165ms step_avg:141.82ms
step:1923/5100 train_loss:3.6430 train_time:271303ms step_avg:141.82ms
step:1924/5100 train_loss:3.6117 train_time:271443ms step_avg:141.82ms
step:1925/5100 train_loss:3.4498 train_time:271582ms step_avg:141.82ms
step:1926/5100 train_loss:3.4847 train_time:271721ms step_avg:141.82ms
step:1927/5100 train_loss:3.3942 train_time:271861ms step_avg:141.82ms
step:1928/5100 train_loss:3.4963 train_time:272000ms step_avg:141.81ms
step:1929/5100 train_loss:3.3573 train_time:272142ms step_avg:141.81ms
step:1930/5100 train_loss:3.4655 train_time:272282ms step_avg:141.81ms
step:1931/5100 train_loss:3.6055 train_time:272422ms step_avg:141.81ms
step:1932/5100 train_loss:3.4735 train_time:272562ms step_avg:141.81ms
step:1933/5100 train_loss:3.6091 train_time:272701ms step_avg:141.81ms
step:1934/5100 train_loss:3.4886 train_time:272842ms step_avg:141.81ms
step:1935/5100 train_loss:3.5309 train_time:272981ms step_avg:141.81ms
step:1936/5100 train_loss:3.5727 train_time:273121ms step_avg:141.81ms
step:1937/5100 train_loss:3.5286 train_time:273262ms step_avg:141.81ms
step:1938/5100 train_loss:3.5503 train_time:273402ms step_avg:141.81ms
step:1939/5100 train_loss:3.4836 train_time:273542ms step_avg:141.80ms
step:1940/5100 train_loss:3.5766 train_time:273682ms step_avg:141.80ms
step:1941/5100 train_loss:3.6056 train_time:273821ms step_avg:141.80ms
step:1942/5100 train_loss:3.4450 train_time:273962ms step_avg:141.80ms
step:1943/5100 train_loss:3.4849 train_time:274104ms step_avg:141.80ms
step:1944/5100 train_loss:3.5494 train_time:274244ms step_avg:141.80ms
step:1945/5100 train_loss:3.3930 train_time:274382ms step_avg:141.80ms
step:1946/5100 train_loss:3.6646 train_time:274522ms step_avg:141.80ms
step:1947/5100 train_loss:3.5373 train_time:274662ms step_avg:141.80ms
step:1948/5100 train_loss:3.5129 train_time:274801ms step_avg:141.80ms
step:1949/5100 train_loss:3.5176 train_time:274942ms step_avg:141.80ms
step:1950/5100 train_loss:3.3949 train_time:275083ms step_avg:141.80ms
step:1951/5100 train_loss:3.5150 train_time:275225ms step_avg:141.80ms
step:1952/5100 train_loss:3.3581 train_time:275362ms step_avg:141.79ms
step:1953/5100 train_loss:3.5725 train_time:275503ms step_avg:141.79ms
step:1954/5100 train_loss:3.5696 train_time:275642ms step_avg:141.79ms
step:1955/5100 train_loss:3.5197 train_time:275782ms step_avg:141.79ms
step:1956/5100 train_loss:3.4097 train_time:275921ms step_avg:141.79ms
step:1957/5100 train_loss:3.5006 train_time:276065ms step_avg:141.79ms
step:1958/5100 train_loss:3.6850 train_time:276201ms step_avg:141.79ms
step:1959/5100 train_loss:3.6027 train_time:276342ms step_avg:141.79ms
step:1960/5100 train_loss:3.6237 train_time:276481ms step_avg:141.79ms
step:1961/5100 train_loss:3.4202 train_time:276621ms step_avg:141.78ms
step:1962/5100 train_loss:3.5440 train_time:276763ms step_avg:141.78ms
step:1963/5100 train_loss:3.5948 train_time:276902ms step_avg:141.78ms
step:1964/5100 train_loss:3.5383 train_time:277044ms step_avg:141.78ms
step:1965/5100 train_loss:3.4502 train_time:277182ms step_avg:141.78ms
step:1966/5100 train_loss:3.8561 train_time:277322ms step_avg:141.78ms
step:1967/5100 train_loss:3.4642 train_time:277462ms step_avg:141.78ms
step:1968/5100 train_loss:3.5072 train_time:277601ms step_avg:141.78ms
step:1969/5100 train_loss:3.5515 train_time:277742ms step_avg:141.78ms
step:1970/5100 train_loss:3.5143 train_time:277882ms step_avg:141.78ms
step:1971/5100 train_loss:3.4032 train_time:278022ms step_avg:141.78ms
step:1972/5100 train_loss:3.3822 train_time:278162ms step_avg:141.77ms
step:1973/5100 train_loss:3.5045 train_time:278302ms step_avg:141.77ms
step:1974/5100 train_loss:3.4732 train_time:278443ms step_avg:141.77ms
step:1975/5100 train_loss:3.4450 train_time:278582ms step_avg:141.77ms
step:1976/5100 train_loss:3.6034 train_time:278722ms step_avg:141.77ms
step:1977/5100 train_loss:3.4745 train_time:278862ms step_avg:141.77ms
step:1978/5100 train_loss:3.8460 train_time:279002ms step_avg:141.77ms
step:1979/5100 train_loss:3.5282 train_time:279142ms step_avg:141.77ms
step:1980/5100 train_loss:3.5211 train_time:279282ms step_avg:141.77ms
step:1981/5100 train_loss:3.5288 train_time:279422ms step_avg:141.77ms
step:1982/5100 train_loss:3.5541 train_time:279562ms step_avg:141.77ms
step:1983/5100 train_loss:3.4881 train_time:279705ms step_avg:141.77ms
step:1984/5100 train_loss:3.4520 train_time:279842ms step_avg:141.76ms
step:1985/5100 train_loss:3.4971 train_time:279982ms step_avg:141.76ms
step:1986/5100 train_loss:3.5696 train_time:280123ms step_avg:141.76ms
step:1987/5100 train_loss:3.5448 train_time:280263ms step_avg:141.76ms
step:1988/5100 train_loss:3.5109 train_time:280402ms step_avg:141.76ms
step:1989/5100 train_loss:3.5912 train_time:280542ms step_avg:141.76ms
step:1990/5100 train_loss:3.6300 train_time:280681ms step_avg:141.76ms
step:1991/5100 train_loss:3.4062 train_time:280824ms step_avg:141.76ms
step:1992/5100 train_loss:3.4054 train_time:280962ms step_avg:141.76ms
step:1993/5100 train_loss:3.5881 train_time:281101ms step_avg:141.76ms
step:1994/5100 train_loss:3.4146 train_time:281245ms step_avg:141.76ms
step:1995/5100 train_loss:3.4980 train_time:281382ms step_avg:141.75ms
step:1996/5100 train_loss:3.5720 train_time:281525ms step_avg:141.75ms
step:1997/5100 train_loss:3.4308 train_time:281661ms step_avg:141.75ms
step:1998/5100 train_loss:3.5429 train_time:281801ms step_avg:141.75ms
step:1999/5100 train_loss:3.5388 train_time:281942ms step_avg:141.75ms
step:2000/5100 train_loss:3.4600 train_time:282081ms step_avg:141.75ms
step:2000/5100 val_loss:3.5168 train_time:282137ms step_avg:141.78ms
step:2001/5100 train_loss:3.6123 train_time:282232ms step_avg:141.75ms
step:2002/5100 train_loss:3.5519 train_time:282379ms step_avg:141.76ms
step:2003/5100 train_loss:3.6376 train_time:282516ms step_avg:141.75ms
step:2004/5100 train_loss:3.5582 train_time:282655ms step_avg:141.75ms
step:2005/5100 train_loss:3.5708 train_time:282793ms step_avg:141.75ms
step:2006/5100 train_loss:3.4555 train_time:282934ms step_avg:141.75ms
step:2007/5100 train_loss:3.4831 train_time:283069ms step_avg:141.75ms
step:2008/5100 train_loss:3.5239 train_time:283208ms step_avg:141.75ms
step:2009/5100 train_loss:3.5663 train_time:283352ms step_avg:141.75ms
step:2010/5100 train_loss:3.4646 train_time:283493ms step_avg:141.75ms
step:2011/5100 train_loss:3.5517 train_time:283634ms step_avg:141.75ms
step:2012/5100 train_loss:3.5259 train_time:283772ms step_avg:141.74ms
step:2013/5100 train_loss:3.5307 train_time:283910ms step_avg:141.74ms
step:2014/5100 train_loss:3.4471 train_time:284049ms step_avg:141.74ms
step:2015/5100 train_loss:3.4919 train_time:284189ms step_avg:141.74ms
step:2016/5100 train_loss:3.5056 train_time:284330ms step_avg:141.74ms
step:2017/5100 train_loss:3.6377 train_time:284472ms step_avg:141.74ms
step:2018/5100 train_loss:3.4849 train_time:284611ms step_avg:141.74ms
step:2019/5100 train_loss:3.6368 train_time:284751ms step_avg:141.74ms
step:2020/5100 train_loss:3.6511 train_time:284888ms step_avg:141.74ms
step:2021/5100 train_loss:3.3615 train_time:285028ms step_avg:141.73ms
step:2022/5100 train_loss:3.5927 train_time:285171ms step_avg:141.73ms
step:2023/5100 train_loss:3.5127 train_time:285309ms step_avg:141.73ms
step:2024/5100 train_loss:3.6092 train_time:285451ms step_avg:141.73ms
step:2025/5100 train_loss:3.6571 train_time:285590ms step_avg:141.73ms
step:2026/5100 train_loss:3.4376 train_time:285730ms step_avg:141.73ms
step:2027/5100 train_loss:3.4838 train_time:285870ms step_avg:141.73ms
step:2028/5100 train_loss:3.3854 train_time:286009ms step_avg:141.73ms
step:2029/5100 train_loss:3.4896 train_time:286149ms step_avg:141.73ms
step:2030/5100 train_loss:3.4185 train_time:286289ms step_avg:141.73ms
step:2031/5100 train_loss:3.5074 train_time:286430ms step_avg:141.73ms
step:2032/5100 train_loss:3.5040 train_time:286569ms step_avg:141.73ms
step:2033/5100 train_loss:3.5174 train_time:286709ms step_avg:141.72ms
step:2034/5100 train_loss:3.4155 train_time:286851ms step_avg:141.72ms
step:2035/5100 train_loss:3.5789 train_time:286992ms step_avg:141.72ms
step:2036/5100 train_loss:3.5724 train_time:287129ms step_avg:141.72ms
step:2037/5100 train_loss:3.5638 train_time:287269ms step_avg:141.72ms
step:2038/5100 train_loss:3.4315 train_time:287410ms step_avg:141.72ms
step:2039/5100 train_loss:3.6853 train_time:287550ms step_avg:141.72ms
step:2040/5100 train_loss:3.5308 train_time:287690ms step_avg:141.72ms
step:2041/5100 train_loss:3.5475 train_time:287829ms step_avg:141.72ms
step:2042/5100 train_loss:3.4898 train_time:287969ms step_avg:141.72ms
step:2043/5100 train_loss:3.3910 train_time:288108ms step_avg:141.72ms
step:2044/5100 train_loss:3.5149 train_time:288250ms step_avg:141.72ms
step:2045/5100 train_loss:3.5222 train_time:288389ms step_avg:141.71ms
step:2046/5100 train_loss:3.3817 train_time:288530ms step_avg:141.71ms
step:2047/5100 train_loss:3.4519 train_time:288670ms step_avg:141.71ms
step:2048/5100 train_loss:3.5372 train_time:288809ms step_avg:141.71ms
step:2049/5100 train_loss:3.4852 train_time:288949ms step_avg:141.71ms
step:2050/5100 train_loss:3.5264 train_time:289089ms step_avg:141.71ms
step:2051/5100 train_loss:3.6765 train_time:289229ms step_avg:141.71ms
step:2052/5100 train_loss:3.5435 train_time:289370ms step_avg:141.71ms
step:2053/5100 train_loss:3.4915 train_time:289509ms step_avg:141.71ms
step:2054/5100 train_loss:3.4688 train_time:289652ms step_avg:141.71ms
step:2055/5100 train_loss:3.3364 train_time:289790ms step_avg:141.71ms
step:2056/5100 train_loss:3.4536 train_time:289930ms step_avg:141.71ms
step:2057/5100 train_loss:3.6311 train_time:290070ms step_avg:141.70ms
step:2058/5100 train_loss:3.6495 train_time:290209ms step_avg:141.70ms
step:2059/5100 train_loss:3.5153 train_time:290349ms step_avg:141.70ms
step:2060/5100 train_loss:3.5585 train_time:290489ms step_avg:141.70ms
step:2061/5100 train_loss:3.5389 train_time:290630ms step_avg:141.70ms
step:2062/5100 train_loss:3.4908 train_time:290770ms step_avg:141.70ms
step:2063/5100 train_loss:3.4067 train_time:290909ms step_avg:141.70ms
step:2064/5100 train_loss:3.7106 train_time:291050ms step_avg:141.70ms
step:2065/5100 train_loss:3.5784 train_time:291190ms step_avg:141.70ms
step:2066/5100 train_loss:3.5265 train_time:291329ms step_avg:141.70ms
step:2067/5100 train_loss:3.5636 train_time:291469ms step_avg:141.70ms
step:2068/5100 train_loss:3.4706 train_time:291609ms step_avg:141.70ms
step:2069/5100 train_loss:3.5273 train_time:291750ms step_avg:141.69ms
step:2070/5100 train_loss:3.6539 train_time:291889ms step_avg:141.69ms
step:2071/5100 train_loss:3.6571 train_time:292030ms step_avg:141.69ms
step:2072/5100 train_loss:3.5095 train_time:292170ms step_avg:141.69ms
step:2073/5100 train_loss:3.5424 train_time:292309ms step_avg:141.69ms
step:2074/5100 train_loss:3.4297 train_time:292449ms step_avg:141.69ms
step:2075/5100 train_loss:3.9574 train_time:292590ms step_avg:141.69ms
step:2076/5100 train_loss:3.3832 train_time:292729ms step_avg:141.69ms
step:2077/5100 train_loss:3.5569 train_time:292871ms step_avg:141.69ms
step:2078/5100 train_loss:3.4481 train_time:293014ms step_avg:141.69ms
step:2079/5100 train_loss:3.4260 train_time:293324ms step_avg:141.77ms
step:2080/5100 train_loss:3.5095 train_time:293457ms step_avg:141.77ms
step:2081/5100 train_loss:3.7720 train_time:293595ms step_avg:141.77ms
step:2082/5100 train_loss:3.3932 train_time:293734ms step_avg:141.76ms
step:2083/5100 train_loss:3.7332 train_time:293872ms step_avg:141.76ms
step:2084/5100 train_loss:3.4369 train_time:294009ms step_avg:141.76ms
step:2085/5100 train_loss:3.4177 train_time:294149ms step_avg:141.76ms
step:2086/5100 train_loss:3.6704 train_time:294290ms step_avg:141.76ms
step:2087/5100 train_loss:3.5985 train_time:294433ms step_avg:141.76ms
step:2088/5100 train_loss:3.5800 train_time:294572ms step_avg:141.76ms
step:2089/5100 train_loss:3.6364 train_time:294711ms step_avg:141.76ms
step:2090/5100 train_loss:3.5566 train_time:295017ms step_avg:141.84ms
step:2091/5100 train_loss:3.5498 train_time:295156ms step_avg:141.83ms
step:2092/5100 train_loss:3.4999 train_time:295294ms step_avg:141.83ms
step:2093/5100 train_loss:3.5743 train_time:295432ms step_avg:141.83ms
step:2094/5100 train_loss:3.4749 train_time:295570ms step_avg:141.83ms
step:2095/5100 train_loss:3.2644 train_time:295708ms step_avg:141.83ms
step:2096/5100 train_loss:3.4948 train_time:295848ms step_avg:141.83ms
step:2097/5100 train_loss:3.6673 train_time:295995ms step_avg:141.83ms
step:2098/5100 train_loss:3.4970 train_time:296135ms step_avg:141.83ms
step:2099/5100 train_loss:3.3800 train_time:296272ms step_avg:141.82ms
step:2100/5100 train_loss:3.4867 train_time:296409ms step_avg:141.82ms
step:2101/5100 train_loss:3.4476 train_time:296549ms step_avg:141.82ms
step:2102/5100 train_loss:3.5892 train_time:296688ms step_avg:141.82ms
step:2103/5100 train_loss:3.4234 train_time:296827ms step_avg:141.82ms
step:2104/5100 train_loss:3.3872 train_time:296969ms step_avg:141.82ms
step:2105/5100 train_loss:3.6450 train_time:297111ms step_avg:141.82ms
step:2106/5100 train_loss:3.3792 train_time:297251ms step_avg:141.82ms
step:2107/5100 train_loss:3.7765 train_time:297389ms step_avg:141.82ms
step:2108/5100 train_loss:3.6109 train_time:297529ms step_avg:141.82ms
step:2109/5100 train_loss:3.5181 train_time:297669ms step_avg:141.81ms
step:2110/5100 train_loss:3.5292 train_time:297807ms step_avg:141.81ms
step:2111/5100 train_loss:3.3575 train_time:297950ms step_avg:141.81ms
step:2112/5100 train_loss:3.8399 train_time:298091ms step_avg:141.81ms
step:2113/5100 train_loss:3.5363 train_time:298231ms step_avg:141.81ms
step:2114/5100 train_loss:3.4590 train_time:298370ms step_avg:141.81ms
step:2115/5100 train_loss:3.5774 train_time:298509ms step_avg:141.81ms
step:2116/5100 train_loss:3.5285 train_time:298650ms step_avg:141.81ms
step:2117/5100 train_loss:3.5213 train_time:298789ms step_avg:141.81ms
step:2118/5100 train_loss:3.5775 train_time:298933ms step_avg:141.81ms
step:2119/5100 train_loss:3.4334 train_time:299071ms step_avg:141.81ms
step:2120/5100 train_loss:3.4908 train_time:299212ms step_avg:141.81ms
step:2121/5100 train_loss:3.1982 train_time:299350ms step_avg:141.81ms
step:2122/5100 train_loss:3.3952 train_time:299490ms step_avg:141.80ms
step:2123/5100 train_loss:3.5624 train_time:299630ms step_avg:141.80ms
step:2124/5100 train_loss:3.4791 train_time:299769ms step_avg:141.80ms
step:2125/5100 train_loss:3.6313 train_time:299909ms step_avg:141.80ms
step:2125/5100 val_loss:3.5054 train_time:299965ms step_avg:141.83ms
step:2126/5100 train_loss:3.4920 train_time:300059ms step_avg:141.80ms
step:2127/5100 train_loss:3.6041 train_time:300210ms step_avg:141.81ms
step:2128/5100 train_loss:3.5846 train_time:300347ms step_avg:141.81ms
step:2129/5100 train_loss:3.4460 train_time:300487ms step_avg:141.81ms
step:2130/5100 train_loss:3.4279 train_time:300626ms step_avg:141.80ms
step:2131/5100 train_loss:3.4548 train_time:300766ms step_avg:141.80ms
step:2132/5100 train_loss:3.6091 train_time:300907ms step_avg:141.80ms
step:2133/5100 train_loss:3.4904 train_time:301045ms step_avg:141.80ms
step:2134/5100 train_loss:3.3951 train_time:301189ms step_avg:141.80ms
step:2135/5100 train_loss:3.4543 train_time:301331ms step_avg:141.80ms
step:2136/5100 train_loss:3.5840 train_time:301472ms step_avg:141.80ms
step:2137/5100 train_loss:3.5994 train_time:301610ms step_avg:141.80ms
step:2138/5100 train_loss:3.5442 train_time:301749ms step_avg:141.80ms
step:2139/5100 train_loss:3.5314 train_time:301888ms step_avg:141.80ms
step:2140/5100 train_loss:3.5133 train_time:302028ms step_avg:141.80ms
step:2141/5100 train_loss:3.6048 train_time:302172ms step_avg:141.80ms
step:2142/5100 train_loss:3.9059 train_time:302313ms step_avg:141.80ms
step:2143/5100 train_loss:3.4340 train_time:302451ms step_avg:141.80ms
step:2144/5100 train_loss:3.4655 train_time:302591ms step_avg:141.80ms
step:2145/5100 train_loss:3.5039 train_time:302730ms step_avg:141.79ms
step:2146/5100 train_loss:3.6325 train_time:302870ms step_avg:141.79ms
step:2147/5100 train_loss:3.5646 train_time:303009ms step_avg:141.79ms
step:2148/5100 train_loss:3.9713 train_time:303151ms step_avg:141.79ms
step:2149/5100 train_loss:3.4862 train_time:303292ms step_avg:141.79ms
step:2150/5100 train_loss:3.4683 train_time:303434ms step_avg:141.79ms
step:2151/5100 train_loss:3.5234 train_time:303570ms step_avg:141.79ms
step:2152/5100 train_loss:3.5537 train_time:303709ms step_avg:141.79ms
step:2153/5100 train_loss:3.5138 train_time:303850ms step_avg:141.79ms
step:2154/5100 train_loss:3.4467 train_time:303989ms step_avg:141.79ms
step:2155/5100 train_loss:3.6601 train_time:304130ms step_avg:141.79ms
step:2156/5100 train_loss:3.2815 train_time:304272ms step_avg:141.79ms
step:2157/5100 train_loss:3.4430 train_time:304411ms step_avg:141.78ms
step:2158/5100 train_loss:3.5779 train_time:304553ms step_avg:141.78ms
step:2159/5100 train_loss:3.5092 train_time:304690ms step_avg:141.78ms
step:2160/5100 train_loss:3.6705 train_time:304830ms step_avg:141.78ms
step:2161/5100 train_loss:3.5891 train_time:304971ms step_avg:141.78ms
step:2162/5100 train_loss:3.5108 train_time:305111ms step_avg:141.78ms
step:2163/5100 train_loss:3.4829 train_time:305252ms step_avg:141.78ms
step:2164/5100 train_loss:3.4852 train_time:305391ms step_avg:141.78ms
step:2165/5100 train_loss:3.5664 train_time:305530ms step_avg:141.78ms
step:2166/5100 train_loss:3.5895 train_time:305674ms step_avg:141.78ms
step:2167/5100 train_loss:3.5190 train_time:305809ms step_avg:141.78ms
step:2168/5100 train_loss:3.4116 train_time:305950ms step_avg:141.77ms
step:2169/5100 train_loss:3.4986 train_time:306090ms step_avg:141.77ms
step:2170/5100 train_loss:3.5391 train_time:306231ms step_avg:141.77ms
step:2171/5100 train_loss:3.6597 train_time:306372ms step_avg:141.77ms
step:2172/5100 train_loss:3.4580 train_time:306511ms step_avg:141.77ms
step:2173/5100 train_loss:3.4440 train_time:306651ms step_avg:141.77ms
step:2174/5100 train_loss:3.4508 train_time:306791ms step_avg:141.77ms
step:2175/5100 train_loss:3.5074 train_time:306931ms step_avg:141.77ms
step:2176/5100 train_loss:3.4691 train_time:307071ms step_avg:141.77ms
step:2177/5100 train_loss:3.4410 train_time:307210ms step_avg:141.77ms
step:2178/5100 train_loss:3.6579 train_time:307355ms step_avg:141.77ms
step:2179/5100 train_loss:3.4825 train_time:307491ms step_avg:141.77ms
step:2180/5100 train_loss:3.4908 train_time:307631ms step_avg:141.77ms
step:2181/5100 train_loss:3.5540 train_time:307771ms step_avg:141.76ms
step:2182/5100 train_loss:3.5271 train_time:307912ms step_avg:141.76ms
step:2183/5100 train_loss:3.5076 train_time:308052ms step_avg:141.76ms
step:2184/5100 train_loss:3.3922 train_time:308191ms step_avg:141.76ms
step:2185/5100 train_loss:3.5651 train_time:308331ms step_avg:141.76ms
step:2186/5100 train_loss:3.7386 train_time:308472ms step_avg:141.76ms
step:2187/5100 train_loss:3.3777 train_time:308610ms step_avg:141.76ms
step:2188/5100 train_loss:3.4218 train_time:308754ms step_avg:141.76ms
step:2189/5100 train_loss:3.2708 train_time:308890ms step_avg:141.76ms
step:2190/5100 train_loss:3.4254 train_time:309031ms step_avg:141.76ms
step:2191/5100 train_loss:3.5677 train_time:309171ms step_avg:141.76ms
step:2192/5100 train_loss:3.5064 train_time:309310ms step_avg:141.76ms
step:2193/5100 train_loss:3.7395 train_time:309452ms step_avg:141.76ms
step:2194/5100 train_loss:3.5077 train_time:309590ms step_avg:141.75ms
step:2195/5100 train_loss:3.5722 train_time:309730ms step_avg:141.75ms
step:2196/5100 train_loss:3.5073 train_time:309871ms step_avg:141.75ms
step:2197/5100 train_loss:3.4299 train_time:310011ms step_avg:141.75ms
step:2198/5100 train_loss:3.5116 train_time:310151ms step_avg:141.75ms
step:2199/5100 train_loss:3.4545 train_time:310290ms step_avg:141.75ms
step:2200/5100 train_loss:3.4590 train_time:310431ms step_avg:141.75ms
step:2201/5100 train_loss:3.5096 train_time:310571ms step_avg:141.75ms
step:2202/5100 train_loss:3.4889 train_time:310710ms step_avg:141.75ms
step:2203/5100 train_loss:3.4646 train_time:310851ms step_avg:141.75ms
step:2204/5100 train_loss:3.9685 train_time:310991ms step_avg:141.75ms
step:2205/5100 train_loss:3.3818 train_time:311131ms step_avg:141.75ms
step:2206/5100 train_loss:3.5075 train_time:311275ms step_avg:141.75ms
step:2207/5100 train_loss:3.5188 train_time:311411ms step_avg:141.74ms
step:2208/5100 train_loss:3.5397 train_time:311551ms step_avg:141.74ms
step:2209/5100 train_loss:3.4308 train_time:311691ms step_avg:141.74ms
step:2210/5100 train_loss:3.5095 train_time:311831ms step_avg:141.74ms
step:2211/5100 train_loss:3.5263 train_time:311971ms step_avg:141.74ms
step:2212/5100 train_loss:3.5234 train_time:312110ms step_avg:141.74ms
step:2213/5100 train_loss:3.5439 train_time:312251ms step_avg:141.74ms
step:2214/5100 train_loss:3.4060 train_time:312391ms step_avg:141.74ms
step:2215/5100 train_loss:3.4639 train_time:312531ms step_avg:141.74ms
step:2216/5100 train_loss:3.6047 train_time:312671ms step_avg:141.74ms
step:2217/5100 train_loss:3.5624 train_time:312810ms step_avg:141.74ms
step:2218/5100 train_loss:3.5152 train_time:312952ms step_avg:141.74ms
step:2219/5100 train_loss:3.5287 train_time:313090ms step_avg:141.73ms
step:2220/5100 train_loss:3.4420 train_time:313231ms step_avg:141.73ms
step:2221/5100 train_loss:3.6947 train_time:313371ms step_avg:141.73ms
step:2222/5100 train_loss:3.5782 train_time:313511ms step_avg:141.73ms
step:2223/5100 train_loss:3.6048 train_time:313652ms step_avg:141.73ms
step:2224/5100 train_loss:3.4940 train_time:313791ms step_avg:141.73ms
step:2225/5100 train_loss:3.6092 train_time:313931ms step_avg:141.73ms
step:2226/5100 train_loss:3.3604 train_time:314072ms step_avg:141.73ms
step:2227/5100 train_loss:3.6309 train_time:314211ms step_avg:141.73ms
step:2228/5100 train_loss:3.5687 train_time:314354ms step_avg:141.73ms
step:2229/5100 train_loss:3.3736 train_time:314491ms step_avg:141.73ms
step:2230/5100 train_loss:3.7142 train_time:314630ms step_avg:141.73ms
step:2231/5100 train_loss:3.4111 train_time:314771ms step_avg:141.72ms
step:2232/5100 train_loss:3.8739 train_time:314911ms step_avg:141.72ms
step:2233/5100 train_loss:3.5630 train_time:315051ms step_avg:141.72ms
step:2234/5100 train_loss:3.5059 train_time:315191ms step_avg:141.72ms
step:2235/5100 train_loss:3.5543 train_time:315331ms step_avg:141.72ms
step:2236/5100 train_loss:3.3246 train_time:315471ms step_avg:141.72ms
step:2237/5100 train_loss:3.3289 train_time:315610ms step_avg:141.72ms
step:2238/5100 train_loss:3.5512 train_time:315753ms step_avg:141.72ms
step:2239/5100 train_loss:3.6462 train_time:315891ms step_avg:141.72ms
step:2240/5100 train_loss:3.3650 train_time:316031ms step_avg:141.72ms
step:2241/5100 train_loss:3.4400 train_time:316172ms step_avg:141.72ms
step:2242/5100 train_loss:3.6161 train_time:316311ms step_avg:141.72ms
step:2243/5100 train_loss:3.5833 train_time:316451ms step_avg:141.72ms
step:2244/5100 train_loss:3.4395 train_time:316590ms step_avg:141.71ms
step:2245/5100 train_loss:3.5070 train_time:316731ms step_avg:141.71ms
step:2246/5100 train_loss:3.5294 train_time:316873ms step_avg:141.71ms
step:2247/5100 train_loss:3.3673 train_time:317014ms step_avg:141.71ms
step:2248/5100 train_loss:3.3833 train_time:317151ms step_avg:141.71ms
step:2249/5100 train_loss:3.6430 train_time:317291ms step_avg:141.71ms
step:2250/5100 train_loss:3.3757 train_time:317431ms step_avg:141.71ms
step:2250/5100 val_loss:3.4936 train_time:317487ms step_avg:141.74ms
step:2251/5100 train_loss:3.3682 train_time:317585ms step_avg:141.72ms
step:2252/5100 train_loss:3.4404 train_time:317729ms step_avg:141.72ms
step:2253/5100 train_loss:3.4163 train_time:317868ms step_avg:141.72ms
step:2254/5100 train_loss:3.4648 train_time:318006ms step_avg:141.71ms
step:2255/5100 train_loss:3.5218 train_time:318145ms step_avg:141.71ms
step:2256/5100 train_loss:3.3999 train_time:318284ms step_avg:141.71ms
step:2257/5100 train_loss:3.6824 train_time:318422ms step_avg:141.71ms
step:2258/5100 train_loss:3.5575 train_time:318567ms step_avg:141.71ms
step:2259/5100 train_loss:3.8691 train_time:318710ms step_avg:141.71ms
step:2260/5100 train_loss:3.5610 train_time:318849ms step_avg:141.71ms
step:2261/5100 train_loss:3.6152 train_time:318989ms step_avg:141.71ms
step:2262/5100 train_loss:3.5242 train_time:319128ms step_avg:141.71ms
step:2263/5100 train_loss:3.5279 train_time:319267ms step_avg:141.71ms
step:2264/5100 train_loss:3.2813 train_time:319406ms step_avg:141.71ms
step:2265/5100 train_loss:3.4078 train_time:319549ms step_avg:141.71ms
step:2266/5100 train_loss:3.6248 train_time:319690ms step_avg:141.71ms
step:2267/5100 train_loss:3.3597 train_time:319830ms step_avg:141.71ms
step:2268/5100 train_loss:3.4290 train_time:320142ms step_avg:141.78ms
step:2269/5100 train_loss:3.4021 train_time:320278ms step_avg:141.78ms
step:2270/5100 train_loss:3.3717 train_time:320416ms step_avg:141.78ms
step:2271/5100 train_loss:3.7771 train_time:320555ms step_avg:141.78ms
step:2272/5100 train_loss:3.4235 train_time:320693ms step_avg:141.77ms
step:2273/5100 train_loss:3.4304 train_time:320831ms step_avg:141.77ms
step:2274/5100 train_loss:3.5167 train_time:320970ms step_avg:141.77ms
step:2275/5100 train_loss:3.4626 train_time:321110ms step_avg:141.77ms
step:2276/5100 train_loss:3.4768 train_time:321255ms step_avg:141.77ms
step:2277/5100 train_loss:3.3611 train_time:321396ms step_avg:141.77ms
step:2278/5100 train_loss:3.4673 train_time:321536ms step_avg:141.77ms
step:2279/5100 train_loss:3.5840 train_time:321673ms step_avg:141.77ms
step:2280/5100 train_loss:3.3908 train_time:321980ms step_avg:141.84ms
step:2281/5100 train_loss:3.4485 train_time:322119ms step_avg:141.84ms
step:2282/5100 train_loss:3.4670 train_time:322262ms step_avg:141.84ms
step:2283/5100 train_loss:3.6039 train_time:322401ms step_avg:141.84ms
step:2284/5100 train_loss:3.4852 train_time:322540ms step_avg:141.84ms
step:2285/5100 train_loss:3.5042 train_time:322678ms step_avg:141.84ms
step:2286/5100 train_loss:3.5041 train_time:322817ms step_avg:141.84ms
step:2287/5100 train_loss:3.5040 train_time:322963ms step_avg:141.84ms
step:2288/5100 train_loss:3.4544 train_time:323104ms step_avg:141.84ms
step:2289/5100 train_loss:3.5876 train_time:323245ms step_avg:141.84ms
step:2290/5100 train_loss:3.5610 train_time:323385ms step_avg:141.84ms
step:2291/5100 train_loss:3.4470 train_time:323525ms step_avg:141.83ms
step:2292/5100 train_loss:3.7838 train_time:323665ms step_avg:141.83ms
step:2293/5100 train_loss:3.4473 train_time:323805ms step_avg:141.83ms
step:2294/5100 train_loss:3.3955 train_time:323949ms step_avg:141.83ms
step:2295/5100 train_loss:3.5737 train_time:324088ms step_avg:141.83ms
step:2296/5100 train_loss:3.5263 train_time:324229ms step_avg:141.83ms
step:2297/5100 train_loss:3.4872 train_time:324369ms step_avg:141.83ms
step:2298/5100 train_loss:3.8795 train_time:324507ms step_avg:141.83ms
step:2299/5100 train_loss:3.4004 train_time:324648ms step_avg:141.83ms
step:2300/5100 train_loss:3.3957 train_time:324788ms step_avg:141.83ms
step:2301/5100 train_loss:3.7311 train_time:324929ms step_avg:141.83ms
step:2302/5100 train_loss:3.4584 train_time:325069ms step_avg:141.83ms
step:2303/5100 train_loss:3.4759 train_time:325208ms step_avg:141.83ms
step:2304/5100 train_loss:3.4584 train_time:325348ms step_avg:141.83ms
step:2305/5100 train_loss:3.3991 train_time:325488ms step_avg:141.82ms
step:2306/5100 train_loss:3.5564 train_time:325629ms step_avg:141.82ms
step:2307/5100 train_loss:3.4171 train_time:325768ms step_avg:141.82ms
step:2308/5100 train_loss:3.4339 train_time:325908ms step_avg:141.82ms
step:2309/5100 train_loss:3.5719 train_time:326049ms step_avg:141.82ms
step:2310/5100 train_loss:3.5224 train_time:326189ms step_avg:141.82ms
step:2311/5100 train_loss:3.3901 train_time:326328ms step_avg:141.82ms
step:2312/5100 train_loss:3.5115 train_time:326469ms step_avg:141.82ms
step:2313/5100 train_loss:3.6351 train_time:326608ms step_avg:141.82ms
step:2314/5100 train_loss:3.4525 train_time:326749ms step_avg:141.82ms
step:2315/5100 train_loss:3.3772 train_time:326888ms step_avg:141.82ms
step:2316/5100 train_loss:3.4642 train_time:327028ms step_avg:141.82ms
step:2317/5100 train_loss:3.3584 train_time:327168ms step_avg:141.82ms
step:2318/5100 train_loss:3.4566 train_time:327308ms step_avg:141.81ms
step:2319/5100 train_loss:3.4825 train_time:327448ms step_avg:141.81ms
step:2320/5100 train_loss:3.3231 train_time:327588ms step_avg:141.81ms
step:2321/5100 train_loss:3.4618 train_time:327727ms step_avg:141.81ms
step:2322/5100 train_loss:3.5188 train_time:327869ms step_avg:141.81ms
step:2323/5100 train_loss:3.4301 train_time:328007ms step_avg:141.81ms
step:2324/5100 train_loss:3.4719 train_time:328150ms step_avg:141.81ms
step:2325/5100 train_loss:3.3976 train_time:328289ms step_avg:141.81ms
step:2326/5100 train_loss:3.5359 train_time:328430ms step_avg:141.81ms
step:2327/5100 train_loss:3.5435 train_time:328569ms step_avg:141.81ms
step:2328/5100 train_loss:3.3184 train_time:328708ms step_avg:141.81ms
step:2329/5100 train_loss:3.4281 train_time:328848ms step_avg:141.81ms
step:2330/5100 train_loss:3.4571 train_time:328989ms step_avg:141.81ms
step:2331/5100 train_loss:3.4255 train_time:329129ms step_avg:141.80ms
step:2332/5100 train_loss:3.5983 train_time:329275ms step_avg:141.81ms
step:2333/5100 train_loss:3.4949 train_time:329408ms step_avg:141.80ms
step:2334/5100 train_loss:3.4680 train_time:329549ms step_avg:141.80ms
step:2335/5100 train_loss:3.5474 train_time:329688ms step_avg:141.80ms
step:2336/5100 train_loss:3.3916 train_time:329827ms step_avg:141.80ms
step:2337/5100 train_loss:3.5381 train_time:329970ms step_avg:141.80ms
step:2338/5100 train_loss:3.4996 train_time:330108ms step_avg:141.80ms
step:2339/5100 train_loss:3.4497 train_time:330250ms step_avg:141.80ms
step:2340/5100 train_loss:3.5227 train_time:330389ms step_avg:141.80ms
step:2341/5100 train_loss:3.5774 train_time:330529ms step_avg:141.80ms
step:2342/5100 train_loss:3.4432 train_time:330670ms step_avg:141.80ms
step:2343/5100 train_loss:3.4488 train_time:330812ms step_avg:141.80ms
step:2344/5100 train_loss:3.5201 train_time:330948ms step_avg:141.79ms
step:2345/5100 train_loss:3.4638 train_time:331088ms step_avg:141.79ms
step:2346/5100 train_loss:3.5831 train_time:331229ms step_avg:141.79ms
step:2347/5100 train_loss:3.4911 train_time:331369ms step_avg:141.79ms
step:2348/5100 train_loss:3.5965 train_time:331509ms step_avg:141.79ms
step:2349/5100 train_loss:3.5557 train_time:331649ms step_avg:141.79ms
step:2350/5100 train_loss:3.5929 train_time:331789ms step_avg:141.79ms
step:2351/5100 train_loss:3.2874 train_time:331931ms step_avg:141.79ms
step:2352/5100 train_loss:3.4137 train_time:332069ms step_avg:141.79ms
step:2353/5100 train_loss:3.4010 train_time:332209ms step_avg:141.79ms
step:2354/5100 train_loss:3.6204 train_time:332349ms step_avg:141.79ms
step:2355/5100 train_loss:3.4158 train_time:332488ms step_avg:141.79ms
step:2356/5100 train_loss:3.4119 train_time:332629ms step_avg:141.79ms
step:2357/5100 train_loss:3.5644 train_time:332768ms step_avg:141.78ms
step:2358/5100 train_loss:3.4164 train_time:332907ms step_avg:141.78ms
step:2359/5100 train_loss:3.5187 train_time:333048ms step_avg:141.78ms
step:2360/5100 train_loss:3.4175 train_time:333189ms step_avg:141.78ms
step:2361/5100 train_loss:3.4355 train_time:333328ms step_avg:141.78ms
step:2362/5100 train_loss:3.4622 train_time:333469ms step_avg:141.78ms
step:2363/5100 train_loss:3.5302 train_time:333608ms step_avg:141.78ms
step:2364/5100 train_loss:3.4715 train_time:333751ms step_avg:141.78ms
step:2365/5100 train_loss:3.9100 train_time:333888ms step_avg:141.78ms
step:2366/5100 train_loss:3.5367 train_time:334028ms step_avg:141.78ms
step:2367/5100 train_loss:3.6853 train_time:334172ms step_avg:141.78ms
step:2368/5100 train_loss:3.5098 train_time:334307ms step_avg:141.78ms
step:2369/5100 train_loss:3.5060 train_time:334449ms step_avg:141.78ms
step:2370/5100 train_loss:3.5365 train_time:334589ms step_avg:141.77ms
step:2371/5100 train_loss:3.4183 train_time:334729ms step_avg:141.77ms
step:2372/5100 train_loss:3.6495 train_time:334869ms step_avg:141.77ms
step:2373/5100 train_loss:3.4968 train_time:335008ms step_avg:141.77ms
step:2374/5100 train_loss:4.0533 train_time:335149ms step_avg:141.77ms
step:2375/5100 train_loss:3.4766 train_time:335288ms step_avg:141.77ms
step:2375/5100 val_loss:3.4849 train_time:335344ms step_avg:141.79ms
step:2376/5100 train_loss:3.3822 train_time:335438ms step_avg:141.77ms
step:2377/5100 train_loss:3.5487 train_time:335584ms step_avg:141.78ms
step:2378/5100 train_loss:3.5107 train_time:335725ms step_avg:141.78ms
step:2379/5100 train_loss:3.5299 train_time:335863ms step_avg:141.77ms
step:2380/5100 train_loss:3.5114 train_time:336001ms step_avg:141.77ms
step:2381/5100 train_loss:3.4117 train_time:336141ms step_avg:141.77ms
step:2382/5100 train_loss:3.5080 train_time:336279ms step_avg:141.77ms
step:2383/5100 train_loss:3.5276 train_time:336422ms step_avg:141.77ms
step:2384/5100 train_loss:3.4737 train_time:336564ms step_avg:141.77ms
step:2385/5100 train_loss:3.4032 train_time:336709ms step_avg:141.77ms
step:2386/5100 train_loss:3.5148 train_time:336844ms step_avg:141.77ms
step:2387/5100 train_loss:3.4704 train_time:336983ms step_avg:141.77ms
step:2388/5100 train_loss:3.4724 train_time:337122ms step_avg:141.77ms
step:2389/5100 train_loss:3.5096 train_time:337261ms step_avg:141.77ms
step:2390/5100 train_loss:3.4927 train_time:337402ms step_avg:141.77ms
step:2391/5100 train_loss:3.4909 train_time:337545ms step_avg:141.77ms
step:2392/5100 train_loss:3.3641 train_time:337685ms step_avg:141.77ms
step:2393/5100 train_loss:3.5915 train_time:337824ms step_avg:141.76ms
step:2394/5100 train_loss:3.4286 train_time:337964ms step_avg:141.76ms
step:2395/5100 train_loss:3.5250 train_time:338102ms step_avg:141.76ms
step:2396/5100 train_loss:3.6434 train_time:338242ms step_avg:141.76ms
step:2397/5100 train_loss:3.6547 train_time:338383ms step_avg:141.76ms
step:2398/5100 train_loss:3.6033 train_time:338524ms step_avg:141.76ms
step:2399/5100 train_loss:3.5687 train_time:338668ms step_avg:141.76ms
step:2400/5100 train_loss:3.4453 train_time:338805ms step_avg:141.76ms
step:2401/5100 train_loss:3.4405 train_time:338945ms step_avg:141.76ms
step:2402/5100 train_loss:3.5468 train_time:339083ms step_avg:141.76ms
step:2403/5100 train_loss:3.3928 train_time:339223ms step_avg:141.76ms
step:2404/5100 train_loss:3.5248 train_time:339364ms step_avg:141.76ms
step:2405/5100 train_loss:3.7344 train_time:339503ms step_avg:141.75ms
step:2406/5100 train_loss:3.4650 train_time:339644ms step_avg:141.75ms
step:2407/5100 train_loss:3.6090 train_time:339787ms step_avg:141.76ms
step:2408/5100 train_loss:3.4698 train_time:339924ms step_avg:141.75ms
step:2409/5100 train_loss:3.4040 train_time:340064ms step_avg:141.75ms
step:2410/5100 train_loss:3.5423 train_time:340203ms step_avg:141.75ms
step:2411/5100 train_loss:3.3340 train_time:340344ms step_avg:141.75ms
step:2412/5100 train_loss:3.7676 train_time:340484ms step_avg:141.75ms
step:2413/5100 train_loss:3.4472 train_time:340625ms step_avg:141.75ms
step:2414/5100 train_loss:3.5211 train_time:340765ms step_avg:141.75ms
step:2415/5100 train_loss:3.4428 train_time:340905ms step_avg:141.75ms
step:2416/5100 train_loss:3.5156 train_time:341045ms step_avg:141.75ms
step:2417/5100 train_loss:3.3288 train_time:341184ms step_avg:141.75ms
step:2418/5100 train_loss:3.2663 train_time:341323ms step_avg:141.75ms
step:2419/5100 train_loss:3.5579 train_time:341464ms step_avg:141.75ms
step:2420/5100 train_loss:3.4393 train_time:341604ms step_avg:141.74ms
step:2421/5100 train_loss:3.4687 train_time:341746ms step_avg:141.74ms
step:2422/5100 train_loss:3.5725 train_time:341885ms step_avg:141.74ms
step:2423/5100 train_loss:3.6077 train_time:342024ms step_avg:141.74ms
step:2424/5100 train_loss:3.4405 train_time:342164ms step_avg:141.74ms
step:2425/5100 train_loss:3.5274 train_time:342304ms step_avg:141.74ms
step:2426/5100 train_loss:3.5271 train_time:342444ms step_avg:141.74ms
step:2427/5100 train_loss:3.4500 train_time:342585ms step_avg:141.74ms
step:2428/5100 train_loss:3.4102 train_time:342725ms step_avg:141.74ms
step:2429/5100 train_loss:3.5302 train_time:342865ms step_avg:141.74ms
step:2430/5100 train_loss:3.4280 train_time:343004ms step_avg:141.74ms
step:2431/5100 train_loss:3.4852 train_time:343145ms step_avg:141.74ms
step:2432/5100 train_loss:3.5380 train_time:343284ms step_avg:141.74ms
step:2433/5100 train_loss:3.5039 train_time:343424ms step_avg:141.73ms
step:2434/5100 train_loss:3.3779 train_time:343566ms step_avg:141.74ms
step:2435/5100 train_loss:3.3398 train_time:343704ms step_avg:141.73ms
step:2436/5100 train_loss:3.5075 train_time:343844ms step_avg:141.73ms
step:2437/5100 train_loss:3.3735 train_time:343983ms step_avg:141.73ms
step:2438/5100 train_loss:3.4454 train_time:344124ms step_avg:141.73ms
step:2439/5100 train_loss:3.5352 train_time:344264ms step_avg:141.73ms
step:2440/5100 train_loss:3.4520 train_time:344404ms step_avg:141.73ms
step:2441/5100 train_loss:3.5407 train_time:344545ms step_avg:141.73ms
step:2442/5100 train_loss:3.4226 train_time:344685ms step_avg:141.73ms
step:2443/5100 train_loss:3.4922 train_time:344825ms step_avg:141.73ms
step:2444/5100 train_loss:3.3660 train_time:344965ms step_avg:141.73ms
step:2445/5100 train_loss:3.3793 train_time:345104ms step_avg:141.73ms
step:2446/5100 train_loss:3.5484 train_time:345244ms step_avg:141.73ms
step:2447/5100 train_loss:3.4030 train_time:345384ms step_avg:141.73ms
step:2448/5100 train_loss:3.4723 train_time:345525ms step_avg:141.72ms
step:2449/5100 train_loss:3.6385 train_time:345664ms step_avg:141.72ms
step:2450/5100 train_loss:3.4568 train_time:345803ms step_avg:141.72ms
step:2451/5100 train_loss:3.5391 train_time:345944ms step_avg:141.72ms
step:2452/5100 train_loss:3.4421 train_time:346083ms step_avg:141.72ms
step:2453/5100 train_loss:3.5415 train_time:346224ms step_avg:141.72ms
step:2454/5100 train_loss:3.4352 train_time:346364ms step_avg:141.72ms
step:2455/5100 train_loss:3.5669 train_time:346504ms step_avg:141.72ms
step:2456/5100 train_loss:3.4983 train_time:346645ms step_avg:141.72ms
step:2457/5100 train_loss:3.4163 train_time:346956ms step_avg:141.79ms
step:2458/5100 train_loss:3.3475 train_time:347093ms step_avg:141.79ms
step:2459/5100 train_loss:3.4748 train_time:347232ms step_avg:141.79ms
step:2460/5100 train_loss:4.0816 train_time:347370ms step_avg:141.78ms
step:2461/5100 train_loss:3.5395 train_time:347509ms step_avg:141.78ms
step:2462/5100 train_loss:3.3566 train_time:347647ms step_avg:141.78ms
step:2463/5100 train_loss:3.5522 train_time:347787ms step_avg:141.78ms
step:2464/5100 train_loss:3.4713 train_time:347927ms step_avg:141.78ms
step:2465/5100 train_loss:3.6692 train_time:348069ms step_avg:141.78ms
step:2466/5100 train_loss:3.8694 train_time:348209ms step_avg:141.78ms
step:2467/5100 train_loss:3.5826 train_time:348349ms step_avg:141.78ms
step:2468/5100 train_loss:3.4514 train_time:348487ms step_avg:141.78ms
step:2469/5100 train_loss:3.5656 train_time:348628ms step_avg:141.78ms
step:2470/5100 train_loss:3.5825 train_time:348928ms step_avg:141.84ms
step:2471/5100 train_loss:3.3895 train_time:349070ms step_avg:141.84ms
step:2472/5100 train_loss:3.4759 train_time:349206ms step_avg:141.84ms
step:2473/5100 train_loss:3.4729 train_time:349344ms step_avg:141.84ms
step:2474/5100 train_loss:3.6195 train_time:349481ms step_avg:141.83ms
step:2475/5100 train_loss:3.7472 train_time:349620ms step_avg:141.83ms
step:2476/5100 train_loss:3.3411 train_time:349759ms step_avg:141.83ms
step:2477/5100 train_loss:3.5453 train_time:349908ms step_avg:141.84ms
step:2478/5100 train_loss:3.5093 train_time:350051ms step_avg:141.84ms
step:2479/5100 train_loss:3.3474 train_time:350190ms step_avg:141.83ms
step:2480/5100 train_loss:3.3453 train_time:350329ms step_avg:141.83ms
step:2481/5100 train_loss:3.4975 train_time:350468ms step_avg:141.83ms
step:2482/5100 train_loss:3.5060 train_time:350607ms step_avg:141.83ms
step:2483/5100 train_loss:3.5164 train_time:350746ms step_avg:141.83ms
step:2484/5100 train_loss:3.4762 train_time:350888ms step_avg:141.83ms
step:2485/5100 train_loss:3.4902 train_time:351030ms step_avg:141.83ms
step:2486/5100 train_loss:3.3747 train_time:351169ms step_avg:141.83ms
step:2487/5100 train_loss:3.5703 train_time:351309ms step_avg:141.83ms
step:2488/5100 train_loss:3.5294 train_time:351448ms step_avg:141.83ms
step:2489/5100 train_loss:3.4348 train_time:351587ms step_avg:141.83ms
step:2490/5100 train_loss:3.5408 train_time:351726ms step_avg:141.83ms
step:2491/5100 train_loss:3.5935 train_time:351867ms step_avg:141.82ms
step:2492/5100 train_loss:3.6732 train_time:352006ms step_avg:141.82ms
step:2493/5100 train_loss:3.5212 train_time:352147ms step_avg:141.82ms
step:2494/5100 train_loss:3.4445 train_time:352286ms step_avg:141.82ms
step:2495/5100 train_loss:3.5735 train_time:352425ms step_avg:141.82ms
step:2496/5100 train_loss:3.5165 train_time:352565ms step_avg:141.82ms
step:2497/5100 train_loss:3.4291 train_time:352703ms step_avg:141.82ms
step:2498/5100 train_loss:3.5294 train_time:352848ms step_avg:141.82ms
step:2499/5100 train_loss:3.5819 train_time:352990ms step_avg:141.82ms
step:2500/5100 train_loss:3.6026 train_time:353125ms step_avg:141.82ms
step:2500/5100 val_loss:3.4758 train_time:353183ms step_avg:141.84ms
step:2501/5100 train_loss:3.5458 train_time:353279ms step_avg:141.82ms
step:2502/5100 train_loss:3.4981 train_time:353425ms step_avg:141.82ms
step:2503/5100 train_loss:3.5173 train_time:353563ms step_avg:141.82ms
step:2504/5100 train_loss:3.3931 train_time:353702ms step_avg:141.82ms
step:2505/5100 train_loss:3.5800 train_time:353840ms step_avg:141.82ms
step:2506/5100 train_loss:3.5322 train_time:353978ms step_avg:141.82ms
step:2507/5100 train_loss:3.4784 train_time:354116ms step_avg:141.82ms
step:2508/5100 train_loss:3.4815 train_time:354256ms step_avg:141.82ms
step:2509/5100 train_loss:3.4460 train_time:354400ms step_avg:141.82ms
step:2510/5100 train_loss:3.6133 train_time:354540ms step_avg:141.82ms
step:2511/5100 train_loss:3.4374 train_time:354679ms step_avg:141.81ms
step:2512/5100 train_loss:3.4260 train_time:354817ms step_avg:141.81ms
step:2513/5100 train_loss:3.5027 train_time:354957ms step_avg:141.81ms
step:2514/5100 train_loss:3.5323 train_time:355095ms step_avg:141.81ms
step:2515/5100 train_loss:3.4276 train_time:355234ms step_avg:141.81ms
step:2516/5100 train_loss:3.5241 train_time:355377ms step_avg:141.81ms
step:2517/5100 train_loss:3.5159 train_time:355517ms step_avg:141.81ms
step:2518/5100 train_loss:3.3937 train_time:355656ms step_avg:141.81ms
step:2519/5100 train_loss:3.4219 train_time:355796ms step_avg:141.81ms
step:2520/5100 train_loss:3.5415 train_time:355935ms step_avg:141.81ms
step:2521/5100 train_loss:3.5428 train_time:356075ms step_avg:141.81ms
step:2522/5100 train_loss:3.4201 train_time:356214ms step_avg:141.81ms
step:2523/5100 train_loss:3.4020 train_time:356356ms step_avg:141.81ms
step:2524/5100 train_loss:3.5038 train_time:356497ms step_avg:141.80ms
step:2525/5100 train_loss:3.3382 train_time:356636ms step_avg:141.80ms
step:2526/5100 train_loss:3.5638 train_time:356779ms step_avg:141.80ms
step:2527/5100 train_loss:3.4706 train_time:356915ms step_avg:141.80ms
step:2528/5100 train_loss:3.4741 train_time:357055ms step_avg:141.80ms
step:2529/5100 train_loss:3.4588 train_time:357195ms step_avg:141.80ms
step:2530/5100 train_loss:3.4870 train_time:357335ms step_avg:141.80ms
step:2531/5100 train_loss:3.5244 train_time:357476ms step_avg:141.80ms
step:2532/5100 train_loss:3.3430 train_time:357616ms step_avg:141.80ms
step:2533/5100 train_loss:3.5048 train_time:357756ms step_avg:141.80ms
step:2534/5100 train_loss:3.4035 train_time:357895ms step_avg:141.80ms
step:2535/5100 train_loss:3.4281 train_time:358035ms step_avg:141.80ms
step:2536/5100 train_loss:3.4877 train_time:358176ms step_avg:141.80ms
step:2537/5100 train_loss:3.4989 train_time:358315ms step_avg:141.79ms
step:2538/5100 train_loss:3.3238 train_time:358455ms step_avg:141.79ms
step:2539/5100 train_loss:3.6308 train_time:358595ms step_avg:141.79ms
step:2540/5100 train_loss:3.3226 train_time:358736ms step_avg:141.79ms
step:2541/5100 train_loss:3.4977 train_time:358875ms step_avg:141.79ms
step:2542/5100 train_loss:3.2576 train_time:359014ms step_avg:141.79ms
step:2543/5100 train_loss:3.7064 train_time:359156ms step_avg:141.79ms
step:2544/5100 train_loss:3.4709 train_time:359295ms step_avg:141.79ms
step:2545/5100 train_loss:3.6294 train_time:359435ms step_avg:141.79ms
step:2546/5100 train_loss:3.4570 train_time:359575ms step_avg:141.79ms
step:2547/5100 train_loss:3.4495 train_time:359714ms step_avg:141.79ms
step:2548/5100 train_loss:3.4537 train_time:359856ms step_avg:141.79ms
step:2549/5100 train_loss:3.6139 train_time:359994ms step_avg:141.79ms
step:2550/5100 train_loss:3.4711 train_time:360134ms step_avg:141.79ms
step:2551/5100 train_loss:3.4688 train_time:360276ms step_avg:141.78ms
step:2552/5100 train_loss:3.4999 train_time:360415ms step_avg:141.78ms
step:2553/5100 train_loss:3.5123 train_time:360556ms step_avg:141.78ms
step:2554/5100 train_loss:3.4308 train_time:360694ms step_avg:141.78ms
step:2555/5100 train_loss:3.5325 train_time:360834ms step_avg:141.78ms
step:2556/5100 train_loss:3.5858 train_time:360975ms step_avg:141.78ms
step:2557/5100 train_loss:3.5777 train_time:361114ms step_avg:141.78ms
step:2558/5100 train_loss:3.4145 train_time:361256ms step_avg:141.78ms
step:2559/5100 train_loss:3.4181 train_time:361395ms step_avg:141.78ms
step:2560/5100 train_loss:3.4311 train_time:361535ms step_avg:141.78ms
step:2561/5100 train_loss:3.5512 train_time:361676ms step_avg:141.78ms
step:2562/5100 train_loss:3.5826 train_time:361814ms step_avg:141.78ms
step:2563/5100 train_loss:3.4664 train_time:361955ms step_avg:141.78ms
step:2564/5100 train_loss:3.4997 train_time:362096ms step_avg:141.78ms
step:2565/5100 train_loss:3.4137 train_time:362235ms step_avg:141.77ms
step:2566/5100 train_loss:3.4224 train_time:362376ms step_avg:141.77ms
step:2567/5100 train_loss:3.4148 train_time:362515ms step_avg:141.77ms
step:2568/5100 train_loss:3.4689 train_time:362656ms step_avg:141.77ms
step:2569/5100 train_loss:3.6132 train_time:362796ms step_avg:141.77ms
step:2570/5100 train_loss:3.5176 train_time:362936ms step_avg:141.77ms
step:2571/5100 train_loss:3.6013 train_time:363077ms step_avg:141.77ms
step:2572/5100 train_loss:3.3522 train_time:363216ms step_avg:141.77ms
step:2573/5100 train_loss:3.4618 train_time:363356ms step_avg:141.77ms
step:2574/5100 train_loss:3.1267 train_time:363495ms step_avg:141.77ms
step:2575/5100 train_loss:3.3706 train_time:363635ms step_avg:141.77ms
step:2576/5100 train_loss:3.3119 train_time:363774ms step_avg:141.77ms
step:2577/5100 train_loss:3.4274 train_time:363914ms step_avg:141.77ms
step:2578/5100 train_loss:3.4752 train_time:364055ms step_avg:141.77ms
step:2579/5100 train_loss:3.3824 train_time:364195ms step_avg:141.77ms
step:2580/5100 train_loss:3.4367 train_time:364335ms step_avg:141.76ms
step:2581/5100 train_loss:3.3916 train_time:364478ms step_avg:141.77ms
step:2582/5100 train_loss:3.4917 train_time:364615ms step_avg:141.76ms
step:2583/5100 train_loss:3.3740 train_time:364756ms step_avg:141.76ms
step:2584/5100 train_loss:3.5638 train_time:364896ms step_avg:141.76ms
step:2585/5100 train_loss:3.4870 train_time:365039ms step_avg:141.76ms
step:2586/5100 train_loss:3.4883 train_time:365176ms step_avg:141.76ms
step:2587/5100 train_loss:3.6170 train_time:365315ms step_avg:141.76ms
step:2588/5100 train_loss:3.5049 train_time:365456ms step_avg:141.76ms
step:2589/5100 train_loss:3.3614 train_time:365598ms step_avg:141.76ms
step:2590/5100 train_loss:3.5283 train_time:365736ms step_avg:141.76ms
step:2591/5100 train_loss:3.4346 train_time:365876ms step_avg:141.76ms
step:2592/5100 train_loss:3.6420 train_time:366015ms step_avg:141.76ms
step:2593/5100 train_loss:3.5134 train_time:366156ms step_avg:141.76ms
step:2594/5100 train_loss:3.3312 train_time:366295ms step_avg:141.75ms
step:2595/5100 train_loss:3.4051 train_time:366435ms step_avg:141.75ms
step:2596/5100 train_loss:3.8642 train_time:366576ms step_avg:141.75ms
step:2597/5100 train_loss:3.4966 train_time:366714ms step_avg:141.75ms
step:2598/5100 train_loss:3.4935 train_time:366856ms step_avg:141.75ms
step:2599/5100 train_loss:3.3526 train_time:366995ms step_avg:141.75ms
step:2600/5100 train_loss:3.5884 train_time:367135ms step_avg:141.75ms
step:2601/5100 train_loss:3.7514 train_time:367274ms step_avg:141.75ms
step:2602/5100 train_loss:3.3282 train_time:367415ms step_avg:141.75ms
step:2603/5100 train_loss:3.4760 train_time:367555ms step_avg:141.75ms
step:2604/5100 train_loss:3.3011 train_time:367694ms step_avg:141.75ms
step:2605/5100 train_loss:3.5998 train_time:367835ms step_avg:141.75ms
step:2606/5100 train_loss:3.4698 train_time:367975ms step_avg:141.75ms
step:2607/5100 train_loss:3.3698 train_time:368114ms step_avg:141.75ms
step:2608/5100 train_loss:3.3212 train_time:368255ms step_avg:141.75ms
step:2609/5100 train_loss:3.4388 train_time:368395ms step_avg:141.74ms
step:2610/5100 train_loss:3.6214 train_time:368535ms step_avg:141.74ms
step:2611/5100 train_loss:3.4905 train_time:368675ms step_avg:141.74ms
step:2612/5100 train_loss:3.3121 train_time:368815ms step_avg:141.74ms
step:2613/5100 train_loss:3.4199 train_time:368955ms step_avg:141.74ms
step:2614/5100 train_loss:3.5314 train_time:369095ms step_avg:141.74ms
step:2615/5100 train_loss:3.4626 train_time:369235ms step_avg:141.74ms
step:2616/5100 train_loss:3.4505 train_time:369376ms step_avg:141.74ms
step:2617/5100 train_loss:3.4988 train_time:369516ms step_avg:141.74ms
step:2618/5100 train_loss:3.5370 train_time:369656ms step_avg:141.74ms
step:2619/5100 train_loss:3.3843 train_time:369795ms step_avg:141.74ms
step:2620/5100 train_loss:3.5578 train_time:369934ms step_avg:141.74ms
step:2621/5100 train_loss:3.5175 train_time:370075ms step_avg:141.74ms
step:2622/5100 train_loss:3.6487 train_time:370215ms step_avg:141.74ms
step:2623/5100 train_loss:3.5560 train_time:370355ms step_avg:141.74ms
step:2624/5100 train_loss:3.4733 train_time:370495ms step_avg:141.74ms
step:2625/5100 train_loss:3.4351 train_time:370636ms step_avg:141.73ms
step:2625/5100 val_loss:3.4656 train_time:370691ms step_avg:141.76ms
step:2626/5100 train_loss:3.4578 train_time:370786ms step_avg:141.74ms
step:2627/5100 train_loss:3.5119 train_time:370933ms step_avg:141.74ms
step:2628/5100 train_loss:3.3488 train_time:371074ms step_avg:141.74ms
step:2629/5100 train_loss:3.6033 train_time:371213ms step_avg:141.74ms
step:2630/5100 train_loss:3.4858 train_time:371352ms step_avg:141.74ms
step:2631/5100 train_loss:3.5448 train_time:371491ms step_avg:141.74ms
step:2632/5100 train_loss:3.7706 train_time:371630ms step_avg:141.74ms
step:2633/5100 train_loss:3.5072 train_time:371773ms step_avg:141.74ms
step:2634/5100 train_loss:3.4288 train_time:371915ms step_avg:141.74ms
step:2635/5100 train_loss:3.4027 train_time:372057ms step_avg:141.74ms
step:2636/5100 train_loss:3.4458 train_time:372196ms step_avg:141.73ms
step:2637/5100 train_loss:3.2315 train_time:372335ms step_avg:141.73ms
step:2638/5100 train_loss:3.5409 train_time:372481ms step_avg:141.74ms
step:2639/5100 train_loss:3.5203 train_time:372614ms step_avg:141.73ms
step:2640/5100 train_loss:3.4155 train_time:372755ms step_avg:141.73ms
step:2641/5100 train_loss:3.4904 train_time:372896ms step_avg:141.73ms
step:2642/5100 train_loss:3.5291 train_time:373036ms step_avg:141.73ms
step:2643/5100 train_loss:3.3142 train_time:373176ms step_avg:141.73ms
step:2644/5100 train_loss:3.4367 train_time:373314ms step_avg:141.73ms
step:2645/5100 train_loss:3.5075 train_time:373455ms step_avg:141.73ms
step:2646/5100 train_loss:3.4740 train_time:373760ms step_avg:141.79ms
step:2647/5100 train_loss:3.3655 train_time:373897ms step_avg:141.79ms
step:2648/5100 train_loss:3.5931 train_time:374035ms step_avg:141.79ms
step:2649/5100 train_loss:3.8494 train_time:374173ms step_avg:141.79ms
step:2650/5100 train_loss:3.4822 train_time:374311ms step_avg:141.78ms
step:2651/5100 train_loss:3.4493 train_time:374450ms step_avg:141.78ms
step:2652/5100 train_loss:3.5805 train_time:374588ms step_avg:141.78ms
step:2653/5100 train_loss:3.4206 train_time:374734ms step_avg:141.78ms
step:2654/5100 train_loss:3.4024 train_time:374878ms step_avg:141.78ms
step:2655/5100 train_loss:3.4756 train_time:375018ms step_avg:141.78ms
step:2656/5100 train_loss:3.3901 train_time:375157ms step_avg:141.78ms
step:2657/5100 train_loss:3.4291 train_time:375296ms step_avg:141.78ms
step:2658/5100 train_loss:3.3913 train_time:375435ms step_avg:141.78ms
step:2659/5100 train_loss:3.4790 train_time:375575ms step_avg:141.78ms
step:2660/5100 train_loss:3.6231 train_time:375889ms step_avg:141.84ms
step:2661/5100 train_loss:3.4233 train_time:376023ms step_avg:141.84ms
step:2662/5100 train_loss:3.5675 train_time:376162ms step_avg:141.84ms
step:2663/5100 train_loss:3.4388 train_time:376299ms step_avg:141.84ms
step:2664/5100 train_loss:3.4376 train_time:376437ms step_avg:141.84ms
step:2665/5100 train_loss:3.3622 train_time:376576ms step_avg:141.84ms
step:2666/5100 train_loss:3.4133 train_time:376713ms step_avg:141.83ms
step:2667/5100 train_loss:3.4529 train_time:376858ms step_avg:141.84ms
step:2668/5100 train_loss:3.4916 train_time:377000ms step_avg:141.84ms
step:2669/5100 train_loss:3.4046 train_time:377139ms step_avg:141.83ms
step:2670/5100 train_loss:3.4689 train_time:377278ms step_avg:141.83ms
step:2671/5100 train_loss:3.3536 train_time:377415ms step_avg:141.83ms
step:2672/5100 train_loss:3.4188 train_time:377554ms step_avg:141.83ms
step:2673/5100 train_loss:3.4094 train_time:377694ms step_avg:141.83ms
step:2674/5100 train_loss:3.4670 train_time:377835ms step_avg:141.83ms
step:2675/5100 train_loss:3.4881 train_time:377983ms step_avg:141.83ms
step:2676/5100 train_loss:3.4610 train_time:378120ms step_avg:141.83ms
step:2677/5100 train_loss:3.4496 train_time:378259ms step_avg:141.83ms
step:2678/5100 train_loss:3.4906 train_time:378397ms step_avg:141.83ms
step:2679/5100 train_loss:3.5255 train_time:378536ms step_avg:141.83ms
step:2680/5100 train_loss:3.4424 train_time:378676ms step_avg:141.83ms
step:2681/5100 train_loss:3.3654 train_time:378816ms step_avg:141.83ms
step:2682/5100 train_loss:3.4093 train_time:378959ms step_avg:141.83ms
step:2683/5100 train_loss:3.8788 train_time:379099ms step_avg:141.83ms
step:2684/5100 train_loss:3.4710 train_time:379238ms step_avg:141.82ms
step:2685/5100 train_loss:3.4938 train_time:379378ms step_avg:141.82ms
step:2686/5100 train_loss:3.5473 train_time:379520ms step_avg:141.82ms
step:2687/5100 train_loss:3.4574 train_time:379657ms step_avg:141.82ms
step:2688/5100 train_loss:3.5422 train_time:379796ms step_avg:141.82ms
step:2689/5100 train_loss:3.4691 train_time:379938ms step_avg:141.82ms
step:2690/5100 train_loss:3.4664 train_time:380079ms step_avg:141.82ms
step:2691/5100 train_loss:3.4861 train_time:380217ms step_avg:141.82ms
step:2692/5100 train_loss:3.5620 train_time:380357ms step_avg:141.82ms
step:2693/5100 train_loss:3.3607 train_time:380495ms step_avg:141.82ms
step:2694/5100 train_loss:3.7331 train_time:380635ms step_avg:141.82ms
step:2695/5100 train_loss:3.5396 train_time:380777ms step_avg:141.82ms
step:2696/5100 train_loss:3.3403 train_time:380916ms step_avg:141.82ms
step:2697/5100 train_loss:3.5197 train_time:381058ms step_avg:141.82ms
step:2698/5100 train_loss:3.4791 train_time:381197ms step_avg:141.81ms
step:2699/5100 train_loss:3.4351 train_time:381337ms step_avg:141.81ms
step:2700/5100 train_loss:3.5392 train_time:381477ms step_avg:141.81ms
step:2701/5100 train_loss:3.5029 train_time:381616ms step_avg:141.81ms
step:2702/5100 train_loss:3.4119 train_time:381759ms step_avg:141.81ms
step:2703/5100 train_loss:3.4424 train_time:381897ms step_avg:141.81ms
step:2704/5100 train_loss:3.4519 train_time:382037ms step_avg:141.81ms
step:2705/5100 train_loss:3.4178 train_time:382177ms step_avg:141.81ms
step:2706/5100 train_loss:3.5979 train_time:382316ms step_avg:141.81ms
step:2707/5100 train_loss:3.5512 train_time:382457ms step_avg:141.81ms
step:2708/5100 train_loss:3.4624 train_time:382595ms step_avg:141.81ms
step:2709/5100 train_loss:3.4536 train_time:382736ms step_avg:141.81ms
step:2710/5100 train_loss:3.5555 train_time:382877ms step_avg:141.81ms
step:2711/5100 train_loss:3.4366 train_time:383016ms step_avg:141.81ms
step:2712/5100 train_loss:3.5490 train_time:383156ms step_avg:141.80ms
step:2713/5100 train_loss:3.2903 train_time:383295ms step_avg:141.80ms
step:2714/5100 train_loss:3.4835 train_time:383435ms step_avg:141.80ms
step:2715/5100 train_loss:3.3793 train_time:383575ms step_avg:141.80ms
step:2716/5100 train_loss:3.3847 train_time:383714ms step_avg:141.80ms
step:2717/5100 train_loss:3.5727 train_time:383856ms step_avg:141.80ms
step:2718/5100 train_loss:3.4735 train_time:383996ms step_avg:141.80ms
step:2719/5100 train_loss:3.7033 train_time:384137ms step_avg:141.80ms
step:2720/5100 train_loss:3.4376 train_time:384276ms step_avg:141.80ms
step:2721/5100 train_loss:3.4456 train_time:384414ms step_avg:141.80ms
step:2722/5100 train_loss:3.6684 train_time:384556ms step_avg:141.80ms
step:2723/5100 train_loss:3.4432 train_time:384706ms step_avg:141.80ms
step:2724/5100 train_loss:3.6134 train_time:384836ms step_avg:141.80ms
step:2725/5100 train_loss:3.4902 train_time:384976ms step_avg:141.80ms
step:2726/5100 train_loss:3.4485 train_time:385116ms step_avg:141.80ms
step:2727/5100 train_loss:3.4608 train_time:385258ms step_avg:141.80ms
step:2728/5100 train_loss:3.8009 train_time:385396ms step_avg:141.79ms
step:2729/5100 train_loss:3.5278 train_time:385540ms step_avg:141.79ms
step:2730/5100 train_loss:3.3968 train_time:385677ms step_avg:141.79ms
step:2731/5100 train_loss:3.5046 train_time:385816ms step_avg:141.79ms
step:2732/5100 train_loss:3.4101 train_time:385956ms step_avg:141.79ms
step:2733/5100 train_loss:3.2975 train_time:386096ms step_avg:141.79ms
step:2734/5100 train_loss:3.4093 train_time:386236ms step_avg:141.79ms
step:2735/5100 train_loss:3.4807 train_time:386377ms step_avg:141.79ms
step:2736/5100 train_loss:3.3757 train_time:386516ms step_avg:141.79ms
step:2737/5100 train_loss:3.7802 train_time:386657ms step_avg:141.79ms
step:2738/5100 train_loss:3.5182 train_time:386797ms step_avg:141.79ms
step:2739/5100 train_loss:3.7214 train_time:386937ms step_avg:141.79ms
step:2740/5100 train_loss:3.4668 train_time:387077ms step_avg:141.79ms
step:2741/5100 train_loss:3.4673 train_time:387216ms step_avg:141.79ms
step:2742/5100 train_loss:3.4015 train_time:387356ms step_avg:141.78ms
step:2743/5100 train_loss:3.4813 train_time:387496ms step_avg:141.78ms
step:2744/5100 train_loss:3.4890 train_time:387636ms step_avg:141.78ms
step:2745/5100 train_loss:3.6047 train_time:387776ms step_avg:141.78ms
step:2746/5100 train_loss:3.3576 train_time:387916ms step_avg:141.78ms
step:2747/5100 train_loss:3.4516 train_time:388057ms step_avg:141.78ms
step:2748/5100 train_loss:3.4837 train_time:388196ms step_avg:141.78ms
step:2749/5100 train_loss:3.6046 train_time:388337ms step_avg:141.78ms
step:2750/5100 train_loss:3.4335 train_time:388479ms step_avg:141.78ms
step:2750/5100 val_loss:3.4582 train_time:388532ms step_avg:141.80ms
step:2751/5100 train_loss:3.5067 train_time:388630ms step_avg:141.78ms
step:2752/5100 train_loss:3.5682 train_time:388775ms step_avg:141.79ms
step:2753/5100 train_loss:3.4828 train_time:388918ms step_avg:141.79ms
step:2754/5100 train_loss:3.3999 train_time:389056ms step_avg:141.78ms
step:2755/5100 train_loss:3.4010 train_time:389194ms step_avg:141.78ms
step:2756/5100 train_loss:3.4892 train_time:389333ms step_avg:141.78ms
step:2757/5100 train_loss:3.4331 train_time:389473ms step_avg:141.78ms
step:2758/5100 train_loss:3.3033 train_time:389612ms step_avg:141.78ms
step:2759/5100 train_loss:3.6976 train_time:389757ms step_avg:141.78ms
step:2760/5100 train_loss:3.5032 train_time:389898ms step_avg:141.78ms
step:2761/5100 train_loss:3.4681 train_time:390038ms step_avg:141.78ms
step:2762/5100 train_loss:3.4451 train_time:390178ms step_avg:141.78ms
step:2763/5100 train_loss:3.3592 train_time:390315ms step_avg:141.78ms
step:2764/5100 train_loss:3.5279 train_time:390455ms step_avg:141.78ms
step:2765/5100 train_loss:3.4476 train_time:390596ms step_avg:141.78ms
step:2766/5100 train_loss:3.3466 train_time:390737ms step_avg:141.78ms
step:2767/5100 train_loss:3.4294 train_time:390880ms step_avg:141.78ms
step:2768/5100 train_loss:3.5216 train_time:391019ms step_avg:141.78ms
step:2769/5100 train_loss:3.3953 train_time:391159ms step_avg:141.78ms
step:2770/5100 train_loss:3.4728 train_time:391297ms step_avg:141.77ms
step:2771/5100 train_loss:3.4451 train_time:391436ms step_avg:141.77ms
step:2772/5100 train_loss:3.8841 train_time:391577ms step_avg:141.77ms
step:2773/5100 train_loss:3.3551 train_time:391717ms step_avg:141.77ms
step:2774/5100 train_loss:3.4905 train_time:391861ms step_avg:141.77ms
step:2775/5100 train_loss:3.5510 train_time:391998ms step_avg:141.77ms
step:2776/5100 train_loss:3.5148 train_time:392138ms step_avg:141.77ms
step:2777/5100 train_loss:3.5907 train_time:392278ms step_avg:141.77ms
step:2778/5100 train_loss:3.6026 train_time:392417ms step_avg:141.77ms
step:2779/5100 train_loss:3.4607 train_time:392558ms step_avg:141.77ms
step:2780/5100 train_loss:3.3355 train_time:392698ms step_avg:141.77ms
step:2781/5100 train_loss:3.4822 train_time:392838ms step_avg:141.77ms
step:2782/5100 train_loss:3.5038 train_time:392982ms step_avg:141.77ms
step:2783/5100 train_loss:3.3603 train_time:393118ms step_avg:141.77ms
step:2784/5100 train_loss:3.4772 train_time:393258ms step_avg:141.77ms
step:2785/5100 train_loss:3.5280 train_time:393397ms step_avg:141.76ms
step:2786/5100 train_loss:3.4036 train_time:393537ms step_avg:141.76ms
step:2787/5100 train_loss:3.5393 train_time:393678ms step_avg:141.76ms
step:2788/5100 train_loss:3.4929 train_time:393818ms step_avg:141.76ms
step:2789/5100 train_loss:3.4273 train_time:393959ms step_avg:141.76ms
step:2790/5100 train_loss:3.5087 train_time:394098ms step_avg:141.76ms
step:2791/5100 train_loss:3.4387 train_time:394239ms step_avg:141.76ms
step:2792/5100 train_loss:3.3324 train_time:394379ms step_avg:141.76ms
step:2793/5100 train_loss:3.4376 train_time:394521ms step_avg:141.76ms
step:2794/5100 train_loss:3.4784 train_time:394659ms step_avg:141.76ms
step:2795/5100 train_loss:3.3914 train_time:394798ms step_avg:141.76ms
step:2796/5100 train_loss:3.4344 train_time:394939ms step_avg:141.76ms
step:2797/5100 train_loss:3.3537 train_time:395080ms step_avg:141.76ms
step:2798/5100 train_loss:3.4636 train_time:395217ms step_avg:141.76ms
step:2799/5100 train_loss:3.4106 train_time:395358ms step_avg:141.76ms
step:2800/5100 train_loss:3.5816 train_time:395497ms step_avg:141.76ms
step:2801/5100 train_loss:3.5408 train_time:395639ms step_avg:141.76ms
step:2802/5100 train_loss:3.5070 train_time:395779ms step_avg:141.75ms
step:2803/5100 train_loss:3.4461 train_time:395918ms step_avg:141.75ms
step:2804/5100 train_loss:3.6297 train_time:396058ms step_avg:141.75ms
step:2805/5100 train_loss:3.5941 train_time:396197ms step_avg:141.75ms
step:2806/5100 train_loss:3.3175 train_time:396336ms step_avg:141.75ms
step:2807/5100 train_loss:3.7241 train_time:396477ms step_avg:141.75ms
step:2808/5100 train_loss:3.4690 train_time:396616ms step_avg:141.75ms
step:2809/5100 train_loss:3.3899 train_time:396758ms step_avg:141.75ms
step:2810/5100 train_loss:3.4320 train_time:396898ms step_avg:141.75ms
step:2811/5100 train_loss:3.5785 train_time:397037ms step_avg:141.75ms
step:2812/5100 train_loss:3.5630 train_time:397178ms step_avg:141.75ms
step:2813/5100 train_loss:3.3154 train_time:397317ms step_avg:141.75ms
step:2814/5100 train_loss:3.5354 train_time:397459ms step_avg:141.75ms
step:2815/5100 train_loss:3.6080 train_time:397597ms step_avg:141.75ms
step:2816/5100 train_loss:3.4198 train_time:397738ms step_avg:141.75ms
step:2817/5100 train_loss:3.0473 train_time:397880ms step_avg:141.75ms
step:2818/5100 train_loss:3.4424 train_time:398018ms step_avg:141.74ms
step:2819/5100 train_loss:3.4170 train_time:398159ms step_avg:141.74ms
step:2820/5100 train_loss:3.6062 train_time:398297ms step_avg:141.74ms
step:2821/5100 train_loss:3.6498 train_time:398438ms step_avg:141.74ms
step:2822/5100 train_loss:3.5299 train_time:398582ms step_avg:141.74ms
step:2823/5100 train_loss:3.4640 train_time:398717ms step_avg:141.74ms
step:2824/5100 train_loss:3.4249 train_time:398858ms step_avg:141.74ms
step:2825/5100 train_loss:3.3259 train_time:398999ms step_avg:141.74ms
step:2826/5100 train_loss:3.5911 train_time:399138ms step_avg:141.74ms
step:2827/5100 train_loss:3.4926 train_time:399279ms step_avg:141.74ms
step:2828/5100 train_loss:3.3639 train_time:399418ms step_avg:141.74ms
step:2829/5100 train_loss:3.5039 train_time:399559ms step_avg:141.74ms
step:2830/5100 train_loss:3.5012 train_time:399698ms step_avg:141.74ms
step:2831/5100 train_loss:3.4369 train_time:399843ms step_avg:141.74ms
step:2832/5100 train_loss:3.5789 train_time:399979ms step_avg:141.74ms
step:2833/5100 train_loss:3.5005 train_time:400118ms step_avg:141.74ms
step:2834/5100 train_loss:3.4874 train_time:400258ms step_avg:141.73ms
step:2835/5100 train_loss:3.2966 train_time:400558ms step_avg:141.79ms
step:2836/5100 train_loss:3.5179 train_time:400694ms step_avg:141.79ms
step:2837/5100 train_loss:3.4445 train_time:400832ms step_avg:141.79ms
step:2838/5100 train_loss:3.7675 train_time:400972ms step_avg:141.79ms
step:2839/5100 train_loss:3.4038 train_time:401110ms step_avg:141.79ms
step:2840/5100 train_loss:3.4094 train_time:401250ms step_avg:141.78ms
step:2841/5100 train_loss:3.4714 train_time:401389ms step_avg:141.78ms
step:2842/5100 train_loss:3.3992 train_time:401533ms step_avg:141.78ms
step:2843/5100 train_loss:3.3981 train_time:401677ms step_avg:141.79ms
step:2844/5100 train_loss:3.5769 train_time:401816ms step_avg:141.78ms
step:2845/5100 train_loss:3.4607 train_time:401957ms step_avg:141.78ms
step:2846/5100 train_loss:3.4888 train_time:402097ms step_avg:141.78ms
step:2847/5100 train_loss:3.4421 train_time:402236ms step_avg:141.78ms
step:2848/5100 train_loss:3.7147 train_time:402377ms step_avg:141.78ms
step:2849/5100 train_loss:3.3794 train_time:402517ms step_avg:141.78ms
step:2850/5100 train_loss:3.4089 train_time:402822ms step_avg:141.84ms
step:2851/5100 train_loss:3.5143 train_time:402957ms step_avg:141.84ms
step:2852/5100 train_loss:3.4835 train_time:403095ms step_avg:141.84ms
step:2853/5100 train_loss:3.4499 train_time:403234ms step_avg:141.83ms
step:2854/5100 train_loss:3.5179 train_time:403372ms step_avg:141.83ms
step:2855/5100 train_loss:3.3458 train_time:403511ms step_avg:141.83ms
step:2856/5100 train_loss:3.3614 train_time:403650ms step_avg:141.83ms
step:2857/5100 train_loss:3.4583 train_time:403796ms step_avg:141.83ms
step:2858/5100 train_loss:3.4609 train_time:403939ms step_avg:141.83ms
step:2859/5100 train_loss:3.3385 train_time:404078ms step_avg:141.83ms
step:2860/5100 train_loss:3.4434 train_time:404215ms step_avg:141.83ms
step:2861/5100 train_loss:3.4087 train_time:404355ms step_avg:141.83ms
step:2862/5100 train_loss:3.4568 train_time:404494ms step_avg:141.83ms
step:2863/5100 train_loss:3.4894 train_time:404635ms step_avg:141.83ms
step:2864/5100 train_loss:3.7570 train_time:404777ms step_avg:141.83ms
step:2865/5100 train_loss:3.5731 train_time:404918ms step_avg:141.83ms
step:2866/5100 train_loss:3.4526 train_time:405058ms step_avg:141.83ms
step:2867/5100 train_loss:3.3452 train_time:405196ms step_avg:141.83ms
step:2868/5100 train_loss:3.5434 train_time:405337ms step_avg:141.83ms
step:2869/5100 train_loss:3.4942 train_time:405477ms step_avg:141.82ms
step:2870/5100 train_loss:3.4548 train_time:405616ms step_avg:141.82ms
step:2871/5100 train_loss:3.5924 train_time:405758ms step_avg:141.82ms
step:2872/5100 train_loss:3.3713 train_time:405899ms step_avg:141.82ms
step:2873/5100 train_loss:3.4292 train_time:406038ms step_avg:141.82ms
step:2874/5100 train_loss:3.2942 train_time:406177ms step_avg:141.82ms
step:2875/5100 train_loss:3.4503 train_time:406316ms step_avg:141.82ms
step:2875/5100 val_loss:3.4520 train_time:406373ms step_avg:141.84ms
step:2876/5100 train_loss:3.3679 train_time:406469ms step_avg:141.82ms
step:2877/5100 train_loss:3.3492 train_time:406615ms step_avg:141.83ms
step:2878/5100 train_loss:3.4419 train_time:406757ms step_avg:141.83ms
step:2879/5100 train_loss:3.5609 train_time:406895ms step_avg:141.82ms
step:2880/5100 train_loss:3.5123 train_time:407033ms step_avg:141.82ms
step:2881/5100 train_loss:3.4515 train_time:407172ms step_avg:141.82ms
step:2882/5100 train_loss:3.4384 train_time:407311ms step_avg:141.82ms
step:2883/5100 train_loss:3.5643 train_time:407453ms step_avg:141.82ms
step:2884/5100 train_loss:3.3491 train_time:407597ms step_avg:141.82ms
step:2885/5100 train_loss:3.3747 train_time:407737ms step_avg:141.82ms
step:2886/5100 train_loss:3.4128 train_time:407876ms step_avg:141.82ms
step:2887/5100 train_loss:3.4138 train_time:408014ms step_avg:141.82ms
step:2888/5100 train_loss:3.4200 train_time:408155ms step_avg:141.82ms
step:2889/5100 train_loss:3.4428 train_time:408295ms step_avg:141.82ms
step:2890/5100 train_loss:3.6265 train_time:408435ms step_avg:141.82ms
step:2891/5100 train_loss:3.4746 train_time:408578ms step_avg:141.82ms
step:2892/5100 train_loss:3.3116 train_time:408718ms step_avg:141.82ms
step:2893/5100 train_loss:3.2454 train_time:408857ms step_avg:141.82ms
step:2894/5100 train_loss:3.3811 train_time:408995ms step_avg:141.82ms
step:2895/5100 train_loss:3.2592 train_time:409135ms step_avg:141.81ms
step:2896/5100 train_loss:3.4415 train_time:409275ms step_avg:141.81ms
step:2897/5100 train_loss:3.5653 train_time:409414ms step_avg:141.81ms
step:2898/5100 train_loss:3.3924 train_time:409558ms step_avg:141.81ms
step:2899/5100 train_loss:3.4945 train_time:409697ms step_avg:141.81ms
step:2900/5100 train_loss:3.3725 train_time:409838ms step_avg:141.81ms
step:2901/5100 train_loss:3.5615 train_time:409977ms step_avg:141.81ms
step:2902/5100 train_loss:3.5499 train_time:410115ms step_avg:141.81ms
step:2903/5100 train_loss:3.5967 train_time:410256ms step_avg:141.81ms
step:2904/5100 train_loss:3.3021 train_time:410396ms step_avg:141.81ms
step:2905/5100 train_loss:3.4445 train_time:410537ms step_avg:141.81ms
step:2906/5100 train_loss:3.4207 train_time:410678ms step_avg:141.81ms
step:2907/5100 train_loss:3.5073 train_time:410817ms step_avg:141.81ms
step:2908/5100 train_loss:3.4384 train_time:410957ms step_avg:141.81ms
step:2909/5100 train_loss:3.4010 train_time:411095ms step_avg:141.81ms
step:2910/5100 train_loss:3.7406 train_time:411236ms step_avg:141.81ms
step:2911/5100 train_loss:3.4585 train_time:411376ms step_avg:141.81ms
step:2912/5100 train_loss:3.3569 train_time:411517ms step_avg:141.80ms
step:2913/5100 train_loss:3.3479 train_time:411657ms step_avg:141.80ms
step:2914/5100 train_loss:3.8356 train_time:411797ms step_avg:141.80ms
step:2915/5100 train_loss:3.4200 train_time:411936ms step_avg:141.80ms
step:2916/5100 train_loss:3.3686 train_time:412076ms step_avg:141.80ms
step:2917/5100 train_loss:3.3523 train_time:412215ms step_avg:141.80ms
step:2918/5100 train_loss:3.6410 train_time:412357ms step_avg:141.80ms
step:2919/5100 train_loss:3.1524 train_time:412496ms step_avg:141.80ms
step:2920/5100 train_loss:3.3490 train_time:412636ms step_avg:141.80ms
step:2921/5100 train_loss:3.3694 train_time:412778ms step_avg:141.80ms
step:2922/5100 train_loss:3.4632 train_time:412916ms step_avg:141.80ms
step:2923/5100 train_loss:3.5038 train_time:413056ms step_avg:141.80ms
step:2924/5100 train_loss:3.5358 train_time:413196ms step_avg:141.80ms
step:2925/5100 train_loss:3.5372 train_time:413336ms step_avg:141.80ms
step:2926/5100 train_loss:3.4189 train_time:413480ms step_avg:141.80ms
step:2927/5100 train_loss:3.4381 train_time:413621ms step_avg:141.80ms
step:2928/5100 train_loss:3.4216 train_time:413755ms step_avg:141.79ms
step:2929/5100 train_loss:3.4234 train_time:413896ms step_avg:141.79ms
step:2930/5100 train_loss:3.3871 train_time:414046ms step_avg:141.80ms
step:2931/5100 train_loss:3.4198 train_time:414176ms step_avg:141.79ms
step:2932/5100 train_loss:3.5539 train_time:414315ms step_avg:141.79ms
step:2933/5100 train_loss:3.5941 train_time:414456ms step_avg:141.79ms
step:2934/5100 train_loss:3.5606 train_time:414595ms step_avg:141.79ms
step:2935/5100 train_loss:3.3999 train_time:414736ms step_avg:141.79ms
step:2936/5100 train_loss:3.4611 train_time:414877ms step_avg:141.79ms
step:2937/5100 train_loss:3.3913 train_time:415015ms step_avg:141.79ms
step:2938/5100 train_loss:3.4209 train_time:415156ms step_avg:141.79ms
step:2939/5100 train_loss:3.4484 train_time:415296ms step_avg:141.79ms
step:2940/5100 train_loss:3.4858 train_time:415436ms step_avg:141.79ms
step:2941/5100 train_loss:3.5329 train_time:415576ms step_avg:141.79ms
step:2942/5100 train_loss:3.5160 train_time:415715ms step_avg:141.79ms
step:2943/5100 train_loss:3.4469 train_time:415856ms step_avg:141.79ms
step:2944/5100 train_loss:3.3312 train_time:415996ms step_avg:141.78ms
step:2945/5100 train_loss:3.2707 train_time:416136ms step_avg:141.78ms
step:2946/5100 train_loss:3.4714 train_time:416277ms step_avg:141.78ms
step:2947/5100 train_loss:3.5394 train_time:416416ms step_avg:141.78ms
step:2948/5100 train_loss:3.4683 train_time:416556ms step_avg:141.78ms
step:2949/5100 train_loss:3.6479 train_time:416695ms step_avg:141.78ms
step:2950/5100 train_loss:3.4737 train_time:416835ms step_avg:141.78ms
step:2951/5100 train_loss:3.4699 train_time:416976ms step_avg:141.78ms
step:2952/5100 train_loss:3.8920 train_time:417116ms step_avg:141.78ms
step:2953/5100 train_loss:3.5579 train_time:417256ms step_avg:141.78ms
step:2954/5100 train_loss:3.4936 train_time:417396ms step_avg:141.78ms
step:2955/5100 train_loss:3.5036 train_time:417536ms step_avg:141.78ms
step:2956/5100 train_loss:3.4349 train_time:417676ms step_avg:141.78ms
step:2957/5100 train_loss:3.4636 train_time:417815ms step_avg:141.78ms
step:2958/5100 train_loss:3.3303 train_time:417956ms step_avg:141.78ms
step:2959/5100 train_loss:3.4222 train_time:418095ms step_avg:141.78ms
step:2960/5100 train_loss:3.5596 train_time:418236ms step_avg:141.77ms
step:2961/5100 train_loss:3.3711 train_time:418377ms step_avg:141.77ms
step:2962/5100 train_loss:3.4973 train_time:418516ms step_avg:141.77ms
step:2963/5100 train_loss:3.3558 train_time:418658ms step_avg:141.77ms
step:2964/5100 train_loss:3.4186 train_time:418796ms step_avg:141.77ms
step:2965/5100 train_loss:3.4038 train_time:418936ms step_avg:141.77ms
step:2966/5100 train_loss:3.5216 train_time:419076ms step_avg:141.77ms
step:2967/5100 train_loss:3.3841 train_time:419216ms step_avg:141.77ms
step:2968/5100 train_loss:3.6341 train_time:419357ms step_avg:141.77ms
step:2969/5100 train_loss:3.4836 train_time:419497ms step_avg:141.77ms
step:2970/5100 train_loss:3.4973 train_time:419636ms step_avg:141.77ms
step:2971/5100 train_loss:3.4852 train_time:419778ms step_avg:141.77ms
step:2972/5100 train_loss:3.5592 train_time:419916ms step_avg:141.77ms
step:2973/5100 train_loss:3.3856 train_time:420058ms step_avg:141.77ms
step:2974/5100 train_loss:3.3950 train_time:420197ms step_avg:141.77ms
step:2975/5100 train_loss:3.3093 train_time:420337ms step_avg:141.77ms
step:2976/5100 train_loss:3.3888 train_time:420477ms step_avg:141.77ms
step:2977/5100 train_loss:3.3843 train_time:420616ms step_avg:141.76ms
step:2978/5100 train_loss:3.3973 train_time:420757ms step_avg:141.76ms
step:2979/5100 train_loss:3.6775 train_time:420896ms step_avg:141.76ms
step:2980/5100 train_loss:3.4887 train_time:421036ms step_avg:141.76ms
step:2981/5100 train_loss:3.5260 train_time:421176ms step_avg:141.76ms
step:2982/5100 train_loss:3.5470 train_time:421317ms step_avg:141.76ms
step:2983/5100 train_loss:3.6248 train_time:421458ms step_avg:141.76ms
step:2984/5100 train_loss:3.4215 train_time:421597ms step_avg:141.76ms
step:2985/5100 train_loss:3.5216 train_time:421737ms step_avg:141.76ms
step:2986/5100 train_loss:3.5184 train_time:421876ms step_avg:141.76ms
step:2987/5100 train_loss:3.4758 train_time:422016ms step_avg:141.76ms
step:2988/5100 train_loss:3.5963 train_time:422157ms step_avg:141.76ms
step:2989/5100 train_loss:3.1931 train_time:422300ms step_avg:141.76ms
step:2990/5100 train_loss:3.5369 train_time:422436ms step_avg:141.76ms
step:2991/5100 train_loss:3.4856 train_time:422577ms step_avg:141.76ms
step:2992/5100 train_loss:3.4739 train_time:422719ms step_avg:141.76ms
step:2993/5100 train_loss:3.3806 train_time:422857ms step_avg:141.76ms
step:2994/5100 train_loss:3.5218 train_time:422998ms step_avg:141.76ms
step:2995/5100 train_loss:3.3386 train_time:423136ms step_avg:141.75ms
step:2996/5100 train_loss:3.3700 train_time:423276ms step_avg:141.75ms
step:2997/5100 train_loss:3.4417 train_time:423415ms step_avg:141.75ms
step:2998/5100 train_loss:3.3831 train_time:423556ms step_avg:141.75ms
step:2999/5100 train_loss:3.4991 train_time:423696ms step_avg:141.75ms
step:3000/5100 train_loss:3.4052 train_time:423837ms step_avg:141.75ms
step:3000/5100 val_loss:3.4440 train_time:423891ms step_avg:141.77ms
step:3001/5100 train_loss:3.3988 train_time:423988ms step_avg:141.75ms
step:3002/5100 train_loss:3.3351 train_time:424135ms step_avg:141.76ms
step:3003/5100 train_loss:3.3891 train_time:424273ms step_avg:141.75ms
step:3004/5100 train_loss:3.5104 train_time:424411ms step_avg:141.75ms
step:3005/5100 train_loss:3.8572 train_time:424549ms step_avg:141.75ms
step:3006/5100 train_loss:3.4309 train_time:424688ms step_avg:141.75ms
step:3007/5100 train_loss:3.4950 train_time:424825ms step_avg:141.75ms
step:3008/5100 train_loss:3.3040 train_time:424967ms step_avg:141.75ms
step:3009/5100 train_loss:3.5255 train_time:425110ms step_avg:141.75ms
step:3010/5100 train_loss:3.4132 train_time:425250ms step_avg:141.75ms
step:3011/5100 train_loss:3.4773 train_time:425389ms step_avg:141.75ms
step:3012/5100 train_loss:3.4785 train_time:425527ms step_avg:141.75ms
step:3013/5100 train_loss:3.3688 train_time:425665ms step_avg:141.75ms
step:3014/5100 train_loss:3.5693 train_time:425804ms step_avg:141.75ms
step:3015/5100 train_loss:3.5259 train_time:425944ms step_avg:141.75ms
step:3016/5100 train_loss:3.3957 train_time:426088ms step_avg:141.75ms
step:3017/5100 train_loss:3.4311 train_time:426228ms step_avg:141.75ms
step:3018/5100 train_loss:3.4734 train_time:426368ms step_avg:141.74ms
step:3019/5100 train_loss:3.5089 train_time:426506ms step_avg:141.74ms
step:3020/5100 train_loss:3.3020 train_time:426645ms step_avg:141.74ms
step:3021/5100 train_loss:3.5910 train_time:426785ms step_avg:141.74ms
step:3022/5100 train_loss:3.4308 train_time:426925ms step_avg:141.74ms
step:3023/5100 train_loss:3.3477 train_time:427067ms step_avg:141.74ms
step:3024/5100 train_loss:3.4373 train_time:427365ms step_avg:141.79ms
step:3025/5100 train_loss:3.4265 train_time:427502ms step_avg:141.79ms
step:3026/5100 train_loss:3.4778 train_time:427640ms step_avg:141.79ms
step:3027/5100 train_loss:3.5072 train_time:427780ms step_avg:141.79ms
step:3028/5100 train_loss:3.4073 train_time:427918ms step_avg:141.79ms
step:3029/5100 train_loss:3.2103 train_time:428057ms step_avg:141.79ms
step:3030/5100 train_loss:3.5563 train_time:428196ms step_avg:141.79ms
step:3031/5100 train_loss:3.3130 train_time:428340ms step_avg:141.79ms
step:3032/5100 train_loss:3.3084 train_time:428483ms step_avg:141.79ms
step:3033/5100 train_loss:3.6500 train_time:428625ms step_avg:141.79ms
step:3034/5100 train_loss:3.6402 train_time:428763ms step_avg:141.79ms
step:3035/5100 train_loss:3.4131 train_time:428902ms step_avg:141.79ms
step:3036/5100 train_loss:3.4838 train_time:429042ms step_avg:141.79ms
step:3037/5100 train_loss:3.4397 train_time:429182ms step_avg:141.78ms
step:3038/5100 train_loss:3.3347 train_time:429324ms step_avg:141.78ms
step:3039/5100 train_loss:3.3970 train_time:429467ms step_avg:141.79ms
step:3040/5100 train_loss:3.4867 train_time:429777ms step_avg:141.84ms
step:3041/5100 train_loss:3.4820 train_time:429911ms step_avg:141.84ms
step:3042/5100 train_loss:3.2994 train_time:430050ms step_avg:141.84ms
step:3043/5100 train_loss:3.4385 train_time:430189ms step_avg:141.84ms
step:3044/5100 train_loss:3.4620 train_time:430325ms step_avg:141.83ms
step:3045/5100 train_loss:3.4758 train_time:430464ms step_avg:141.83ms
step:3046/5100 train_loss:3.5525 train_time:430602ms step_avg:141.83ms
step:3047/5100 train_loss:3.3660 train_time:430747ms step_avg:141.83ms
step:3048/5100 train_loss:3.4897 train_time:430891ms step_avg:141.83ms
step:3049/5100 train_loss:3.4388 train_time:431029ms step_avg:141.83ms
step:3050/5100 train_loss:3.3690 train_time:431168ms step_avg:141.83ms
step:3051/5100 train_loss:3.4936 train_time:431306ms step_avg:141.83ms
step:3052/5100 train_loss:3.3396 train_time:431445ms step_avg:141.83ms
step:3053/5100 train_loss:3.5800 train_time:431585ms step_avg:141.83ms
step:3054/5100 train_loss:3.5234 train_time:431726ms step_avg:141.83ms
step:3055/5100 train_loss:3.5067 train_time:431869ms step_avg:141.83ms
step:3056/5100 train_loss:3.5014 train_time:432009ms step_avg:141.83ms
step:3057/5100 train_loss:3.3867 train_time:432148ms step_avg:141.83ms
step:3058/5100 train_loss:3.4206 train_time:432286ms step_avg:141.83ms
step:3059/5100 train_loss:3.4991 train_time:432425ms step_avg:141.83ms
step:3060/5100 train_loss:3.3940 train_time:432566ms step_avg:141.82ms
step:3061/5100 train_loss:3.4492 train_time:432706ms step_avg:141.82ms
step:3062/5100 train_loss:3.4561 train_time:432848ms step_avg:141.82ms
step:3063/5100 train_loss:3.3897 train_time:432990ms step_avg:141.82ms
step:3064/5100 train_loss:3.3677 train_time:433128ms step_avg:141.82ms
step:3065/5100 train_loss:3.3837 train_time:433268ms step_avg:141.82ms
step:3066/5100 train_loss:3.3662 train_time:433406ms step_avg:141.82ms
step:3067/5100 train_loss:3.3504 train_time:433546ms step_avg:141.82ms
step:3068/5100 train_loss:3.3139 train_time:433687ms step_avg:141.82ms
step:3069/5100 train_loss:3.3523 train_time:433827ms step_avg:141.82ms
step:3070/5100 train_loss:3.3481 train_time:433968ms step_avg:141.82ms
step:3071/5100 train_loss:3.5359 train_time:434107ms step_avg:141.82ms
step:3072/5100 train_loss:3.4639 train_time:434247ms step_avg:141.82ms
step:3073/5100 train_loss:3.5087 train_time:434386ms step_avg:141.82ms
step:3074/5100 train_loss:3.4960 train_time:434525ms step_avg:141.82ms
step:3075/5100 train_loss:3.4378 train_time:434667ms step_avg:141.82ms
step:3076/5100 train_loss:3.4879 train_time:434807ms step_avg:141.82ms
step:3077/5100 train_loss:3.5482 train_time:434947ms step_avg:141.81ms
step:3078/5100 train_loss:3.3486 train_time:435087ms step_avg:141.81ms
step:3079/5100 train_loss:3.8868 train_time:435225ms step_avg:141.81ms
step:3080/5100 train_loss:3.4419 train_time:435366ms step_avg:141.81ms
step:3081/5100 train_loss:3.4074 train_time:435506ms step_avg:141.81ms
step:3082/5100 train_loss:3.5466 train_time:435646ms step_avg:141.81ms
step:3083/5100 train_loss:3.3612 train_time:435787ms step_avg:141.81ms
step:3084/5100 train_loss:3.3895 train_time:435926ms step_avg:141.81ms
step:3085/5100 train_loss:3.4374 train_time:436067ms step_avg:141.81ms
step:3086/5100 train_loss:3.5362 train_time:436206ms step_avg:141.81ms
step:3087/5100 train_loss:3.4441 train_time:436346ms step_avg:141.81ms
step:3088/5100 train_loss:3.3557 train_time:436492ms step_avg:141.81ms
step:3089/5100 train_loss:3.5069 train_time:436626ms step_avg:141.81ms
step:3090/5100 train_loss:3.3753 train_time:436766ms step_avg:141.81ms
step:3091/5100 train_loss:3.6232 train_time:436906ms step_avg:141.81ms
step:3092/5100 train_loss:4.1959 train_time:437047ms step_avg:141.81ms
step:3093/5100 train_loss:3.4700 train_time:437188ms step_avg:141.81ms
step:3094/5100 train_loss:3.3610 train_time:437326ms step_avg:141.80ms
step:3095/5100 train_loss:3.3147 train_time:437466ms step_avg:141.80ms
step:3096/5100 train_loss:3.4885 train_time:437605ms step_avg:141.80ms
step:3097/5100 train_loss:3.6084 train_time:437746ms step_avg:141.80ms
step:3098/5100 train_loss:3.3897 train_time:437887ms step_avg:141.80ms
step:3099/5100 train_loss:3.4259 train_time:438027ms step_avg:141.80ms
step:3100/5100 train_loss:3.6030 train_time:438168ms step_avg:141.80ms
step:3101/5100 train_loss:3.5003 train_time:438307ms step_avg:141.80ms
step:3102/5100 train_loss:3.4984 train_time:438446ms step_avg:141.80ms
step:3103/5100 train_loss:3.4043 train_time:438586ms step_avg:141.80ms
step:3104/5100 train_loss:3.6571 train_time:438728ms step_avg:141.80ms
step:3105/5100 train_loss:3.4731 train_time:438866ms step_avg:141.80ms
step:3106/5100 train_loss:3.3325 train_time:439006ms step_avg:141.80ms
step:3107/5100 train_loss:3.3663 train_time:439145ms step_avg:141.80ms
step:3108/5100 train_loss:3.3222 train_time:439286ms step_avg:141.80ms
step:3109/5100 train_loss:3.5409 train_time:439429ms step_avg:141.80ms
step:3110/5100 train_loss:3.4309 train_time:439567ms step_avg:141.80ms
step:3111/5100 train_loss:3.4660 train_time:439706ms step_avg:141.79ms
step:3112/5100 train_loss:3.4475 train_time:439846ms step_avg:141.79ms
step:3113/5100 train_loss:3.5002 train_time:439987ms step_avg:141.79ms
step:3114/5100 train_loss:3.4462 train_time:440126ms step_avg:141.79ms
step:3115/5100 train_loss:3.4615 train_time:440266ms step_avg:141.79ms
step:3116/5100 train_loss:3.4920 train_time:440405ms step_avg:141.79ms
step:3117/5100 train_loss:3.3488 train_time:440546ms step_avg:141.79ms
step:3118/5100 train_loss:3.3571 train_time:440686ms step_avg:141.79ms
step:3119/5100 train_loss:3.5536 train_time:440826ms step_avg:141.79ms
step:3120/5100 train_loss:3.5283 train_time:440966ms step_avg:141.79ms
step:3121/5100 train_loss:3.3193 train_time:441106ms step_avg:141.79ms
step:3122/5100 train_loss:3.5068 train_time:441246ms step_avg:141.79ms
step:3123/5100 train_loss:3.5658 train_time:441386ms step_avg:141.79ms
step:3124/5100 train_loss:3.5334 train_time:441525ms step_avg:141.79ms
step:3125/5100 train_loss:3.3220 train_time:441667ms step_avg:141.79ms
step:3125/5100 val_loss:3.4368 train_time:441722ms step_avg:141.80ms
step:3126/5100 train_loss:3.4044 train_time:441819ms step_avg:141.79ms
step:3127/5100 train_loss:3.4482 train_time:441964ms step_avg:141.79ms
step:3128/5100 train_loss:3.5373 train_time:442107ms step_avg:141.79ms
step:3129/5100 train_loss:3.6062 train_time:442245ms step_avg:141.79ms
step:3130/5100 train_loss:3.3058 train_time:442384ms step_avg:141.79ms
step:3131/5100 train_loss:3.4739 train_time:442522ms step_avg:141.79ms
step:3132/5100 train_loss:3.4709 train_time:442661ms step_avg:141.79ms
step:3133/5100 train_loss:3.4941 train_time:442803ms step_avg:141.79ms
step:3134/5100 train_loss:3.3945 train_time:442946ms step_avg:141.79ms
step:3135/5100 train_loss:3.5079 train_time:443088ms step_avg:141.79ms
step:3136/5100 train_loss:3.4285 train_time:443227ms step_avg:141.79ms
step:3137/5100 train_loss:3.4908 train_time:443365ms step_avg:141.79ms
step:3138/5100 train_loss:3.6843 train_time:443505ms step_avg:141.79ms
step:3139/5100 train_loss:3.6437 train_time:443647ms step_avg:141.79ms
step:3140/5100 train_loss:3.4142 train_time:443784ms step_avg:141.78ms
step:3141/5100 train_loss:3.4356 train_time:443925ms step_avg:141.78ms
step:3142/5100 train_loss:3.3545 train_time:444067ms step_avg:141.78ms
step:3143/5100 train_loss:3.4491 train_time:444206ms step_avg:141.78ms
step:3144/5100 train_loss:3.2449 train_time:444345ms step_avg:141.78ms
step:3145/5100 train_loss:3.4872 train_time:444486ms step_avg:141.78ms
step:3146/5100 train_loss:3.3966 train_time:444625ms step_avg:141.78ms
step:3147/5100 train_loss:3.4214 train_time:444766ms step_avg:141.78ms
step:3148/5100 train_loss:3.5916 train_time:444907ms step_avg:141.78ms
step:3149/5100 train_loss:3.6829 train_time:445048ms step_avg:141.78ms
step:3150/5100 train_loss:3.5442 train_time:445188ms step_avg:141.78ms
step:3151/5100 train_loss:3.3596 train_time:445326ms step_avg:141.78ms
step:3152/5100 train_loss:3.4071 train_time:445466ms step_avg:141.78ms
step:3153/5100 train_loss:3.3856 train_time:445606ms step_avg:141.78ms
step:3154/5100 train_loss:3.5013 train_time:445746ms step_avg:141.78ms
step:3155/5100 train_loss:3.3192 train_time:445886ms step_avg:141.78ms
step:3156/5100 train_loss:3.4563 train_time:446027ms step_avg:141.78ms
step:3157/5100 train_loss:3.4127 train_time:446168ms step_avg:141.78ms
step:3158/5100 train_loss:3.5282 train_time:446306ms step_avg:141.77ms
step:3159/5100 train_loss:3.5891 train_time:446446ms step_avg:141.77ms
step:3160/5100 train_loss:3.4309 train_time:446587ms step_avg:141.77ms
step:3161/5100 train_loss:3.4999 train_time:446726ms step_avg:141.77ms
step:3162/5100 train_loss:3.5700 train_time:446866ms step_avg:141.77ms
step:3163/5100 train_loss:3.4749 train_time:447007ms step_avg:141.77ms
step:3164/5100 train_loss:3.5336 train_time:447146ms step_avg:141.77ms
step:3165/5100 train_loss:3.3530 train_time:447290ms step_avg:141.77ms
step:3166/5100 train_loss:3.3389 train_time:447425ms step_avg:141.77ms
step:3167/5100 train_loss:3.3708 train_time:447566ms step_avg:141.77ms
step:3168/5100 train_loss:3.2033 train_time:447707ms step_avg:141.77ms
step:3169/5100 train_loss:3.3684 train_time:447846ms step_avg:141.77ms
step:3170/5100 train_loss:3.5121 train_time:447988ms step_avg:141.77ms
step:3171/5100 train_loss:3.5527 train_time:448128ms step_avg:141.77ms
step:3172/5100 train_loss:3.5025 train_time:448268ms step_avg:141.77ms
step:3173/5100 train_loss:3.4734 train_time:448408ms step_avg:141.77ms
step:3174/5100 train_loss:3.4410 train_time:448547ms step_avg:141.77ms
step:3175/5100 train_loss:3.4470 train_time:448687ms step_avg:141.77ms
step:3176/5100 train_loss:3.4466 train_time:448826ms step_avg:141.76ms
step:3177/5100 train_loss:3.3752 train_time:448967ms step_avg:141.76ms
step:3178/5100 train_loss:3.5075 train_time:449108ms step_avg:141.76ms
step:3179/5100 train_loss:3.5766 train_time:449247ms step_avg:141.76ms
step:3180/5100 train_loss:3.4202 train_time:449387ms step_avg:141.76ms
step:3181/5100 train_loss:3.4058 train_time:449525ms step_avg:141.76ms
step:3182/5100 train_loss:3.4518 train_time:449667ms step_avg:141.76ms
step:3183/5100 train_loss:3.5570 train_time:449806ms step_avg:141.76ms
step:3184/5100 train_loss:3.5680 train_time:449946ms step_avg:141.76ms
step:3185/5100 train_loss:3.4691 train_time:450088ms step_avg:141.76ms
step:3186/5100 train_loss:3.5379 train_time:450226ms step_avg:141.76ms
step:3187/5100 train_loss:3.5208 train_time:450366ms step_avg:141.76ms
step:3188/5100 train_loss:3.3162 train_time:450506ms step_avg:141.76ms
step:3189/5100 train_loss:3.4727 train_time:450646ms step_avg:141.76ms
step:3190/5100 train_loss:3.4282 train_time:450788ms step_avg:141.76ms
step:3191/5100 train_loss:3.4556 train_time:450926ms step_avg:141.76ms
step:3192/5100 train_loss:3.4094 train_time:451067ms step_avg:141.76ms
step:3193/5100 train_loss:3.3389 train_time:451207ms step_avg:141.76ms
step:3194/5100 train_loss:4.3615 train_time:451347ms step_avg:141.75ms
step:3195/5100 train_loss:3.4764 train_time:451486ms step_avg:141.75ms
step:3196/5100 train_loss:3.2542 train_time:451625ms step_avg:141.75ms
step:3197/5100 train_loss:3.4122 train_time:451769ms step_avg:141.75ms
step:3198/5100 train_loss:3.2894 train_time:451906ms step_avg:141.75ms
step:3199/5100 train_loss:3.4004 train_time:452046ms step_avg:141.75ms
step:3200/5100 train_loss:3.3296 train_time:452187ms step_avg:141.75ms
step:3201/5100 train_loss:3.4141 train_time:452326ms step_avg:141.75ms
step:3202/5100 train_loss:3.5123 train_time:452466ms step_avg:141.75ms
step:3203/5100 train_loss:3.3593 train_time:452606ms step_avg:141.75ms
step:3204/5100 train_loss:3.4045 train_time:452746ms step_avg:141.75ms
step:3205/5100 train_loss:3.4932 train_time:452887ms step_avg:141.75ms
step:3206/5100 train_loss:3.6518 train_time:453026ms step_avg:141.75ms
step:3207/5100 train_loss:3.2463 train_time:453167ms step_avg:141.75ms
step:3208/5100 train_loss:3.6041 train_time:453307ms step_avg:141.75ms
step:3209/5100 train_loss:3.4524 train_time:453446ms step_avg:141.75ms
step:3210/5100 train_loss:3.5204 train_time:453587ms step_avg:141.75ms
step:3211/5100 train_loss:3.6058 train_time:453727ms step_avg:141.75ms
step:3212/5100 train_loss:3.2932 train_time:453866ms step_avg:141.74ms
step:3213/5100 train_loss:3.3448 train_time:454168ms step_avg:141.79ms
step:3214/5100 train_loss:3.5544 train_time:454305ms step_avg:141.79ms
step:3215/5100 train_loss:3.3443 train_time:454443ms step_avg:141.79ms
step:3216/5100 train_loss:3.4148 train_time:454582ms step_avg:141.79ms
step:3217/5100 train_loss:3.3167 train_time:454720ms step_avg:141.79ms
step:3218/5100 train_loss:3.4427 train_time:454859ms step_avg:141.79ms
step:3219/5100 train_loss:3.4884 train_time:455000ms step_avg:141.79ms
step:3220/5100 train_loss:3.5326 train_time:455143ms step_avg:141.79ms
step:3221/5100 train_loss:3.4796 train_time:455286ms step_avg:141.79ms
step:3222/5100 train_loss:3.4799 train_time:455427ms step_avg:141.79ms
step:3223/5100 train_loss:3.3440 train_time:455565ms step_avg:141.79ms
step:3224/5100 train_loss:3.3707 train_time:455704ms step_avg:141.79ms
step:3225/5100 train_loss:3.3640 train_time:455844ms step_avg:141.79ms
step:3226/5100 train_loss:3.4086 train_time:455985ms step_avg:141.79ms
step:3227/5100 train_loss:3.3400 train_time:456126ms step_avg:141.79ms
step:3228/5100 train_loss:3.2619 train_time:456267ms step_avg:141.79ms
step:3229/5100 train_loss:3.3844 train_time:456406ms step_avg:141.79ms
step:3230/5100 train_loss:3.1594 train_time:456728ms step_avg:141.84ms
step:3231/5100 train_loss:3.3375 train_time:456863ms step_avg:141.84ms
step:3232/5100 train_loss:3.3280 train_time:457001ms step_avg:141.84ms
step:3233/5100 train_loss:3.5691 train_time:457139ms step_avg:141.84ms
step:3234/5100 train_loss:3.5530 train_time:457279ms step_avg:141.84ms
step:3235/5100 train_loss:3.5215 train_time:457419ms step_avg:141.84ms
step:3236/5100 train_loss:3.4081 train_time:457556ms step_avg:141.83ms
step:3237/5100 train_loss:3.5681 train_time:457701ms step_avg:141.83ms
step:3238/5100 train_loss:3.4297 train_time:457844ms step_avg:141.84ms
step:3239/5100 train_loss:3.5607 train_time:457984ms step_avg:141.83ms
step:3240/5100 train_loss:3.5173 train_time:458123ms step_avg:141.83ms
step:3241/5100 train_loss:3.4207 train_time:458264ms step_avg:141.83ms
step:3242/5100 train_loss:3.3766 train_time:458402ms step_avg:141.83ms
step:3243/5100 train_loss:3.6014 train_time:458543ms step_avg:141.83ms
step:3244/5100 train_loss:3.4730 train_time:458685ms step_avg:141.83ms
step:3245/5100 train_loss:3.5183 train_time:458825ms step_avg:141.83ms
step:3246/5100 train_loss:3.4061 train_time:458965ms step_avg:141.83ms
step:3247/5100 train_loss:3.5370 train_time:459109ms step_avg:141.83ms
step:3248/5100 train_loss:3.4664 train_time:459245ms step_avg:141.83ms
step:3249/5100 train_loss:3.4133 train_time:459387ms step_avg:141.83ms
step:3250/5100 train_loss:3.2812 train_time:459524ms step_avg:141.83ms
step:3250/5100 val_loss:3.4283 train_time:459582ms step_avg:141.85ms
step:3251/5100 train_loss:3.4783 train_time:459682ms step_avg:141.83ms
step:3252/5100 train_loss:3.4850 train_time:459822ms step_avg:141.83ms
step:3253/5100 train_loss:3.4441 train_time:459965ms step_avg:141.83ms
step:3254/5100 train_loss:3.3606 train_time:460104ms step_avg:141.83ms
step:3255/5100 train_loss:3.5106 train_time:460243ms step_avg:141.83ms
step:3256/5100 train_loss:3.5438 train_time:460382ms step_avg:141.83ms
step:3257/5100 train_loss:3.4802 train_time:460520ms step_avg:141.83ms
step:3258/5100 train_loss:3.5144 train_time:460660ms step_avg:141.83ms
step:3259/5100 train_loss:3.3494 train_time:460803ms step_avg:141.83ms
step:3260/5100 train_loss:3.4369 train_time:460944ms step_avg:141.83ms
step:3261/5100 train_loss:3.3018 train_time:461085ms step_avg:141.83ms
step:3262/5100 train_loss:3.3401 train_time:461223ms step_avg:141.83ms
step:3263/5100 train_loss:3.3748 train_time:461362ms step_avg:141.83ms
step:3264/5100 train_loss:3.5246 train_time:461500ms step_avg:141.83ms
step:3265/5100 train_loss:3.4036 train_time:461640ms step_avg:141.82ms
step:3266/5100 train_loss:3.4644 train_time:461781ms step_avg:141.82ms
step:3267/5100 train_loss:3.4805 train_time:461922ms step_avg:141.82ms
step:3268/5100 train_loss:3.5605 train_time:462063ms step_avg:141.82ms
step:3269/5100 train_loss:3.3816 train_time:462202ms step_avg:141.82ms
step:3270/5100 train_loss:3.5021 train_time:462343ms step_avg:141.82ms
step:3271/5100 train_loss:3.3702 train_time:462482ms step_avg:141.82ms
step:3272/5100 train_loss:3.2757 train_time:462622ms step_avg:141.82ms
step:3273/5100 train_loss:3.3833 train_time:462761ms step_avg:141.82ms
step:3274/5100 train_loss:3.5199 train_time:462901ms step_avg:141.82ms
step:3275/5100 train_loss:3.3150 train_time:463043ms step_avg:141.82ms
step:3276/5100 train_loss:3.4645 train_time:463182ms step_avg:141.82ms
step:3277/5100 train_loss:3.4642 train_time:463322ms step_avg:141.82ms
step:3278/5100 train_loss:3.4552 train_time:463460ms step_avg:141.82ms
step:3279/5100 train_loss:3.4292 train_time:463600ms step_avg:141.82ms
step:3280/5100 train_loss:3.5762 train_time:463740ms step_avg:141.82ms
step:3281/5100 train_loss:3.4317 train_time:463881ms step_avg:141.82ms
step:3282/5100 train_loss:3.4783 train_time:464022ms step_avg:141.82ms
step:3283/5100 train_loss:3.3312 train_time:464162ms step_avg:141.82ms
step:3284/5100 train_loss:3.4600 train_time:464301ms step_avg:141.81ms
step:3285/5100 train_loss:3.5196 train_time:464441ms step_avg:141.81ms
step:3286/5100 train_loss:3.5002 train_time:464581ms step_avg:141.81ms
step:3287/5100 train_loss:3.5322 train_time:464721ms step_avg:141.81ms
step:3288/5100 train_loss:3.4093 train_time:464862ms step_avg:141.81ms
step:3289/5100 train_loss:3.5220 train_time:465001ms step_avg:141.81ms
step:3290/5100 train_loss:3.4389 train_time:465142ms step_avg:141.81ms
step:3291/5100 train_loss:3.3256 train_time:465283ms step_avg:141.81ms
step:3292/5100 train_loss:3.4490 train_time:465423ms step_avg:141.81ms
step:3293/5100 train_loss:3.4839 train_time:465562ms step_avg:141.81ms
step:3294/5100 train_loss:3.4661 train_time:465701ms step_avg:141.81ms
step:3295/5100 train_loss:3.3532 train_time:465841ms step_avg:141.81ms
step:3296/5100 train_loss:3.4087 train_time:465982ms step_avg:141.81ms
step:3297/5100 train_loss:3.4534 train_time:466123ms step_avg:141.81ms
step:3298/5100 train_loss:3.4505 train_time:466264ms step_avg:141.81ms
step:3299/5100 train_loss:3.4336 train_time:466404ms step_avg:141.81ms
step:3300/5100 train_loss:3.4912 train_time:466543ms step_avg:141.81ms
step:3301/5100 train_loss:3.3963 train_time:466683ms step_avg:141.81ms
step:3302/5100 train_loss:3.4624 train_time:466823ms step_avg:141.81ms
step:3303/5100 train_loss:3.4150 train_time:466963ms step_avg:141.80ms
step:3304/5100 train_loss:3.4267 train_time:467103ms step_avg:141.80ms
step:3305/5100 train_loss:3.4195 train_time:467245ms step_avg:141.80ms
step:3306/5100 train_loss:3.5161 train_time:467387ms step_avg:141.80ms
step:3307/5100 train_loss:3.4448 train_time:467527ms step_avg:141.80ms
step:3308/5100 train_loss:3.4108 train_time:467667ms step_avg:141.80ms
step:3309/5100 train_loss:3.5290 train_time:467807ms step_avg:141.80ms
step:3310/5100 train_loss:3.4054 train_time:467946ms step_avg:141.80ms
step:3311/5100 train_loss:3.3500 train_time:468088ms step_avg:141.80ms
step:3312/5100 train_loss:3.4581 train_time:468229ms step_avg:141.80ms
step:3313/5100 train_loss:3.4280 train_time:468370ms step_avg:141.80ms
step:3314/5100 train_loss:3.6280 train_time:468511ms step_avg:141.80ms
step:3315/5100 train_loss:3.4591 train_time:468651ms step_avg:141.80ms
step:3316/5100 train_loss:3.4237 train_time:468791ms step_avg:141.80ms
step:3317/5100 train_loss:3.0572 train_time:468931ms step_avg:141.80ms
step:3318/5100 train_loss:3.5629 train_time:469072ms step_avg:141.80ms
step:3319/5100 train_loss:3.3973 train_time:469212ms step_avg:141.80ms
step:3320/5100 train_loss:3.4733 train_time:469353ms step_avg:141.80ms
step:3321/5100 train_loss:3.4034 train_time:469496ms step_avg:141.80ms
step:3322/5100 train_loss:3.4801 train_time:469634ms step_avg:141.80ms
step:3323/5100 train_loss:3.4200 train_time:469774ms step_avg:141.80ms
step:3324/5100 train_loss:3.3418 train_time:469914ms step_avg:141.80ms
step:3325/5100 train_loss:3.2743 train_time:470054ms step_avg:141.80ms
step:3326/5100 train_loss:3.4410 train_time:470194ms step_avg:141.80ms
step:3327/5100 train_loss:3.4027 train_time:470333ms step_avg:141.79ms
step:3328/5100 train_loss:3.3255 train_time:470476ms step_avg:141.79ms
step:3329/5100 train_loss:3.3650 train_time:470614ms step_avg:141.79ms
step:3330/5100 train_loss:3.3186 train_time:470755ms step_avg:141.79ms
step:3331/5100 train_loss:3.5666 train_time:470896ms step_avg:141.79ms
step:3332/5100 train_loss:3.4644 train_time:471034ms step_avg:141.79ms
step:3333/5100 train_loss:3.4521 train_time:471175ms step_avg:141.79ms
step:3334/5100 train_loss:3.3104 train_time:471314ms step_avg:141.79ms
step:3335/5100 train_loss:3.3792 train_time:471455ms step_avg:141.79ms
step:3336/5100 train_loss:3.4808 train_time:471595ms step_avg:141.79ms
step:3337/5100 train_loss:3.4538 train_time:471734ms step_avg:141.79ms
step:3338/5100 train_loss:3.4866 train_time:471875ms step_avg:141.79ms
step:3339/5100 train_loss:3.4149 train_time:472014ms step_avg:141.79ms
step:3340/5100 train_loss:3.4371 train_time:472155ms step_avg:141.79ms
step:3341/5100 train_loss:3.4565 train_time:472295ms step_avg:141.79ms
step:3342/5100 train_loss:3.4641 train_time:472448ms step_avg:141.79ms
step:3343/5100 train_loss:3.4578 train_time:472575ms step_avg:141.79ms
step:3344/5100 train_loss:3.3833 train_time:472714ms step_avg:141.79ms
step:3345/5100 train_loss:3.2949 train_time:472856ms step_avg:141.79ms
step:3346/5100 train_loss:3.6218 train_time:472995ms step_avg:141.79ms
step:3347/5100 train_loss:3.3842 train_time:473134ms step_avg:141.78ms
step:3348/5100 train_loss:3.5492 train_time:473275ms step_avg:141.78ms
step:3349/5100 train_loss:3.4174 train_time:473414ms step_avg:141.78ms
step:3350/5100 train_loss:3.4972 train_time:473554ms step_avg:141.78ms
step:3351/5100 train_loss:3.2373 train_time:473695ms step_avg:141.78ms
step:3352/5100 train_loss:3.2669 train_time:473834ms step_avg:141.78ms
step:3353/5100 train_loss:3.4400 train_time:473974ms step_avg:141.78ms
step:3354/5100 train_loss:3.3139 train_time:474114ms step_avg:141.78ms
step:3355/5100 train_loss:3.4697 train_time:474255ms step_avg:141.78ms
step:3356/5100 train_loss:3.3305 train_time:474395ms step_avg:141.78ms
step:3357/5100 train_loss:3.5025 train_time:474533ms step_avg:141.78ms
step:3358/5100 train_loss:3.3567 train_time:474675ms step_avg:141.78ms
step:3359/5100 train_loss:3.5331 train_time:474815ms step_avg:141.78ms
step:3360/5100 train_loss:3.3398 train_time:474955ms step_avg:141.78ms
step:3361/5100 train_loss:4.1080 train_time:475096ms step_avg:141.78ms
step:3362/5100 train_loss:3.4939 train_time:475234ms step_avg:141.78ms
step:3363/5100 train_loss:3.5234 train_time:475375ms step_avg:141.78ms
step:3364/5100 train_loss:3.3987 train_time:475514ms step_avg:141.78ms
step:3365/5100 train_loss:3.5171 train_time:475655ms step_avg:141.77ms
step:3366/5100 train_loss:3.4208 train_time:475798ms step_avg:141.78ms
step:3367/5100 train_loss:3.5892 train_time:475934ms step_avg:141.77ms
step:3368/5100 train_loss:3.4010 train_time:476074ms step_avg:141.77ms
step:3369/5100 train_loss:3.4184 train_time:476214ms step_avg:141.77ms
step:3370/5100 train_loss:3.3875 train_time:476355ms step_avg:141.77ms
step:3371/5100 train_loss:3.3479 train_time:476494ms step_avg:141.77ms
step:3372/5100 train_loss:3.3478 train_time:476634ms step_avg:141.77ms
step:3373/5100 train_loss:3.4095 train_time:476775ms step_avg:141.77ms
step:3374/5100 train_loss:3.4469 train_time:476914ms step_avg:141.77ms
step:3375/5100 train_loss:3.4159 train_time:477054ms step_avg:141.77ms
step:3375/5100 val_loss:3.4228 train_time:477111ms step_avg:141.79ms
step:3376/5100 train_loss:3.4573 train_time:477206ms step_avg:141.77ms
step:3377/5100 train_loss:3.4578 train_time:477352ms step_avg:141.77ms
step:3378/5100 train_loss:3.5412 train_time:477493ms step_avg:141.77ms
step:3379/5100 train_loss:3.3884 train_time:477632ms step_avg:141.77ms
step:3380/5100 train_loss:3.4053 train_time:477771ms step_avg:141.77ms
step:3381/5100 train_loss:3.4083 train_time:477910ms step_avg:141.77ms
step:3382/5100 train_loss:3.5128 train_time:478048ms step_avg:141.77ms
step:3383/5100 train_loss:3.3551 train_time:478188ms step_avg:141.77ms
step:3384/5100 train_loss:3.5232 train_time:478332ms step_avg:141.77ms
step:3385/5100 train_loss:3.3795 train_time:478475ms step_avg:141.77ms
step:3386/5100 train_loss:3.4028 train_time:478615ms step_avg:141.77ms
step:3387/5100 train_loss:3.3407 train_time:478754ms step_avg:141.77ms
step:3388/5100 train_loss:3.5188 train_time:478893ms step_avg:141.77ms
step:3389/5100 train_loss:3.4851 train_time:479033ms step_avg:141.77ms
step:3390/5100 train_loss:3.5013 train_time:479174ms step_avg:141.77ms
step:3391/5100 train_loss:3.4851 train_time:479315ms step_avg:141.77ms
step:3392/5100 train_loss:3.4223 train_time:479457ms step_avg:141.77ms
step:3393/5100 train_loss:3.5444 train_time:479599ms step_avg:141.77ms
step:3394/5100 train_loss:3.4968 train_time:479738ms step_avg:141.77ms
step:3395/5100 train_loss:3.5918 train_time:479878ms step_avg:141.77ms
step:3396/5100 train_loss:3.4470 train_time:480018ms step_avg:141.77ms
step:3397/5100 train_loss:3.4292 train_time:480158ms step_avg:141.76ms
step:3398/5100 train_loss:3.3928 train_time:480299ms step_avg:141.76ms
step:3399/5100 train_loss:3.4465 train_time:480440ms step_avg:141.76ms
step:3400/5100 train_loss:3.4418 train_time:480580ms step_avg:141.76ms
step:3401/5100 train_loss:3.5323 train_time:480719ms step_avg:141.76ms
step:3402/5100 train_loss:3.4001 train_time:481020ms step_avg:141.81ms
step:3403/5100 train_loss:3.5857 train_time:481158ms step_avg:141.81ms
step:3404/5100 train_loss:3.4103 train_time:481296ms step_avg:141.81ms
step:3405/5100 train_loss:3.4261 train_time:481435ms step_avg:141.81ms
step:3406/5100 train_loss:3.3682 train_time:481573ms step_avg:141.81ms
step:3407/5100 train_loss:3.4314 train_time:481713ms step_avg:141.81ms
step:3408/5100 train_loss:3.4348 train_time:481852ms step_avg:141.80ms
step:3409/5100 train_loss:3.4136 train_time:481998ms step_avg:141.81ms
step:3410/5100 train_loss:3.4387 train_time:482139ms step_avg:141.81ms
step:3411/5100 train_loss:3.4011 train_time:482280ms step_avg:141.81ms
step:3412/5100 train_loss:3.4336 train_time:482418ms step_avg:141.80ms
step:3413/5100 train_loss:3.3617 train_time:482558ms step_avg:141.80ms
step:3414/5100 train_loss:3.5745 train_time:482698ms step_avg:141.80ms
step:3415/5100 train_loss:3.3250 train_time:482838ms step_avg:141.80ms
step:3416/5100 train_loss:3.4893 train_time:482980ms step_avg:141.80ms
step:3417/5100 train_loss:3.3588 train_time:483121ms step_avg:141.80ms
step:3418/5100 train_loss:3.4681 train_time:483261ms step_avg:141.80ms
step:3419/5100 train_loss:3.4656 train_time:483400ms step_avg:141.80ms
step:3420/5100 train_loss:3.4929 train_time:483719ms step_avg:141.85ms
step:3421/5100 train_loss:3.3668 train_time:483851ms step_avg:141.85ms
step:3422/5100 train_loss:3.4120 train_time:483989ms step_avg:141.85ms
step:3423/5100 train_loss:3.3417 train_time:484129ms step_avg:141.85ms
step:3424/5100 train_loss:3.6677 train_time:484265ms step_avg:141.85ms
step:3425/5100 train_loss:3.5515 train_time:484404ms step_avg:141.85ms
step:3426/5100 train_loss:3.4234 train_time:484543ms step_avg:141.85ms
step:3427/5100 train_loss:3.3829 train_time:484688ms step_avg:141.85ms
step:3428/5100 train_loss:3.3555 train_time:484832ms step_avg:141.85ms
step:3429/5100 train_loss:3.3547 train_time:484972ms step_avg:141.85ms
step:3430/5100 train_loss:3.4211 train_time:485113ms step_avg:141.85ms
step:3431/5100 train_loss:3.4430 train_time:485252ms step_avg:141.85ms
step:3432/5100 train_loss:3.5427 train_time:485394ms step_avg:141.85ms
step:3433/5100 train_loss:3.3577 train_time:485531ms step_avg:141.84ms
step:3434/5100 train_loss:3.5714 train_time:485672ms step_avg:141.84ms
step:3435/5100 train_loss:3.4956 train_time:485815ms step_avg:141.84ms
step:3436/5100 train_loss:3.3391 train_time:485957ms step_avg:141.84ms
step:3437/5100 train_loss:3.3851 train_time:486099ms step_avg:141.84ms
step:3438/5100 train_loss:3.4317 train_time:486241ms step_avg:141.84ms
step:3439/5100 train_loss:3.5242 train_time:486378ms step_avg:141.84ms
step:3440/5100 train_loss:3.3026 train_time:486519ms step_avg:141.84ms
step:3441/5100 train_loss:3.4790 train_time:486660ms step_avg:141.84ms
step:3442/5100 train_loss:3.3803 train_time:486800ms step_avg:141.84ms
step:3443/5100 train_loss:3.5632 train_time:486940ms step_avg:141.84ms
step:3444/5100 train_loss:3.4301 train_time:487080ms step_avg:141.84ms
step:3445/5100 train_loss:3.3134 train_time:487219ms step_avg:141.84ms
step:3446/5100 train_loss:3.5251 train_time:487358ms step_avg:141.84ms
step:3447/5100 train_loss:3.6008 train_time:487499ms step_avg:141.84ms
step:3448/5100 train_loss:3.4180 train_time:487640ms step_avg:141.84ms
step:3449/5100 train_loss:3.4331 train_time:487781ms step_avg:141.84ms
step:3450/5100 train_loss:3.5214 train_time:487921ms step_avg:141.84ms
step:3451/5100 train_loss:3.5169 train_time:488061ms step_avg:141.84ms
step:3452/5100 train_loss:3.5219 train_time:488200ms step_avg:141.84ms
step:3453/5100 train_loss:3.3259 train_time:488339ms step_avg:141.84ms
step:3454/5100 train_loss:3.4469 train_time:488480ms step_avg:141.83ms
step:3455/5100 train_loss:3.3285 train_time:488620ms step_avg:141.83ms
step:3456/5100 train_loss:3.6178 train_time:488760ms step_avg:141.83ms
step:3457/5100 train_loss:3.2971 train_time:488901ms step_avg:141.83ms
step:3458/5100 train_loss:3.4407 train_time:489039ms step_avg:141.83ms
step:3459/5100 train_loss:3.3857 train_time:489180ms step_avg:141.83ms
step:3460/5100 train_loss:3.3867 train_time:489320ms step_avg:141.83ms
step:3461/5100 train_loss:3.3814 train_time:489460ms step_avg:141.83ms
step:3462/5100 train_loss:3.3883 train_time:489601ms step_avg:141.83ms
step:3463/5100 train_loss:3.4927 train_time:489741ms step_avg:141.83ms
step:3464/5100 train_loss:3.3661 train_time:489881ms step_avg:141.83ms
step:3465/5100 train_loss:3.3743 train_time:490020ms step_avg:141.83ms
step:3466/5100 train_loss:3.3543 train_time:490159ms step_avg:141.83ms
step:3467/5100 train_loss:3.5123 train_time:490301ms step_avg:141.83ms
step:3468/5100 train_loss:3.3985 train_time:490439ms step_avg:141.83ms
step:3469/5100 train_loss:3.4140 train_time:490581ms step_avg:141.83ms
step:3470/5100 train_loss:3.6048 train_time:490720ms step_avg:141.83ms
step:3471/5100 train_loss:3.5016 train_time:490861ms step_avg:141.83ms
step:3472/5100 train_loss:3.5504 train_time:491001ms step_avg:141.83ms
step:3473/5100 train_loss:4.1785 train_time:491140ms step_avg:141.83ms
step:3474/5100 train_loss:3.4212 train_time:491280ms step_avg:141.82ms
step:3475/5100 train_loss:3.4274 train_time:491419ms step_avg:141.82ms
step:3476/5100 train_loss:3.4129 train_time:491560ms step_avg:141.82ms
step:3477/5100 train_loss:3.3509 train_time:491700ms step_avg:141.82ms
step:3478/5100 train_loss:3.4278 train_time:491840ms step_avg:141.82ms
step:3479/5100 train_loss:3.4229 train_time:491980ms step_avg:141.82ms
step:3480/5100 train_loss:3.3257 train_time:492119ms step_avg:141.82ms
step:3481/5100 train_loss:3.6242 train_time:492260ms step_avg:141.82ms
step:3482/5100 train_loss:3.4924 train_time:492400ms step_avg:141.82ms
step:3483/5100 train_loss:3.4430 train_time:492539ms step_avg:141.82ms
step:3484/5100 train_loss:3.4597 train_time:492680ms step_avg:141.82ms
step:3485/5100 train_loss:3.4327 train_time:492821ms step_avg:141.82ms
step:3486/5100 train_loss:3.6225 train_time:492961ms step_avg:141.82ms
step:3487/5100 train_loss:3.6415 train_time:493101ms step_avg:141.82ms
step:3488/5100 train_loss:3.4997 train_time:493240ms step_avg:141.82ms
step:3489/5100 train_loss:3.3582 train_time:493383ms step_avg:141.82ms
step:3490/5100 train_loss:3.5273 train_time:493520ms step_avg:141.82ms
step:3491/5100 train_loss:3.4320 train_time:493660ms step_avg:141.82ms
step:3492/5100 train_loss:3.4813 train_time:493801ms step_avg:141.82ms
step:3493/5100 train_loss:3.3170 train_time:493940ms step_avg:141.81ms
step:3494/5100 train_loss:3.4470 train_time:494080ms step_avg:141.81ms
step:3495/5100 train_loss:3.4021 train_time:494219ms step_avg:141.81ms
step:3496/5100 train_loss:3.4216 train_time:494360ms step_avg:141.81ms
step:3497/5100 train_loss:3.5892 train_time:494501ms step_avg:141.81ms
step:3498/5100 train_loss:3.4250 train_time:494639ms step_avg:141.81ms
step:3499/5100 train_loss:3.4490 train_time:494781ms step_avg:141.81ms
step:3500/5100 train_loss:3.4457 train_time:494920ms step_avg:141.81ms
step:3500/5100 val_loss:3.4154 train_time:494975ms step_avg:141.83ms
step:3501/5100 train_loss:3.4901 train_time:495071ms step_avg:141.81ms
step:3502/5100 train_loss:3.5633 train_time:495218ms step_avg:141.82ms
step:3503/5100 train_loss:3.2692 train_time:495358ms step_avg:141.81ms
step:3504/5100 train_loss:3.4289 train_time:495496ms step_avg:141.81ms
step:3505/5100 train_loss:3.4545 train_time:495635ms step_avg:141.81ms
step:3506/5100 train_loss:3.4776 train_time:495773ms step_avg:141.81ms
step:3507/5100 train_loss:3.3515 train_time:495912ms step_avg:141.81ms
step:3508/5100 train_loss:3.4994 train_time:496054ms step_avg:141.81ms
step:3509/5100 train_loss:3.3937 train_time:496197ms step_avg:141.81ms
step:3510/5100 train_loss:3.6087 train_time:496338ms step_avg:141.81ms
step:3511/5100 train_loss:3.4177 train_time:496478ms step_avg:141.81ms
step:3512/5100 train_loss:3.3646 train_time:496618ms step_avg:141.81ms
step:3513/5100 train_loss:3.4275 train_time:496757ms step_avg:141.81ms
step:3514/5100 train_loss:3.3882 train_time:496898ms step_avg:141.81ms
step:3515/5100 train_loss:3.4435 train_time:497039ms step_avg:141.81ms
step:3516/5100 train_loss:3.4474 train_time:497180ms step_avg:141.81ms
step:3517/5100 train_loss:3.4157 train_time:497320ms step_avg:141.81ms
step:3518/5100 train_loss:3.4409 train_time:497460ms step_avg:141.81ms
step:3519/5100 train_loss:3.4295 train_time:497599ms step_avg:141.81ms
step:3520/5100 train_loss:3.4493 train_time:497738ms step_avg:141.81ms
step:3521/5100 train_loss:3.5185 train_time:497878ms step_avg:141.81ms
step:3522/5100 train_loss:3.4327 train_time:498019ms step_avg:141.80ms
step:3523/5100 train_loss:3.3552 train_time:498160ms step_avg:141.80ms
step:3524/5100 train_loss:3.3992 train_time:498301ms step_avg:141.80ms
step:3525/5100 train_loss:3.3980 train_time:498439ms step_avg:141.80ms
step:3526/5100 train_loss:3.3920 train_time:498580ms step_avg:141.80ms
step:3527/5100 train_loss:3.5253 train_time:498718ms step_avg:141.80ms
step:3528/5100 train_loss:3.3545 train_time:498858ms step_avg:141.80ms
step:3529/5100 train_loss:3.2408 train_time:498999ms step_avg:141.80ms
step:3530/5100 train_loss:3.5339 train_time:499139ms step_avg:141.80ms
step:3531/5100 train_loss:3.3413 train_time:499281ms step_avg:141.80ms
step:3532/5100 train_loss:3.3750 train_time:499420ms step_avg:141.80ms
step:3533/5100 train_loss:3.2968 train_time:499560ms step_avg:141.80ms
step:3534/5100 train_loss:3.3559 train_time:499700ms step_avg:141.80ms
step:3535/5100 train_loss:3.3227 train_time:499840ms step_avg:141.80ms
step:3536/5100 train_loss:3.4994 train_time:499981ms step_avg:141.80ms
step:3537/5100 train_loss:3.4606 train_time:500121ms step_avg:141.80ms
step:3538/5100 train_loss:3.4852 train_time:500261ms step_avg:141.80ms
step:3539/5100 train_loss:3.3857 train_time:500401ms step_avg:141.80ms
step:3540/5100 train_loss:3.3575 train_time:500540ms step_avg:141.80ms
step:3541/5100 train_loss:3.4934 train_time:500680ms step_avg:141.80ms
step:3542/5100 train_loss:3.3407 train_time:500820ms step_avg:141.80ms
step:3543/5100 train_loss:3.5672 train_time:500961ms step_avg:141.79ms
step:3544/5100 train_loss:3.6810 train_time:501102ms step_avg:141.79ms
step:3545/5100 train_loss:3.4927 train_time:501241ms step_avg:141.79ms
step:3546/5100 train_loss:3.5243 train_time:501382ms step_avg:141.79ms
step:3547/5100 train_loss:3.2905 train_time:501521ms step_avg:141.79ms
step:3548/5100 train_loss:3.3744 train_time:501659ms step_avg:141.79ms
step:3549/5100 train_loss:3.4000 train_time:501800ms step_avg:141.79ms
step:3550/5100 train_loss:3.4885 train_time:501940ms step_avg:141.79ms
step:3551/5100 train_loss:3.4496 train_time:502080ms step_avg:141.79ms
step:3552/5100 train_loss:3.3869 train_time:502220ms step_avg:141.79ms
step:3553/5100 train_loss:3.4987 train_time:502361ms step_avg:141.79ms
step:3554/5100 train_loss:3.4241 train_time:502502ms step_avg:141.79ms
step:3555/5100 train_loss:3.3845 train_time:502640ms step_avg:141.79ms
step:3556/5100 train_loss:3.3261 train_time:502781ms step_avg:141.79ms
step:3557/5100 train_loss:3.2837 train_time:502921ms step_avg:141.79ms
step:3558/5100 train_loss:3.4014 train_time:503061ms step_avg:141.79ms
step:3559/5100 train_loss:3.4222 train_time:503201ms step_avg:141.79ms
step:3560/5100 train_loss:3.6236 train_time:503340ms step_avg:141.79ms
step:3561/5100 train_loss:3.4930 train_time:503481ms step_avg:141.79ms
step:3562/5100 train_loss:3.4030 train_time:503620ms step_avg:141.78ms
step:3563/5100 train_loss:3.2777 train_time:503760ms step_avg:141.78ms
step:3564/5100 train_loss:3.7901 train_time:503901ms step_avg:141.78ms
step:3565/5100 train_loss:3.3590 train_time:504040ms step_avg:141.78ms
step:3566/5100 train_loss:3.3003 train_time:504183ms step_avg:141.78ms
step:3567/5100 train_loss:3.3286 train_time:504321ms step_avg:141.78ms
step:3568/5100 train_loss:3.4621 train_time:504461ms step_avg:141.78ms
step:3569/5100 train_loss:3.3989 train_time:504601ms step_avg:141.78ms
step:3570/5100 train_loss:3.5195 train_time:504740ms step_avg:141.78ms
step:3571/5100 train_loss:3.4346 train_time:504880ms step_avg:141.78ms
step:3572/5100 train_loss:3.7203 train_time:505020ms step_avg:141.78ms
step:3573/5100 train_loss:3.3478 train_time:505160ms step_avg:141.78ms
step:3574/5100 train_loss:3.4202 train_time:505301ms step_avg:141.78ms
step:3575/5100 train_loss:3.5869 train_time:505441ms step_avg:141.78ms
step:3576/5100 train_loss:3.4640 train_time:505581ms step_avg:141.78ms
step:3577/5100 train_loss:3.4029 train_time:505720ms step_avg:141.78ms
step:3578/5100 train_loss:3.3698 train_time:505860ms step_avg:141.78ms
step:3579/5100 train_loss:3.4434 train_time:506001ms step_avg:141.78ms
step:3580/5100 train_loss:3.3912 train_time:506141ms step_avg:141.78ms
step:3581/5100 train_loss:3.2954 train_time:506281ms step_avg:141.78ms
step:3582/5100 train_loss:3.3681 train_time:506421ms step_avg:141.78ms
step:3583/5100 train_loss:3.3188 train_time:506560ms step_avg:141.77ms
step:3584/5100 train_loss:3.4247 train_time:506701ms step_avg:141.77ms
step:3585/5100 train_loss:3.5165 train_time:506840ms step_avg:141.77ms
step:3586/5100 train_loss:3.3661 train_time:506982ms step_avg:141.77ms
step:3587/5100 train_loss:3.4180 train_time:507120ms step_avg:141.77ms
step:3588/5100 train_loss:3.4147 train_time:507261ms step_avg:141.77ms
step:3589/5100 train_loss:3.4020 train_time:507401ms step_avg:141.77ms
step:3590/5100 train_loss:3.4018 train_time:507540ms step_avg:141.77ms
step:3591/5100 train_loss:3.5199 train_time:507849ms step_avg:141.82ms
step:3592/5100 train_loss:3.3934 train_time:507986ms step_avg:141.82ms
step:3593/5100 train_loss:3.4777 train_time:508124ms step_avg:141.82ms
step:3594/5100 train_loss:3.4614 train_time:508262ms step_avg:141.81ms
step:3595/5100 train_loss:3.4183 train_time:508400ms step_avg:141.81ms
step:3596/5100 train_loss:3.3438 train_time:508538ms step_avg:141.81ms
step:3597/5100 train_loss:3.3546 train_time:508677ms step_avg:141.81ms
step:3598/5100 train_loss:3.6203 train_time:508820ms step_avg:141.81ms
step:3599/5100 train_loss:3.4036 train_time:508962ms step_avg:141.81ms
step:3600/5100 train_loss:3.4157 train_time:509102ms step_avg:141.81ms
step:3601/5100 train_loss:3.2796 train_time:509240ms step_avg:141.81ms
step:3602/5100 train_loss:3.4536 train_time:509381ms step_avg:141.81ms
step:3603/5100 train_loss:3.4003 train_time:509518ms step_avg:141.81ms
step:3604/5100 train_loss:3.5521 train_time:509657ms step_avg:141.81ms
step:3605/5100 train_loss:3.6006 train_time:509800ms step_avg:141.81ms
step:3606/5100 train_loss:3.3955 train_time:509942ms step_avg:141.81ms
step:3607/5100 train_loss:3.4351 train_time:510084ms step_avg:141.81ms
step:3608/5100 train_loss:3.6962 train_time:510221ms step_avg:141.81ms
step:3609/5100 train_loss:3.3939 train_time:510359ms step_avg:141.81ms
step:3610/5100 train_loss:3.5407 train_time:510676ms step_avg:141.85ms
step:3611/5100 train_loss:3.2887 train_time:510812ms step_avg:141.85ms
step:3612/5100 train_loss:3.3931 train_time:510950ms step_avg:141.85ms
step:3613/5100 train_loss:3.4628 train_time:511088ms step_avg:141.85ms
step:3614/5100 train_loss:3.6663 train_time:511226ms step_avg:141.85ms
step:3615/5100 train_loss:3.6749 train_time:511364ms step_avg:141.85ms
step:3616/5100 train_loss:3.3351 train_time:511502ms step_avg:141.85ms
step:3617/5100 train_loss:3.4244 train_time:511646ms step_avg:141.85ms
step:3618/5100 train_loss:3.3936 train_time:511787ms step_avg:141.85ms
step:3619/5100 train_loss:3.5155 train_time:511927ms step_avg:141.85ms
step:3620/5100 train_loss:3.4577 train_time:512067ms step_avg:141.85ms
step:3621/5100 train_loss:3.2809 train_time:512206ms step_avg:141.85ms
step:3622/5100 train_loss:3.4279 train_time:512345ms step_avg:141.85ms
step:3623/5100 train_loss:3.4287 train_time:512484ms step_avg:141.84ms
step:3624/5100 train_loss:3.3662 train_time:512625ms step_avg:141.84ms
step:3625/5100 train_loss:3.4862 train_time:512765ms step_avg:141.84ms
step:3625/5100 val_loss:3.4130 train_time:512821ms step_avg:141.86ms
step:3626/5100 train_loss:3.5503 train_time:512921ms step_avg:141.85ms
step:3627/5100 train_loss:3.5605 train_time:513062ms step_avg:141.85ms
step:3628/5100 train_loss:3.4460 train_time:513201ms step_avg:141.85ms
step:3629/5100 train_loss:3.5957 train_time:513339ms step_avg:141.85ms
step:3630/5100 train_loss:3.4206 train_time:513479ms step_avg:141.84ms
step:3631/5100 train_loss:3.4171 train_time:513617ms step_avg:141.84ms
step:3632/5100 train_loss:3.5043 train_time:513756ms step_avg:141.84ms
step:3633/5100 train_loss:3.4851 train_time:513898ms step_avg:141.84ms
step:3634/5100 train_loss:3.4103 train_time:514042ms step_avg:141.84ms
step:3635/5100 train_loss:3.4100 train_time:514182ms step_avg:141.84ms
step:3636/5100 train_loss:3.4653 train_time:514321ms step_avg:141.84ms
step:3637/5100 train_loss:3.6408 train_time:514460ms step_avg:141.84ms
step:3638/5100 train_loss:3.4365 train_time:514599ms step_avg:141.84ms
step:3639/5100 train_loss:3.4015 train_time:514739ms step_avg:141.84ms
step:3640/5100 train_loss:3.3964 train_time:514881ms step_avg:141.84ms
step:3641/5100 train_loss:3.6864 train_time:515022ms step_avg:141.84ms
step:3642/5100 train_loss:3.4211 train_time:515162ms step_avg:141.84ms
step:3643/5100 train_loss:3.4604 train_time:515302ms step_avg:141.84ms
step:3644/5100 train_loss:3.4365 train_time:515441ms step_avg:141.84ms
step:3645/5100 train_loss:3.3605 train_time:515580ms step_avg:141.84ms
step:3646/5100 train_loss:3.5578 train_time:515720ms step_avg:141.84ms
step:3647/5100 train_loss:3.3363 train_time:515861ms step_avg:141.84ms
step:3648/5100 train_loss:3.4136 train_time:516003ms step_avg:141.84ms
step:3649/5100 train_loss:3.4745 train_time:516142ms step_avg:141.84ms
step:3650/5100 train_loss:3.4343 train_time:516282ms step_avg:141.84ms
step:3651/5100 train_loss:3.4707 train_time:516420ms step_avg:141.83ms
step:3652/5100 train_loss:3.5245 train_time:516561ms step_avg:141.83ms
step:3653/5100 train_loss:3.3452 train_time:516700ms step_avg:141.83ms
step:3654/5100 train_loss:3.4533 train_time:516840ms step_avg:141.83ms
step:3655/5100 train_loss:3.4805 train_time:516982ms step_avg:141.83ms
step:3656/5100 train_loss:4.1596 train_time:517122ms step_avg:141.83ms
step:3657/5100 train_loss:3.5418 train_time:517272ms step_avg:141.83ms
step:3658/5100 train_loss:3.4537 train_time:517401ms step_avg:141.83ms
step:3659/5100 train_loss:3.4460 train_time:517540ms step_avg:141.83ms
step:3660/5100 train_loss:3.3266 train_time:517681ms step_avg:141.83ms
step:3661/5100 train_loss:3.4476 train_time:517821ms step_avg:141.83ms
step:3662/5100 train_loss:3.3218 train_time:517962ms step_avg:141.83ms
step:3663/5100 train_loss:3.4795 train_time:518103ms step_avg:141.83ms
step:3664/5100 train_loss:3.4814 train_time:518241ms step_avg:141.83ms
step:3665/5100 train_loss:3.3248 train_time:518382ms step_avg:141.83ms
step:3666/5100 train_loss:3.2487 train_time:518520ms step_avg:141.83ms
step:3667/5100 train_loss:3.6847 train_time:518660ms step_avg:141.83ms
step:3668/5100 train_loss:3.4665 train_time:518801ms step_avg:141.83ms
step:3669/5100 train_loss:3.4929 train_time:518942ms step_avg:141.83ms
step:3670/5100 train_loss:3.4123 train_time:519082ms step_avg:141.83ms
step:3671/5100 train_loss:3.4790 train_time:519221ms step_avg:141.83ms
step:3672/5100 train_loss:3.3766 train_time:519361ms step_avg:141.82ms
step:3673/5100 train_loss:3.3774 train_time:519501ms step_avg:141.82ms
step:3674/5100 train_loss:3.2831 train_time:519641ms step_avg:141.82ms
step:3675/5100 train_loss:3.3544 train_time:519781ms step_avg:141.82ms
step:3676/5100 train_loss:3.5115 train_time:519922ms step_avg:141.82ms
step:3677/5100 train_loss:3.3162 train_time:520061ms step_avg:141.82ms
step:3678/5100 train_loss:3.4865 train_time:520202ms step_avg:141.82ms
step:3679/5100 train_loss:3.4678 train_time:520341ms step_avg:141.82ms
step:3680/5100 train_loss:3.3661 train_time:520481ms step_avg:141.82ms
step:3681/5100 train_loss:3.4341 train_time:520620ms step_avg:141.82ms
step:3682/5100 train_loss:3.4972 train_time:520762ms step_avg:141.82ms
step:3683/5100 train_loss:3.5939 train_time:520901ms step_avg:141.82ms
step:3684/5100 train_loss:3.3387 train_time:521040ms step_avg:141.82ms
step:3685/5100 train_loss:3.4222 train_time:521181ms step_avg:141.82ms
step:3686/5100 train_loss:3.5630 train_time:521320ms step_avg:141.82ms
step:3687/5100 train_loss:3.3390 train_time:521461ms step_avg:141.82ms
step:3688/5100 train_loss:3.5524 train_time:521602ms step_avg:141.82ms
step:3689/5100 train_loss:3.2852 train_time:521740ms step_avg:141.82ms
step:3690/5100 train_loss:3.3732 train_time:521882ms step_avg:141.82ms
step:3691/5100 train_loss:3.5088 train_time:522022ms step_avg:141.82ms
step:3692/5100 train_loss:3.2968 train_time:522162ms step_avg:141.81ms
step:3693/5100 train_loss:3.4509 train_time:522301ms step_avg:141.81ms
step:3694/5100 train_loss:3.4284 train_time:522441ms step_avg:141.81ms
step:3695/5100 train_loss:3.4264 train_time:522582ms step_avg:141.81ms
step:3696/5100 train_loss:3.4654 train_time:522721ms step_avg:141.81ms
step:3697/5100 train_loss:3.3039 train_time:522861ms step_avg:141.81ms
step:3698/5100 train_loss:3.4445 train_time:523003ms step_avg:141.81ms
step:3699/5100 train_loss:3.4532 train_time:523142ms step_avg:141.81ms
step:3700/5100 train_loss:3.4244 train_time:523283ms step_avg:141.81ms
step:3701/5100 train_loss:3.4890 train_time:523421ms step_avg:141.81ms
step:3702/5100 train_loss:3.4508 train_time:523562ms step_avg:141.81ms
step:3703/5100 train_loss:3.3712 train_time:523705ms step_avg:141.81ms
step:3704/5100 train_loss:3.3475 train_time:523842ms step_avg:141.81ms
step:3705/5100 train_loss:3.4842 train_time:523982ms step_avg:141.81ms
step:3706/5100 train_loss:3.4939 train_time:524123ms step_avg:141.81ms
step:3707/5100 train_loss:3.4927 train_time:524262ms step_avg:141.81ms
step:3708/5100 train_loss:3.4520 train_time:524402ms step_avg:141.81ms
step:3709/5100 train_loss:3.3213 train_time:524541ms step_avg:141.81ms
step:3710/5100 train_loss:3.6260 train_time:524681ms step_avg:141.81ms
step:3711/5100 train_loss:3.2038 train_time:524821ms step_avg:141.81ms
step:3712/5100 train_loss:3.4798 train_time:524962ms step_avg:141.80ms
step:3713/5100 train_loss:3.3654 train_time:525102ms step_avg:141.80ms
step:3714/5100 train_loss:3.3965 train_time:525241ms step_avg:141.80ms
step:3715/5100 train_loss:3.7717 train_time:525381ms step_avg:141.80ms
step:3716/5100 train_loss:3.6217 train_time:525520ms step_avg:141.80ms
step:3717/5100 train_loss:3.8957 train_time:525661ms step_avg:141.80ms
step:3718/5100 train_loss:3.3979 train_time:525802ms step_avg:141.80ms
step:3719/5100 train_loss:3.3065 train_time:525941ms step_avg:141.80ms
step:3720/5100 train_loss:3.5709 train_time:526082ms step_avg:141.80ms
step:3721/5100 train_loss:3.3342 train_time:526221ms step_avg:141.80ms
step:3722/5100 train_loss:3.4297 train_time:526363ms step_avg:141.80ms
step:3723/5100 train_loss:3.2931 train_time:526501ms step_avg:141.80ms
step:3724/5100 train_loss:3.2851 train_time:526640ms step_avg:141.80ms
step:3725/5100 train_loss:3.4149 train_time:526781ms step_avg:141.80ms
step:3726/5100 train_loss:3.3628 train_time:526921ms step_avg:141.80ms
step:3727/5100 train_loss:3.6387 train_time:527061ms step_avg:141.80ms
step:3728/5100 train_loss:3.3561 train_time:527202ms step_avg:141.80ms
step:3729/5100 train_loss:3.3545 train_time:527341ms step_avg:141.80ms
step:3730/5100 train_loss:3.7102 train_time:527482ms step_avg:141.80ms
step:3731/5100 train_loss:3.4678 train_time:527621ms step_avg:141.80ms
step:3732/5100 train_loss:3.3738 train_time:527762ms step_avg:141.80ms
step:3733/5100 train_loss:3.2399 train_time:527902ms step_avg:141.79ms
step:3734/5100 train_loss:3.4762 train_time:528041ms step_avg:141.79ms
step:3735/5100 train_loss:3.3423 train_time:528181ms step_avg:141.79ms
step:3736/5100 train_loss:3.4407 train_time:528320ms step_avg:141.79ms
step:3737/5100 train_loss:3.3498 train_time:528461ms step_avg:141.79ms
step:3738/5100 train_loss:3.4488 train_time:528600ms step_avg:141.79ms
step:3739/5100 train_loss:3.3517 train_time:528741ms step_avg:141.79ms
step:3740/5100 train_loss:3.3969 train_time:528882ms step_avg:141.79ms
step:3741/5100 train_loss:3.6894 train_time:529021ms step_avg:141.79ms
step:3742/5100 train_loss:3.3523 train_time:529163ms step_avg:141.79ms
step:3743/5100 train_loss:3.4064 train_time:529303ms step_avg:141.79ms
step:3744/5100 train_loss:3.6140 train_time:529441ms step_avg:141.79ms
step:3745/5100 train_loss:3.3337 train_time:529582ms step_avg:141.79ms
step:3746/5100 train_loss:3.2704 train_time:529721ms step_avg:141.79ms
step:3747/5100 train_loss:3.4512 train_time:529861ms step_avg:141.79ms
step:3748/5100 train_loss:3.3045 train_time:530002ms step_avg:141.79ms
step:3749/5100 train_loss:3.3443 train_time:530142ms step_avg:141.79ms
step:3750/5100 train_loss:3.5427 train_time:530282ms step_avg:141.79ms
step:3750/5100 val_loss:3.4013 train_time:530337ms step_avg:141.80ms
step:3751/5100 train_loss:3.4276 train_time:530434ms step_avg:141.79ms
step:3752/5100 train_loss:3.6635 train_time:530578ms step_avg:141.79ms
step:3753/5100 train_loss:3.3772 train_time:530717ms step_avg:141.79ms
step:3754/5100 train_loss:3.3835 train_time:530855ms step_avg:141.79ms
step:3755/5100 train_loss:3.3460 train_time:530993ms step_avg:141.79ms
step:3756/5100 train_loss:3.4501 train_time:531131ms step_avg:141.79ms
step:3757/5100 train_loss:3.3967 train_time:531270ms step_avg:141.79ms
step:3758/5100 train_loss:3.4052 train_time:531413ms step_avg:141.79ms
step:3759/5100 train_loss:3.5897 train_time:531556ms step_avg:141.79ms
step:3760/5100 train_loss:3.4706 train_time:531696ms step_avg:141.79ms
step:3761/5100 train_loss:3.5840 train_time:531835ms step_avg:141.78ms
step:3762/5100 train_loss:3.3548 train_time:531974ms step_avg:141.78ms
step:3763/5100 train_loss:3.3707 train_time:532115ms step_avg:141.78ms
step:3764/5100 train_loss:3.5296 train_time:532252ms step_avg:141.78ms
step:3765/5100 train_loss:3.2882 train_time:532394ms step_avg:141.78ms
step:3766/5100 train_loss:3.3819 train_time:532535ms step_avg:141.78ms
step:3767/5100 train_loss:3.4750 train_time:532676ms step_avg:141.78ms
step:3768/5100 train_loss:3.2811 train_time:532814ms step_avg:141.78ms
step:3769/5100 train_loss:3.5520 train_time:532954ms step_avg:141.78ms
step:3770/5100 train_loss:3.3606 train_time:533094ms step_avg:141.78ms
step:3771/5100 train_loss:3.2381 train_time:533232ms step_avg:141.78ms
step:3772/5100 train_loss:3.4869 train_time:533374ms step_avg:141.78ms
step:3773/5100 train_loss:3.4065 train_time:533515ms step_avg:141.78ms
step:3774/5100 train_loss:3.4046 train_time:533655ms step_avg:141.78ms
step:3775/5100 train_loss:3.3987 train_time:533794ms step_avg:141.78ms
step:3776/5100 train_loss:3.4462 train_time:533933ms step_avg:141.78ms
step:3777/5100 train_loss:3.2795 train_time:534073ms step_avg:141.78ms
step:3778/5100 train_loss:3.3924 train_time:534215ms step_avg:141.78ms
step:3779/5100 train_loss:3.5164 train_time:534352ms step_avg:141.78ms
step:3780/5100 train_loss:3.4859 train_time:534659ms step_avg:141.82ms
step:3781/5100 train_loss:3.4878 train_time:534796ms step_avg:141.82ms
step:3782/5100 train_loss:3.4389 train_time:534933ms step_avg:141.82ms
step:3783/5100 train_loss:3.4403 train_time:535072ms step_avg:141.82ms
step:3784/5100 train_loss:3.4080 train_time:535210ms step_avg:141.81ms
step:3785/5100 train_loss:3.2927 train_time:535348ms step_avg:141.81ms
step:3786/5100 train_loss:3.3729 train_time:535488ms step_avg:141.81ms
step:3787/5100 train_loss:3.4104 train_time:535632ms step_avg:141.81ms
step:3788/5100 train_loss:3.3990 train_time:535777ms step_avg:141.81ms
step:3789/5100 train_loss:3.3491 train_time:535916ms step_avg:141.81ms
step:3790/5100 train_loss:3.3759 train_time:536055ms step_avg:141.81ms
step:3791/5100 train_loss:3.2494 train_time:536194ms step_avg:141.81ms
step:3792/5100 train_loss:3.4675 train_time:536333ms step_avg:141.81ms
step:3793/5100 train_loss:3.4521 train_time:536474ms step_avg:141.81ms
step:3794/5100 train_loss:3.3775 train_time:536614ms step_avg:141.81ms
step:3795/5100 train_loss:3.3388 train_time:536756ms step_avg:141.81ms
step:3796/5100 train_loss:3.1862 train_time:536897ms step_avg:141.81ms
step:3797/5100 train_loss:3.3818 train_time:537034ms step_avg:141.81ms
step:3798/5100 train_loss:3.3771 train_time:537174ms step_avg:141.81ms
step:3799/5100 train_loss:3.4412 train_time:537312ms step_avg:141.81ms
step:3800/5100 train_loss:3.3512 train_time:537611ms step_avg:141.85ms
step:3801/5100 train_loss:3.3287 train_time:537747ms step_avg:141.85ms
step:3802/5100 train_loss:3.2856 train_time:537887ms step_avg:141.85ms
step:3803/5100 train_loss:3.5900 train_time:538027ms step_avg:141.85ms
step:3804/5100 train_loss:3.4424 train_time:538164ms step_avg:141.85ms
step:3805/5100 train_loss:3.2975 train_time:538302ms step_avg:141.85ms
step:3806/5100 train_loss:3.5388 train_time:538440ms step_avg:141.84ms
step:3807/5100 train_loss:3.5029 train_time:538586ms step_avg:141.85ms
step:3808/5100 train_loss:3.3683 train_time:538729ms step_avg:141.85ms
step:3809/5100 train_loss:3.4335 train_time:538870ms step_avg:141.85ms
step:3810/5100 train_loss:3.3429 train_time:539010ms step_avg:141.84ms
step:3811/5100 train_loss:3.4152 train_time:539149ms step_avg:141.84ms
step:3812/5100 train_loss:3.3944 train_time:539289ms step_avg:141.84ms
step:3813/5100 train_loss:3.4218 train_time:539428ms step_avg:141.84ms
step:3814/5100 train_loss:3.4219 train_time:539572ms step_avg:141.84ms
step:3815/5100 train_loss:3.3004 train_time:539713ms step_avg:141.84ms
step:3816/5100 train_loss:3.6459 train_time:539853ms step_avg:141.84ms
step:3817/5100 train_loss:3.2508 train_time:539994ms step_avg:141.84ms
step:3818/5100 train_loss:3.4265 train_time:540134ms step_avg:141.84ms
step:3819/5100 train_loss:3.4048 train_time:540273ms step_avg:141.84ms
step:3820/5100 train_loss:3.3840 train_time:540413ms step_avg:141.84ms
step:3821/5100 train_loss:3.3210 train_time:540553ms step_avg:141.84ms
step:3822/5100 train_loss:3.4922 train_time:540696ms step_avg:141.84ms
step:3823/5100 train_loss:3.2216 train_time:540839ms step_avg:141.84ms
step:3824/5100 train_loss:3.3382 train_time:540975ms step_avg:141.84ms
step:3825/5100 train_loss:3.3887 train_time:541113ms step_avg:141.84ms
step:3826/5100 train_loss:3.5287 train_time:541253ms step_avg:141.84ms
step:3827/5100 train_loss:3.4674 train_time:541394ms step_avg:141.84ms
step:3828/5100 train_loss:3.8433 train_time:541533ms step_avg:141.84ms
step:3829/5100 train_loss:3.4677 train_time:541675ms step_avg:141.84ms
step:3830/5100 train_loss:3.2919 train_time:541814ms step_avg:141.84ms
step:3831/5100 train_loss:3.3357 train_time:541954ms step_avg:141.84ms
step:3832/5100 train_loss:3.5437 train_time:542094ms step_avg:141.84ms
step:3833/5100 train_loss:3.3590 train_time:542232ms step_avg:141.83ms
step:3834/5100 train_loss:3.4782 train_time:542373ms step_avg:141.83ms
step:3835/5100 train_loss:3.4142 train_time:542514ms step_avg:141.83ms
step:3836/5100 train_loss:3.2136 train_time:542654ms step_avg:141.83ms
step:3837/5100 train_loss:3.5005 train_time:542795ms step_avg:141.83ms
step:3838/5100 train_loss:3.4767 train_time:542934ms step_avg:141.83ms
step:3839/5100 train_loss:3.4313 train_time:543075ms step_avg:141.83ms
step:3840/5100 train_loss:3.4908 train_time:543213ms step_avg:141.83ms
step:3841/5100 train_loss:3.6131 train_time:543354ms step_avg:141.83ms
step:3842/5100 train_loss:3.3669 train_time:543495ms step_avg:141.83ms
step:3843/5100 train_loss:3.4198 train_time:543634ms step_avg:141.83ms
step:3844/5100 train_loss:3.5621 train_time:543776ms step_avg:141.83ms
step:3845/5100 train_loss:3.3592 train_time:543914ms step_avg:141.83ms
step:3846/5100 train_loss:3.2292 train_time:544054ms step_avg:141.83ms
step:3847/5100 train_loss:3.4694 train_time:544193ms step_avg:141.83ms
step:3848/5100 train_loss:3.3936 train_time:544333ms step_avg:141.83ms
step:3849/5100 train_loss:3.4366 train_time:544474ms step_avg:141.83ms
step:3850/5100 train_loss:3.3194 train_time:544614ms step_avg:141.83ms
step:3851/5100 train_loss:3.3188 train_time:544754ms step_avg:141.83ms
step:3852/5100 train_loss:3.4710 train_time:544895ms step_avg:141.83ms
step:3853/5100 train_loss:3.3213 train_time:545035ms step_avg:141.83ms
step:3854/5100 train_loss:3.2904 train_time:545174ms step_avg:141.82ms
step:3855/5100 train_loss:3.3749 train_time:545313ms step_avg:141.82ms
step:3856/5100 train_loss:3.4092 train_time:545454ms step_avg:141.82ms
step:3857/5100 train_loss:3.3853 train_time:545594ms step_avg:141.82ms
step:3858/5100 train_loss:3.4091 train_time:545733ms step_avg:141.82ms
step:3859/5100 train_loss:3.3898 train_time:545875ms step_avg:141.82ms
step:3860/5100 train_loss:3.3967 train_time:546014ms step_avg:141.82ms
step:3861/5100 train_loss:3.5586 train_time:546153ms step_avg:141.82ms
step:3862/5100 train_loss:3.3701 train_time:546294ms step_avg:141.82ms
step:3863/5100 train_loss:3.4831 train_time:546435ms step_avg:141.82ms
step:3864/5100 train_loss:3.4404 train_time:546575ms step_avg:141.82ms
step:3865/5100 train_loss:3.4901 train_time:546714ms step_avg:141.82ms
step:3866/5100 train_loss:3.4589 train_time:546856ms step_avg:141.82ms
step:3867/5100 train_loss:3.3969 train_time:546994ms step_avg:141.82ms
step:3868/5100 train_loss:3.4888 train_time:547133ms step_avg:141.82ms
step:3869/5100 train_loss:3.6583 train_time:547274ms step_avg:141.82ms
step:3870/5100 train_loss:3.4924 train_time:547415ms step_avg:141.82ms
step:3871/5100 train_loss:3.3864 train_time:547555ms step_avg:141.82ms
step:3872/5100 train_loss:3.5272 train_time:547696ms step_avg:141.82ms
step:3873/5100 train_loss:3.4281 train_time:547835ms step_avg:141.82ms
step:3874/5100 train_loss:3.3718 train_time:547975ms step_avg:141.82ms
step:3875/5100 train_loss:3.4677 train_time:548114ms step_avg:141.81ms
step:3875/5100 val_loss:3.3883 train_time:548170ms step_avg:141.83ms
step:3876/5100 train_loss:4.0001 train_time:548266ms step_avg:141.82ms
step:3877/5100 train_loss:3.4122 train_time:548411ms step_avg:141.82ms
step:3878/5100 train_loss:3.4059 train_time:548553ms step_avg:141.82ms
step:3879/5100 train_loss:3.3835 train_time:548693ms step_avg:141.82ms
step:3880/5100 train_loss:3.5935 train_time:548831ms step_avg:141.82ms
step:3881/5100 train_loss:3.3947 train_time:548970ms step_avg:141.82ms
step:3882/5100 train_loss:3.4640 train_time:549108ms step_avg:141.82ms
step:3883/5100 train_loss:3.5146 train_time:549251ms step_avg:141.82ms
step:3884/5100 train_loss:3.3294 train_time:549397ms step_avg:141.82ms
step:3885/5100 train_loss:3.3209 train_time:549537ms step_avg:141.82ms
step:3886/5100 train_loss:3.3595 train_time:549678ms step_avg:141.82ms
step:3887/5100 train_loss:3.3957 train_time:549818ms step_avg:141.82ms
step:3888/5100 train_loss:3.5732 train_time:549958ms step_avg:141.81ms
step:3889/5100 train_loss:3.4200 train_time:550098ms step_avg:141.81ms
step:3890/5100 train_loss:3.3581 train_time:550237ms step_avg:141.81ms
step:3891/5100 train_loss:3.5043 train_time:550378ms step_avg:141.81ms
step:3892/5100 train_loss:3.3615 train_time:550520ms step_avg:141.81ms
step:3893/5100 train_loss:3.6079 train_time:550660ms step_avg:141.81ms
step:3894/5100 train_loss:3.3464 train_time:550801ms step_avg:141.81ms
step:3895/5100 train_loss:3.3541 train_time:550939ms step_avg:141.81ms
step:3896/5100 train_loss:3.4391 train_time:551079ms step_avg:141.81ms
step:3897/5100 train_loss:3.6839 train_time:551218ms step_avg:141.81ms
step:3898/5100 train_loss:3.2334 train_time:551359ms step_avg:141.81ms
step:3899/5100 train_loss:3.3625 train_time:551500ms step_avg:141.81ms
step:3900/5100 train_loss:3.5016 train_time:551640ms step_avg:141.81ms
step:3901/5100 train_loss:3.4220 train_time:551780ms step_avg:141.81ms
step:3902/5100 train_loss:3.4668 train_time:551919ms step_avg:141.81ms
step:3903/5100 train_loss:3.7533 train_time:552058ms step_avg:141.81ms
step:3904/5100 train_loss:3.3468 train_time:552198ms step_avg:141.81ms
step:3905/5100 train_loss:3.3661 train_time:552339ms step_avg:141.81ms
step:3906/5100 train_loss:3.3184 train_time:552480ms step_avg:141.81ms
step:3907/5100 train_loss:3.4727 train_time:552620ms step_avg:141.81ms
step:3908/5100 train_loss:3.4903 train_time:552760ms step_avg:141.81ms
step:3909/5100 train_loss:3.4845 train_time:552900ms step_avg:141.81ms
step:3910/5100 train_loss:3.4296 train_time:553039ms step_avg:141.80ms
step:3911/5100 train_loss:3.3652 train_time:553179ms step_avg:141.80ms
step:3912/5100 train_loss:3.3844 train_time:553319ms step_avg:141.80ms
step:3913/5100 train_loss:3.3745 train_time:553459ms step_avg:141.80ms
step:3914/5100 train_loss:3.5014 train_time:553601ms step_avg:141.80ms
step:3915/5100 train_loss:3.3341 train_time:553739ms step_avg:141.80ms
step:3916/5100 train_loss:3.3144 train_time:553880ms step_avg:141.80ms
step:3917/5100 train_loss:3.3090 train_time:554019ms step_avg:141.80ms
step:3918/5100 train_loss:3.4314 train_time:554159ms step_avg:141.80ms
step:3919/5100 train_loss:3.5458 train_time:554300ms step_avg:141.80ms
step:3920/5100 train_loss:3.3264 train_time:554438ms step_avg:141.80ms
step:3921/5100 train_loss:3.3037 train_time:554580ms step_avg:141.80ms
step:3922/5100 train_loss:3.3805 train_time:554719ms step_avg:141.80ms
step:3923/5100 train_loss:3.3738 train_time:554859ms step_avg:141.80ms
step:3924/5100 train_loss:3.3902 train_time:554999ms step_avg:141.80ms
step:3925/5100 train_loss:3.4664 train_time:555139ms step_avg:141.80ms
step:3926/5100 train_loss:3.4332 train_time:555279ms step_avg:141.80ms
step:3927/5100 train_loss:3.5297 train_time:555420ms step_avg:141.80ms
step:3928/5100 train_loss:3.4150 train_time:555560ms step_avg:141.80ms
step:3929/5100 train_loss:3.2707 train_time:555700ms step_avg:141.80ms
step:3930/5100 train_loss:3.5976 train_time:555839ms step_avg:141.80ms
step:3931/5100 train_loss:3.3798 train_time:555980ms step_avg:141.80ms
step:3932/5100 train_loss:3.4294 train_time:556119ms step_avg:141.79ms
step:3933/5100 train_loss:3.4687 train_time:556259ms step_avg:141.79ms
step:3934/5100 train_loss:3.3433 train_time:556400ms step_avg:141.79ms
step:3935/5100 train_loss:3.4673 train_time:556540ms step_avg:141.79ms
step:3936/5100 train_loss:3.4746 train_time:556680ms step_avg:141.79ms
step:3937/5100 train_loss:3.4010 train_time:556819ms step_avg:141.79ms
step:3938/5100 train_loss:3.4640 train_time:556961ms step_avg:141.79ms
step:3939/5100 train_loss:3.3852 train_time:557100ms step_avg:141.79ms
step:3940/5100 train_loss:3.1482 train_time:557239ms step_avg:141.79ms
step:3941/5100 train_loss:3.3551 train_time:557380ms step_avg:141.79ms
step:3942/5100 train_loss:3.4631 train_time:557520ms step_avg:141.79ms
step:3943/5100 train_loss:3.5607 train_time:557659ms step_avg:141.79ms
step:3944/5100 train_loss:3.5942 train_time:557799ms step_avg:141.79ms
step:3945/5100 train_loss:3.4344 train_time:557940ms step_avg:141.79ms
step:3946/5100 train_loss:3.3446 train_time:558080ms step_avg:141.79ms
step:3947/5100 train_loss:3.3652 train_time:558219ms step_avg:141.79ms
step:3948/5100 train_loss:3.4408 train_time:558359ms step_avg:141.79ms
step:3949/5100 train_loss:3.2439 train_time:558500ms step_avg:141.79ms
step:3950/5100 train_loss:3.4496 train_time:558639ms step_avg:141.79ms
step:3951/5100 train_loss:3.3805 train_time:558782ms step_avg:141.79ms
step:3952/5100 train_loss:3.1908 train_time:558921ms step_avg:141.79ms
step:3953/5100 train_loss:3.2596 train_time:559061ms step_avg:141.79ms
step:3954/5100 train_loss:3.5298 train_time:559199ms step_avg:141.78ms
step:3955/5100 train_loss:3.4407 train_time:559339ms step_avg:141.78ms
step:3956/5100 train_loss:3.3818 train_time:559481ms step_avg:141.78ms
step:3957/5100 train_loss:3.4370 train_time:559620ms step_avg:141.78ms
step:3958/5100 train_loss:3.1555 train_time:559761ms step_avg:141.78ms
step:3959/5100 train_loss:3.4395 train_time:559900ms step_avg:141.78ms
step:3960/5100 train_loss:3.3948 train_time:560040ms step_avg:141.78ms
step:3961/5100 train_loss:3.3633 train_time:560180ms step_avg:141.78ms
step:3962/5100 train_loss:3.3774 train_time:560319ms step_avg:141.78ms
step:3963/5100 train_loss:3.4031 train_time:560462ms step_avg:141.78ms
step:3964/5100 train_loss:3.4321 train_time:560600ms step_avg:141.78ms
step:3965/5100 train_loss:3.2797 train_time:560739ms step_avg:141.78ms
step:3966/5100 train_loss:3.3978 train_time:560880ms step_avg:141.78ms
step:3967/5100 train_loss:3.4728 train_time:561019ms step_avg:141.78ms
step:3968/5100 train_loss:3.3864 train_time:561159ms step_avg:141.78ms
step:3969/5100 train_loss:3.4720 train_time:561471ms step_avg:141.82ms
step:3970/5100 train_loss:3.3626 train_time:561608ms step_avg:141.82ms
step:3971/5100 train_loss:3.5517 train_time:561748ms step_avg:141.82ms
step:3972/5100 train_loss:3.4708 train_time:561886ms step_avg:141.82ms
step:3973/5100 train_loss:3.4280 train_time:562024ms step_avg:141.82ms
step:3974/5100 train_loss:3.3187 train_time:562163ms step_avg:141.82ms
step:3975/5100 train_loss:3.3783 train_time:562301ms step_avg:141.82ms
step:3976/5100 train_loss:3.4363 train_time:562442ms step_avg:141.82ms
step:3977/5100 train_loss:3.3508 train_time:562584ms step_avg:141.82ms
step:3978/5100 train_loss:3.4108 train_time:562723ms step_avg:141.82ms
step:3979/5100 train_loss:3.4878 train_time:562863ms step_avg:141.81ms
step:3980/5100 train_loss:3.4296 train_time:563002ms step_avg:141.81ms
step:3981/5100 train_loss:3.4436 train_time:563140ms step_avg:141.81ms
step:3982/5100 train_loss:3.6391 train_time:563280ms step_avg:141.81ms
step:3983/5100 train_loss:3.3706 train_time:563422ms step_avg:141.81ms
step:3984/5100 train_loss:3.4401 train_time:563562ms step_avg:141.81ms
step:3985/5100 train_loss:3.3672 train_time:563701ms step_avg:141.81ms
step:3986/5100 train_loss:3.2943 train_time:563842ms step_avg:141.81ms
step:3987/5100 train_loss:3.3396 train_time:563982ms step_avg:141.81ms
step:3988/5100 train_loss:3.3606 train_time:564120ms step_avg:141.81ms
step:3989/5100 train_loss:3.0877 train_time:564259ms step_avg:141.81ms
step:3990/5100 train_loss:3.4093 train_time:564577ms step_avg:141.85ms
step:3991/5100 train_loss:3.3930 train_time:564713ms step_avg:141.85ms
step:3992/5100 train_loss:3.2322 train_time:564850ms step_avg:141.85ms
step:3993/5100 train_loss:3.3419 train_time:564988ms step_avg:141.85ms
step:3994/5100 train_loss:3.5358 train_time:565126ms step_avg:141.85ms
step:3995/5100 train_loss:3.3577 train_time:565264ms step_avg:141.85ms
step:3996/5100 train_loss:3.2739 train_time:565403ms step_avg:141.85ms
step:3997/5100 train_loss:3.4237 train_time:565544ms step_avg:141.85ms
step:3998/5100 train_loss:3.3392 train_time:565688ms step_avg:141.85ms
step:3999/5100 train_loss:3.2951 train_time:565827ms step_avg:141.85ms
step:4000/5100 train_loss:3.3672 train_time:565965ms step_avg:141.85ms
step:4000/5100 val_loss:3.3723 train_time:566023ms step_avg:141.86ms
step:4001/5100 train_loss:3.4923 train_time:566117ms step_avg:141.85ms
step:4002/5100 train_loss:3.5628 train_time:566268ms step_avg:141.85ms
step:4003/5100 train_loss:3.2342 train_time:566405ms step_avg:141.85ms
step:4004/5100 train_loss:3.4341 train_time:566543ms step_avg:141.85ms
step:4005/5100 train_loss:3.3358 train_time:566682ms step_avg:141.85ms
step:4006/5100 train_loss:3.3832 train_time:566821ms step_avg:141.85ms
step:4007/5100 train_loss:3.3668 train_time:566959ms step_avg:141.85ms
step:4008/5100 train_loss:3.5703 train_time:567101ms step_avg:141.85ms
step:4009/5100 train_loss:3.1593 train_time:567245ms step_avg:141.85ms
step:4010/5100 train_loss:3.3536 train_time:567386ms step_avg:141.85ms
step:4011/5100 train_loss:3.3368 train_time:567525ms step_avg:141.85ms
step:4012/5100 train_loss:3.3048 train_time:567666ms step_avg:141.85ms
step:4013/5100 train_loss:3.4733 train_time:567805ms step_avg:141.84ms
step:4014/5100 train_loss:3.3372 train_time:567944ms step_avg:141.84ms
step:4015/5100 train_loss:3.4390 train_time:568085ms step_avg:141.84ms
step:4016/5100 train_loss:3.5235 train_time:568226ms step_avg:141.84ms
step:4017/5100 train_loss:3.5162 train_time:568366ms step_avg:141.84ms
step:4018/5100 train_loss:3.2737 train_time:568506ms step_avg:141.84ms
step:4019/5100 train_loss:3.4094 train_time:568645ms step_avg:141.84ms
step:4020/5100 train_loss:3.3209 train_time:568785ms step_avg:141.84ms
step:4021/5100 train_loss:3.5936 train_time:568924ms step_avg:141.84ms
step:4022/5100 train_loss:3.4732 train_time:569064ms step_avg:141.84ms
step:4023/5100 train_loss:3.4516 train_time:569205ms step_avg:141.84ms
step:4024/5100 train_loss:3.4238 train_time:569345ms step_avg:141.84ms
step:4025/5100 train_loss:3.4637 train_time:569486ms step_avg:141.84ms
step:4026/5100 train_loss:3.2130 train_time:569626ms step_avg:141.84ms
step:4027/5100 train_loss:3.4317 train_time:569766ms step_avg:141.84ms
step:4028/5100 train_loss:3.3775 train_time:569906ms step_avg:141.84ms
step:4029/5100 train_loss:3.2608 train_time:570046ms step_avg:141.84ms
step:4030/5100 train_loss:3.2920 train_time:570187ms step_avg:141.84ms
step:4031/5100 train_loss:3.3479 train_time:570327ms step_avg:141.84ms
step:4032/5100 train_loss:3.4360 train_time:570467ms step_avg:141.84ms
step:4033/5100 train_loss:3.4004 train_time:570607ms step_avg:141.84ms
step:4034/5100 train_loss:3.3791 train_time:570746ms step_avg:141.84ms
step:4035/5100 train_loss:3.3714 train_time:570886ms step_avg:141.84ms
step:4036/5100 train_loss:3.3106 train_time:571026ms step_avg:141.83ms
step:4037/5100 train_loss:3.4767 train_time:571165ms step_avg:141.83ms
step:4038/5100 train_loss:3.4066 train_time:571306ms step_avg:141.83ms
step:4039/5100 train_loss:3.3977 train_time:571445ms step_avg:141.83ms
step:4040/5100 train_loss:3.3832 train_time:571586ms step_avg:141.83ms
step:4041/5100 train_loss:3.4300 train_time:571725ms step_avg:141.83ms
step:4042/5100 train_loss:3.6179 train_time:571865ms step_avg:141.83ms
step:4043/5100 train_loss:3.5161 train_time:572006ms step_avg:141.83ms
step:4044/5100 train_loss:3.3050 train_time:572145ms step_avg:141.83ms
step:4045/5100 train_loss:3.4729 train_time:572286ms step_avg:141.83ms
step:4046/5100 train_loss:3.1840 train_time:572426ms step_avg:141.83ms
step:4047/5100 train_loss:3.4396 train_time:572565ms step_avg:141.83ms
step:4048/5100 train_loss:3.5241 train_time:572705ms step_avg:141.83ms
step:4049/5100 train_loss:3.3949 train_time:572845ms step_avg:141.83ms
step:4050/5100 train_loss:3.3274 train_time:572986ms step_avg:141.83ms
step:4051/5100 train_loss:3.3664 train_time:573126ms step_avg:141.83ms
step:4052/5100 train_loss:3.3006 train_time:573266ms step_avg:141.83ms
step:4053/5100 train_loss:3.5084 train_time:573407ms step_avg:141.83ms
step:4054/5100 train_loss:3.3590 train_time:573546ms step_avg:141.83ms
step:4055/5100 train_loss:3.4485 train_time:573686ms step_avg:141.83ms
step:4056/5100 train_loss:3.4272 train_time:573826ms step_avg:141.83ms
step:4057/5100 train_loss:3.4067 train_time:573972ms step_avg:141.83ms
step:4058/5100 train_loss:3.2696 train_time:574107ms step_avg:141.82ms
step:4059/5100 train_loss:3.4258 train_time:574246ms step_avg:141.82ms
step:4060/5100 train_loss:3.2854 train_time:574387ms step_avg:141.82ms
step:4061/5100 train_loss:3.3709 train_time:574526ms step_avg:141.82ms
step:4062/5100 train_loss:3.4850 train_time:574667ms step_avg:141.82ms
step:4063/5100 train_loss:3.6529 train_time:574806ms step_avg:141.82ms
step:4064/5100 train_loss:3.0461 train_time:574945ms step_avg:141.82ms
step:4065/5100 train_loss:3.4082 train_time:575086ms step_avg:141.82ms
step:4066/5100 train_loss:3.2913 train_time:575227ms step_avg:141.82ms
step:4067/5100 train_loss:3.4493 train_time:575367ms step_avg:141.82ms
step:4068/5100 train_loss:3.4521 train_time:575507ms step_avg:141.82ms
step:4069/5100 train_loss:3.2579 train_time:575646ms step_avg:141.82ms
step:4070/5100 train_loss:3.4212 train_time:575786ms step_avg:141.82ms
step:4071/5100 train_loss:3.2379 train_time:575926ms step_avg:141.82ms
step:4072/5100 train_loss:3.4240 train_time:576066ms step_avg:141.82ms
step:4073/5100 train_loss:3.5294 train_time:576206ms step_avg:141.82ms
step:4074/5100 train_loss:3.4571 train_time:576348ms step_avg:141.82ms
step:4075/5100 train_loss:3.3666 train_time:576487ms step_avg:141.82ms
step:4076/5100 train_loss:3.3656 train_time:576626ms step_avg:141.82ms
step:4077/5100 train_loss:3.2343 train_time:576765ms step_avg:141.82ms
step:4078/5100 train_loss:3.4034 train_time:576906ms step_avg:141.82ms
step:4079/5100 train_loss:3.4176 train_time:577045ms step_avg:141.81ms
step:4080/5100 train_loss:3.2140 train_time:577186ms step_avg:141.81ms
step:4081/5100 train_loss:3.3881 train_time:577325ms step_avg:141.81ms
step:4082/5100 train_loss:3.3390 train_time:577466ms step_avg:141.81ms
step:4083/5100 train_loss:3.3868 train_time:577606ms step_avg:141.81ms
step:4084/5100 train_loss:3.3931 train_time:577745ms step_avg:141.81ms
step:4085/5100 train_loss:3.4227 train_time:577887ms step_avg:141.81ms
step:4086/5100 train_loss:3.3791 train_time:578026ms step_avg:141.81ms
step:4087/5100 train_loss:3.3560 train_time:578166ms step_avg:141.81ms
step:4088/5100 train_loss:3.4776 train_time:578306ms step_avg:141.81ms
step:4089/5100 train_loss:3.3072 train_time:578446ms step_avg:141.81ms
step:4090/5100 train_loss:3.3374 train_time:578587ms step_avg:141.81ms
step:4091/5100 train_loss:3.3528 train_time:578726ms step_avg:141.81ms
step:4092/5100 train_loss:3.3091 train_time:578866ms step_avg:141.81ms
step:4093/5100 train_loss:3.2990 train_time:579006ms step_avg:141.81ms
step:4094/5100 train_loss:3.4793 train_time:579145ms step_avg:141.81ms
step:4095/5100 train_loss:3.4656 train_time:579286ms step_avg:141.81ms
step:4096/5100 train_loss:3.3776 train_time:579426ms step_avg:141.81ms
step:4097/5100 train_loss:3.4418 train_time:579565ms step_avg:141.81ms
step:4098/5100 train_loss:3.2169 train_time:579706ms step_avg:141.81ms
step:4099/5100 train_loss:3.3506 train_time:579846ms step_avg:141.81ms
step:4100/5100 train_loss:3.3302 train_time:579986ms step_avg:141.81ms
step:4101/5100 train_loss:3.1261 train_time:580126ms step_avg:141.81ms
step:4102/5100 train_loss:3.4063 train_time:580266ms step_avg:141.80ms
step:4103/5100 train_loss:3.3651 train_time:580411ms step_avg:141.81ms
step:4104/5100 train_loss:3.2021 train_time:580546ms step_avg:141.80ms
step:4105/5100 train_loss:3.2883 train_time:580687ms step_avg:141.80ms
step:4106/5100 train_loss:3.4432 train_time:580827ms step_avg:141.80ms
step:4107/5100 train_loss:3.4900 train_time:580967ms step_avg:141.80ms
step:4108/5100 train_loss:3.3902 train_time:581106ms step_avg:141.80ms
step:4109/5100 train_loss:3.4787 train_time:581246ms step_avg:141.80ms
step:4110/5100 train_loss:3.4716 train_time:581391ms step_avg:141.80ms
step:4111/5100 train_loss:3.6389 train_time:581527ms step_avg:141.80ms
step:4112/5100 train_loss:3.2891 train_time:581666ms step_avg:141.80ms
step:4113/5100 train_loss:3.4239 train_time:581806ms step_avg:141.80ms
step:4114/5100 train_loss:3.3339 train_time:581949ms step_avg:141.80ms
step:4115/5100 train_loss:3.4268 train_time:582086ms step_avg:141.80ms
step:4116/5100 train_loss:3.4302 train_time:582227ms step_avg:141.80ms
step:4117/5100 train_loss:3.6671 train_time:582366ms step_avg:141.80ms
step:4118/5100 train_loss:3.2366 train_time:582506ms step_avg:141.80ms
step:4119/5100 train_loss:3.4096 train_time:582646ms step_avg:141.80ms
step:4120/5100 train_loss:3.3246 train_time:582788ms step_avg:141.80ms
step:4121/5100 train_loss:3.4337 train_time:582926ms step_avg:141.80ms
step:4122/5100 train_loss:3.4196 train_time:583066ms step_avg:141.80ms
step:4123/5100 train_loss:3.4100 train_time:583207ms step_avg:141.80ms
step:4124/5100 train_loss:3.2532 train_time:583346ms step_avg:141.80ms
step:4125/5100 train_loss:3.2531 train_time:583486ms step_avg:141.79ms
step:4125/5100 val_loss:3.3595 train_time:583541ms step_avg:141.81ms
step:4126/5100 train_loss:3.3793 train_time:583637ms step_avg:141.80ms
step:4127/5100 train_loss:3.3141 train_time:583782ms step_avg:141.80ms
step:4128/5100 train_loss:3.3895 train_time:583923ms step_avg:141.80ms
step:4129/5100 train_loss:3.3851 train_time:584064ms step_avg:141.80ms
step:4130/5100 train_loss:3.1640 train_time:584204ms step_avg:141.80ms
step:4131/5100 train_loss:3.4582 train_time:584343ms step_avg:141.80ms
step:4132/5100 train_loss:3.4135 train_time:584482ms step_avg:141.80ms
step:4133/5100 train_loss:3.3328 train_time:584624ms step_avg:141.80ms
step:4134/5100 train_loss:3.5430 train_time:584765ms step_avg:141.80ms
step:4135/5100 train_loss:3.3653 train_time:584907ms step_avg:141.80ms
step:4136/5100 train_loss:3.3440 train_time:585046ms step_avg:141.80ms
step:4137/5100 train_loss:3.4964 train_time:585187ms step_avg:141.79ms
step:4138/5100 train_loss:3.3419 train_time:585327ms step_avg:141.79ms
step:4139/5100 train_loss:3.3942 train_time:585467ms step_avg:141.79ms
step:4140/5100 train_loss:3.4845 train_time:585609ms step_avg:141.79ms
step:4141/5100 train_loss:3.5136 train_time:585749ms step_avg:141.79ms
step:4142/5100 train_loss:3.4708 train_time:585890ms step_avg:141.79ms
step:4143/5100 train_loss:3.4591 train_time:586029ms step_avg:141.79ms
step:4144/5100 train_loss:3.3607 train_time:586169ms step_avg:141.79ms
step:4145/5100 train_loss:3.3284 train_time:586309ms step_avg:141.79ms
step:4146/5100 train_loss:3.4395 train_time:586449ms step_avg:141.79ms
step:4147/5100 train_loss:2.9951 train_time:586590ms step_avg:141.79ms
step:4148/5100 train_loss:3.3519 train_time:586730ms step_avg:141.79ms
step:4149/5100 train_loss:3.3921 train_time:586870ms step_avg:141.79ms
step:4150/5100 train_loss:3.2019 train_time:587010ms step_avg:141.79ms
step:4151/5100 train_loss:3.2341 train_time:587150ms step_avg:141.79ms
step:4152/5100 train_loss:3.2779 train_time:587291ms step_avg:141.79ms
step:4153/5100 train_loss:3.3290 train_time:587430ms step_avg:141.79ms
step:4154/5100 train_loss:3.3892 train_time:587571ms step_avg:141.79ms
step:4155/5100 train_loss:3.4901 train_time:587712ms step_avg:141.79ms
step:4156/5100 train_loss:3.2973 train_time:587850ms step_avg:141.79ms
step:4157/5100 train_loss:3.2564 train_time:587992ms step_avg:141.79ms
step:4158/5100 train_loss:3.3684 train_time:588302ms step_avg:141.83ms
step:4159/5100 train_loss:3.3730 train_time:588440ms step_avg:141.83ms
step:4160/5100 train_loss:3.2922 train_time:588577ms step_avg:141.83ms
step:4161/5100 train_loss:3.3761 train_time:588715ms step_avg:141.82ms
step:4162/5100 train_loss:3.3139 train_time:588854ms step_avg:141.82ms
step:4163/5100 train_loss:3.5366 train_time:588992ms step_avg:141.82ms
step:4164/5100 train_loss:3.2168 train_time:589129ms step_avg:141.82ms
step:4165/5100 train_loss:3.3265 train_time:589274ms step_avg:141.82ms
step:4166/5100 train_loss:3.3062 train_time:589415ms step_avg:141.82ms
step:4167/5100 train_loss:3.3531 train_time:589556ms step_avg:141.82ms
step:4168/5100 train_loss:3.3441 train_time:589693ms step_avg:141.82ms
step:4169/5100 train_loss:3.3732 train_time:589831ms step_avg:141.82ms
step:4170/5100 train_loss:3.2122 train_time:589970ms step_avg:141.82ms
step:4171/5100 train_loss:3.3141 train_time:590109ms step_avg:141.82ms
step:4172/5100 train_loss:3.4422 train_time:590251ms step_avg:141.82ms
step:4173/5100 train_loss:3.5107 train_time:590393ms step_avg:141.82ms
step:4174/5100 train_loss:3.8780 train_time:590532ms step_avg:141.82ms
step:4175/5100 train_loss:3.3197 train_time:590672ms step_avg:141.82ms
step:4176/5100 train_loss:3.4742 train_time:590811ms step_avg:141.82ms
step:4177/5100 train_loss:3.2715 train_time:590949ms step_avg:141.82ms
step:4178/5100 train_loss:3.2975 train_time:591090ms step_avg:141.82ms
step:4179/5100 train_loss:3.4576 train_time:591232ms step_avg:141.82ms
step:4180/5100 train_loss:3.4053 train_time:591532ms step_avg:141.85ms
step:4181/5100 train_loss:3.3903 train_time:591669ms step_avg:141.85ms
step:4182/5100 train_loss:3.3967 train_time:591806ms step_avg:141.85ms
step:4183/5100 train_loss:3.4331 train_time:591945ms step_avg:141.85ms
step:4184/5100 train_loss:3.8579 train_time:592085ms step_avg:141.85ms
step:4185/5100 train_loss:3.3731 train_time:592222ms step_avg:141.85ms
step:4186/5100 train_loss:3.4141 train_time:592362ms step_avg:141.85ms
step:4187/5100 train_loss:3.4657 train_time:592508ms step_avg:141.85ms
step:4188/5100 train_loss:3.4570 train_time:592649ms step_avg:141.85ms
step:4189/5100 train_loss:3.1012 train_time:592789ms step_avg:141.85ms
step:4190/5100 train_loss:3.4540 train_time:592929ms step_avg:141.85ms
step:4191/5100 train_loss:3.4624 train_time:593068ms step_avg:141.85ms
step:4192/5100 train_loss:3.4337 train_time:593209ms step_avg:141.85ms
step:4193/5100 train_loss:3.3673 train_time:593348ms step_avg:141.85ms
step:4194/5100 train_loss:3.3825 train_time:593491ms step_avg:141.85ms
step:4195/5100 train_loss:3.3566 train_time:593631ms step_avg:141.85ms
step:4196/5100 train_loss:3.3051 train_time:593770ms step_avg:141.85ms
step:4197/5100 train_loss:3.6603 train_time:593910ms step_avg:141.85ms
step:4198/5100 train_loss:3.1035 train_time:594048ms step_avg:141.85ms
step:4199/5100 train_loss:3.5146 train_time:594190ms step_avg:141.85ms
step:4200/5100 train_loss:3.3833 train_time:594330ms step_avg:141.84ms
step:4201/5100 train_loss:3.2680 train_time:594471ms step_avg:141.84ms
step:4202/5100 train_loss:3.4099 train_time:594612ms step_avg:141.84ms
step:4203/5100 train_loss:3.2710 train_time:594751ms step_avg:141.84ms
step:4204/5100 train_loss:3.2890 train_time:594892ms step_avg:141.84ms
step:4205/5100 train_loss:3.3006 train_time:595032ms step_avg:141.84ms
step:4206/5100 train_loss:3.2968 train_time:595170ms step_avg:141.84ms
step:4207/5100 train_loss:3.7577 train_time:595311ms step_avg:141.84ms
step:4208/5100 train_loss:3.3302 train_time:595451ms step_avg:141.84ms
step:4209/5100 train_loss:3.4618 train_time:595593ms step_avg:141.84ms
step:4210/5100 train_loss:3.3488 train_time:595733ms step_avg:141.84ms
step:4211/5100 train_loss:3.7373 train_time:595870ms step_avg:141.84ms
step:4212/5100 train_loss:3.4003 train_time:596010ms step_avg:141.84ms
step:4213/5100 train_loss:3.3992 train_time:596149ms step_avg:141.84ms
step:4214/5100 train_loss:3.2746 train_time:596290ms step_avg:141.84ms
step:4215/5100 train_loss:3.3340 train_time:596430ms step_avg:141.84ms
step:4216/5100 train_loss:3.4114 train_time:596571ms step_avg:141.84ms
step:4217/5100 train_loss:3.2650 train_time:596712ms step_avg:141.84ms
step:4218/5100 train_loss:3.3323 train_time:596851ms step_avg:141.84ms
step:4219/5100 train_loss:3.3871 train_time:596991ms step_avg:141.84ms
step:4220/5100 train_loss:3.1965 train_time:597131ms step_avg:141.84ms
step:4221/5100 train_loss:3.3635 train_time:597271ms step_avg:141.84ms
step:4222/5100 train_loss:3.3878 train_time:597412ms step_avg:141.84ms
step:4223/5100 train_loss:3.3554 train_time:597551ms step_avg:141.83ms
step:4224/5100 train_loss:3.5583 train_time:597691ms step_avg:141.83ms
step:4225/5100 train_loss:3.4440 train_time:597831ms step_avg:141.83ms
step:4226/5100 train_loss:3.4866 train_time:597970ms step_avg:141.83ms
step:4227/5100 train_loss:3.2714 train_time:598111ms step_avg:141.83ms
step:4228/5100 train_loss:3.3431 train_time:598250ms step_avg:141.83ms
step:4229/5100 train_loss:3.3769 train_time:598391ms step_avg:141.83ms
step:4230/5100 train_loss:3.2862 train_time:598530ms step_avg:141.83ms
step:4231/5100 train_loss:3.4730 train_time:598670ms step_avg:141.83ms
step:4232/5100 train_loss:3.4833 train_time:598811ms step_avg:141.83ms
step:4233/5100 train_loss:3.4723 train_time:598950ms step_avg:141.83ms
step:4234/5100 train_loss:3.5807 train_time:599090ms step_avg:141.83ms
step:4235/5100 train_loss:3.4304 train_time:599230ms step_avg:141.83ms
step:4236/5100 train_loss:3.3809 train_time:599371ms step_avg:141.83ms
step:4237/5100 train_loss:3.2245 train_time:599510ms step_avg:141.83ms
step:4238/5100 train_loss:3.4352 train_time:599650ms step_avg:141.83ms
step:4239/5100 train_loss:3.3544 train_time:599796ms step_avg:141.83ms
step:4240/5100 train_loss:3.2699 train_time:599931ms step_avg:141.83ms
step:4241/5100 train_loss:3.3154 train_time:600070ms step_avg:141.83ms
step:4242/5100 train_loss:3.2588 train_time:600213ms step_avg:141.83ms
step:4243/5100 train_loss:3.3397 train_time:600351ms step_avg:141.83ms
step:4244/5100 train_loss:3.2638 train_time:600491ms step_avg:141.83ms
step:4245/5100 train_loss:3.1738 train_time:600631ms step_avg:141.83ms
step:4246/5100 train_loss:3.4851 train_time:600771ms step_avg:141.83ms
step:4247/5100 train_loss:3.2778 train_time:600915ms step_avg:141.83ms
step:4248/5100 train_loss:3.2107 train_time:601050ms step_avg:141.82ms
step:4249/5100 train_loss:3.4226 train_time:601190ms step_avg:141.82ms
step:4250/5100 train_loss:3.7076 train_time:601331ms step_avg:141.82ms
step:4250/5100 val_loss:3.3472 train_time:601387ms step_avg:141.84ms
step:4251/5100 train_loss:3.3302 train_time:601482ms step_avg:141.83ms
step:4252/5100 train_loss:3.5822 train_time:601626ms step_avg:141.83ms
step:4253/5100 train_loss:3.4201 train_time:601766ms step_avg:141.83ms
step:4254/5100 train_loss:3.2303 train_time:601905ms step_avg:141.82ms
step:4255/5100 train_loss:3.3069 train_time:602042ms step_avg:141.82ms
step:4256/5100 train_loss:3.2369 train_time:602180ms step_avg:141.82ms
step:4257/5100 train_loss:3.4691 train_time:602322ms step_avg:141.82ms
step:4258/5100 train_loss:3.3609 train_time:602459ms step_avg:141.82ms
step:4259/5100 train_loss:3.4076 train_time:602603ms step_avg:141.82ms
step:4260/5100 train_loss:3.2315 train_time:602743ms step_avg:141.82ms
step:4261/5100 train_loss:3.5287 train_time:602882ms step_avg:141.82ms
step:4262/5100 train_loss:3.3556 train_time:603021ms step_avg:141.82ms
step:4263/5100 train_loss:3.3709 train_time:603158ms step_avg:141.82ms
step:4264/5100 train_loss:3.4128 train_time:603297ms step_avg:141.82ms
step:4265/5100 train_loss:3.3627 train_time:603441ms step_avg:141.82ms
step:4266/5100 train_loss:3.3643 train_time:603580ms step_avg:141.82ms
step:4267/5100 train_loss:3.4900 train_time:603723ms step_avg:141.82ms
step:4268/5100 train_loss:3.3143 train_time:603862ms step_avg:141.82ms
step:4269/5100 train_loss:3.8509 train_time:604001ms step_avg:141.82ms
step:4270/5100 train_loss:3.3032 train_time:604140ms step_avg:141.82ms
step:4271/5100 train_loss:3.3975 train_time:604279ms step_avg:141.82ms
step:4272/5100 train_loss:3.3248 train_time:604420ms step_avg:141.82ms
step:4273/5100 train_loss:3.5324 train_time:604561ms step_avg:141.82ms
step:4274/5100 train_loss:3.4529 train_time:604703ms step_avg:141.82ms
step:4275/5100 train_loss:3.3074 train_time:604842ms step_avg:141.82ms
step:4276/5100 train_loss:3.3651 train_time:604981ms step_avg:141.81ms
step:4277/5100 train_loss:3.2914 train_time:605120ms step_avg:141.81ms
step:4278/5100 train_loss:3.3286 train_time:605259ms step_avg:141.81ms
step:4279/5100 train_loss:3.3295 train_time:605399ms step_avg:141.81ms
step:4280/5100 train_loss:3.3956 train_time:605540ms step_avg:141.81ms
step:4281/5100 train_loss:3.3803 train_time:605681ms step_avg:141.81ms
step:4282/5100 train_loss:3.3946 train_time:605821ms step_avg:141.81ms
step:4283/5100 train_loss:3.3254 train_time:605960ms step_avg:141.81ms
step:4284/5100 train_loss:3.3650 train_time:606100ms step_avg:141.81ms
step:4285/5100 train_loss:3.4387 train_time:606240ms step_avg:141.81ms
step:4286/5100 train_loss:3.3819 train_time:606379ms step_avg:141.81ms
step:4287/5100 train_loss:3.2691 train_time:606520ms step_avg:141.81ms
step:4288/5100 train_loss:3.3051 train_time:606660ms step_avg:141.81ms
step:4289/5100 train_loss:3.4018 train_time:606800ms step_avg:141.81ms
step:4290/5100 train_loss:3.3633 train_time:606940ms step_avg:141.81ms
step:4291/5100 train_loss:3.2641 train_time:607078ms step_avg:141.81ms
step:4292/5100 train_loss:3.2891 train_time:607218ms step_avg:141.81ms
step:4293/5100 train_loss:3.3624 train_time:607358ms step_avg:141.81ms
step:4294/5100 train_loss:3.1418 train_time:607499ms step_avg:141.81ms
step:4295/5100 train_loss:3.5039 train_time:607640ms step_avg:141.81ms
step:4296/5100 train_loss:3.3881 train_time:607780ms step_avg:141.81ms
step:4297/5100 train_loss:3.3428 train_time:607922ms step_avg:141.81ms
step:4298/5100 train_loss:3.5016 train_time:608059ms step_avg:141.80ms
step:4299/5100 train_loss:3.4297 train_time:608200ms step_avg:141.80ms
step:4300/5100 train_loss:3.2610 train_time:608340ms step_avg:141.80ms
step:4301/5100 train_loss:3.2564 train_time:608480ms step_avg:141.80ms
step:4302/5100 train_loss:3.4041 train_time:608625ms step_avg:141.80ms
step:4303/5100 train_loss:3.2456 train_time:608760ms step_avg:141.80ms
step:4304/5100 train_loss:3.3904 train_time:608900ms step_avg:141.80ms
step:4305/5100 train_loss:3.4743 train_time:609040ms step_avg:141.80ms
step:4306/5100 train_loss:3.2285 train_time:609180ms step_avg:141.80ms
step:4307/5100 train_loss:3.7420 train_time:609320ms step_avg:141.80ms
step:4308/5100 train_loss:3.3555 train_time:609462ms step_avg:141.80ms
step:4309/5100 train_loss:3.2838 train_time:609600ms step_avg:141.80ms
step:4310/5100 train_loss:3.2933 train_time:609740ms step_avg:141.80ms
step:4311/5100 train_loss:3.5946 train_time:609880ms step_avg:141.80ms
step:4312/5100 train_loss:3.4174 train_time:610019ms step_avg:141.80ms
step:4313/5100 train_loss:3.2808 train_time:610159ms step_avg:141.80ms
step:4314/5100 train_loss:3.4846 train_time:610300ms step_avg:141.80ms
step:4315/5100 train_loss:3.4143 train_time:610439ms step_avg:141.80ms
step:4316/5100 train_loss:3.3291 train_time:610580ms step_avg:141.80ms
step:4317/5100 train_loss:3.3690 train_time:610721ms step_avg:141.80ms
step:4318/5100 train_loss:3.3188 train_time:610860ms step_avg:141.80ms
step:4319/5100 train_loss:3.4377 train_time:611000ms step_avg:141.80ms
step:4320/5100 train_loss:3.4879 train_time:611140ms step_avg:141.80ms
step:4321/5100 train_loss:3.3108 train_time:611280ms step_avg:141.80ms
step:4322/5100 train_loss:3.4778 train_time:611421ms step_avg:141.80ms
step:4323/5100 train_loss:3.3399 train_time:611560ms step_avg:141.79ms
step:4324/5100 train_loss:3.2680 train_time:611701ms step_avg:141.79ms
step:4325/5100 train_loss:3.2094 train_time:611840ms step_avg:141.79ms
step:4326/5100 train_loss:3.3057 train_time:611980ms step_avg:141.79ms
step:4327/5100 train_loss:3.1879 train_time:612121ms step_avg:141.79ms
step:4328/5100 train_loss:3.3131 train_time:612260ms step_avg:141.79ms
step:4329/5100 train_loss:3.3373 train_time:612402ms step_avg:141.79ms
step:4330/5100 train_loss:3.2801 train_time:612540ms step_avg:141.79ms
step:4331/5100 train_loss:3.5363 train_time:612680ms step_avg:141.79ms
step:4332/5100 train_loss:3.3356 train_time:612821ms step_avg:141.79ms
step:4333/5100 train_loss:3.4496 train_time:612959ms step_avg:141.79ms
step:4334/5100 train_loss:3.8137 train_time:613101ms step_avg:141.79ms
step:4335/5100 train_loss:3.3470 train_time:613241ms step_avg:141.79ms
step:4336/5100 train_loss:3.4504 train_time:613381ms step_avg:141.79ms
step:4337/5100 train_loss:3.3434 train_time:613521ms step_avg:141.79ms
step:4338/5100 train_loss:3.2336 train_time:613660ms step_avg:141.79ms
step:4339/5100 train_loss:3.3717 train_time:613800ms step_avg:141.79ms
step:4340/5100 train_loss:3.2670 train_time:613939ms step_avg:141.79ms
step:4341/5100 train_loss:3.3638 train_time:614080ms step_avg:141.79ms
step:4342/5100 train_loss:3.3875 train_time:614225ms step_avg:141.79ms
step:4343/5100 train_loss:3.3856 train_time:614359ms step_avg:141.79ms
step:4344/5100 train_loss:3.3772 train_time:614500ms step_avg:141.79ms
step:4345/5100 train_loss:4.0121 train_time:614639ms step_avg:141.79ms
step:4346/5100 train_loss:3.4515 train_time:614779ms step_avg:141.78ms
step:4347/5100 train_loss:3.2468 train_time:615094ms step_avg:141.82ms
step:4348/5100 train_loss:3.3840 train_time:615230ms step_avg:141.82ms
step:4349/5100 train_loss:3.3335 train_time:615369ms step_avg:141.82ms
step:4350/5100 train_loss:3.2507 train_time:615508ms step_avg:141.82ms
step:4351/5100 train_loss:3.4137 train_time:615646ms step_avg:141.82ms
step:4352/5100 train_loss:3.3615 train_time:615784ms step_avg:141.82ms
step:4353/5100 train_loss:3.4351 train_time:615923ms step_avg:141.82ms
step:4354/5100 train_loss:3.4767 train_time:616066ms step_avg:141.82ms
step:4355/5100 train_loss:3.2922 train_time:616209ms step_avg:141.82ms
step:4356/5100 train_loss:3.2482 train_time:616349ms step_avg:141.82ms
step:4357/5100 train_loss:3.3816 train_time:616489ms step_avg:141.82ms
step:4358/5100 train_loss:3.3239 train_time:616628ms step_avg:141.82ms
step:4359/5100 train_loss:3.5128 train_time:616768ms step_avg:141.82ms
step:4360/5100 train_loss:3.3764 train_time:616907ms step_avg:141.82ms
step:4361/5100 train_loss:3.4543 train_time:617047ms step_avg:141.82ms
step:4362/5100 train_loss:3.6146 train_time:617188ms step_avg:141.82ms
step:4363/5100 train_loss:3.4052 train_time:617329ms step_avg:141.82ms
step:4364/5100 train_loss:3.3690 train_time:617471ms step_avg:141.82ms
step:4365/5100 train_loss:3.5736 train_time:617613ms step_avg:141.82ms
step:4366/5100 train_loss:3.4665 train_time:617754ms step_avg:141.82ms
step:4367/5100 train_loss:3.2832 train_time:617895ms step_avg:141.82ms
step:4368/5100 train_loss:3.3093 train_time:618035ms step_avg:141.82ms
step:4369/5100 train_loss:3.3993 train_time:618176ms step_avg:141.82ms
step:4370/5100 train_loss:3.4019 train_time:618497ms step_avg:141.86ms
step:4371/5100 train_loss:3.5479 train_time:618633ms step_avg:141.86ms
step:4372/5100 train_loss:3.2426 train_time:618772ms step_avg:141.86ms
step:4373/5100 train_loss:3.2357 train_time:618908ms step_avg:141.85ms
step:4374/5100 train_loss:3.3946 train_time:619046ms step_avg:141.85ms
step:4375/5100 train_loss:3.4179 train_time:619186ms step_avg:141.85ms
step:4375/5100 val_loss:3.3331 train_time:619241ms step_avg:141.87ms
step:4376/5100 train_loss:3.4844 train_time:619337ms step_avg:141.85ms
step:4377/5100 train_loss:3.2757 train_time:619482ms step_avg:141.86ms
step:4378/5100 train_loss:3.3230 train_time:619624ms step_avg:141.86ms
step:4379/5100 train_loss:3.3379 train_time:619764ms step_avg:141.85ms
step:4380/5100 train_loss:3.3753 train_time:619901ms step_avg:141.85ms
step:4381/5100 train_loss:3.2739 train_time:620040ms step_avg:141.85ms
step:4382/5100 train_loss:3.5160 train_time:620180ms step_avg:141.85ms
step:4383/5100 train_loss:3.4312 train_time:620319ms step_avg:141.85ms
step:4384/5100 train_loss:3.4246 train_time:620462ms step_avg:141.85ms
step:4385/5100 train_loss:3.3044 train_time:620606ms step_avg:141.85ms
step:4386/5100 train_loss:3.4350 train_time:620746ms step_avg:141.85ms
step:4387/5100 train_loss:3.3210 train_time:620886ms step_avg:141.85ms
step:4388/5100 train_loss:3.4551 train_time:621027ms step_avg:141.85ms
step:4389/5100 train_loss:3.2843 train_time:621166ms step_avg:141.85ms
step:4390/5100 train_loss:3.4043 train_time:621307ms step_avg:141.85ms
step:4391/5100 train_loss:3.4195 train_time:621449ms step_avg:141.85ms
step:4392/5100 train_loss:3.2428 train_time:621589ms step_avg:141.85ms
step:4393/5100 train_loss:3.9211 train_time:621730ms step_avg:141.85ms
step:4394/5100 train_loss:3.3135 train_time:621869ms step_avg:141.85ms
step:4395/5100 train_loss:3.5010 train_time:622014ms step_avg:141.85ms
step:4396/5100 train_loss:3.2759 train_time:622150ms step_avg:141.85ms
step:4397/5100 train_loss:3.4006 train_time:622290ms step_avg:141.85ms
step:4398/5100 train_loss:3.2026 train_time:622430ms step_avg:141.85ms
step:4399/5100 train_loss:3.4232 train_time:622569ms step_avg:141.85ms
step:4400/5100 train_loss:3.2502 train_time:622711ms step_avg:141.85ms
step:4401/5100 train_loss:3.3222 train_time:622850ms step_avg:141.85ms
step:4402/5100 train_loss:3.3917 train_time:622991ms step_avg:141.85ms
step:4403/5100 train_loss:3.2274 train_time:623131ms step_avg:141.85ms
step:4404/5100 train_loss:3.2772 train_time:623270ms step_avg:141.85ms
step:4405/5100 train_loss:3.4770 train_time:623411ms step_avg:141.85ms
step:4406/5100 train_loss:3.3371 train_time:623551ms step_avg:141.85ms
step:4407/5100 train_loss:3.3513 train_time:623692ms step_avg:141.84ms
step:4408/5100 train_loss:3.3170 train_time:623831ms step_avg:141.84ms
step:4409/5100 train_loss:3.3963 train_time:623970ms step_avg:141.84ms
step:4410/5100 train_loss:3.3803 train_time:624111ms step_avg:141.84ms
step:4411/5100 train_loss:3.5068 train_time:624250ms step_avg:141.84ms
step:4412/5100 train_loss:3.3449 train_time:624391ms step_avg:141.84ms
step:4413/5100 train_loss:3.3665 train_time:624531ms step_avg:141.84ms
step:4414/5100 train_loss:3.3595 train_time:624671ms step_avg:141.84ms
step:4415/5100 train_loss:3.4103 train_time:624812ms step_avg:141.84ms
step:4416/5100 train_loss:3.3537 train_time:624951ms step_avg:141.84ms
step:4417/5100 train_loss:3.4207 train_time:625091ms step_avg:141.84ms
step:4418/5100 train_loss:3.3499 train_time:625232ms step_avg:141.84ms
step:4419/5100 train_loss:3.2559 train_time:625371ms step_avg:141.84ms
step:4420/5100 train_loss:3.3117 train_time:625511ms step_avg:141.84ms
step:4421/5100 train_loss:3.5396 train_time:625650ms step_avg:141.84ms
step:4422/5100 train_loss:3.3420 train_time:625790ms step_avg:141.84ms
step:4423/5100 train_loss:3.2618 train_time:625930ms step_avg:141.84ms
step:4424/5100 train_loss:3.2841 train_time:626069ms step_avg:141.84ms
step:4425/5100 train_loss:3.4603 train_time:626211ms step_avg:141.84ms
step:4426/5100 train_loss:3.4111 train_time:626350ms step_avg:141.84ms
step:4427/5100 train_loss:3.3260 train_time:626492ms step_avg:141.84ms
step:4428/5100 train_loss:3.5400 train_time:626631ms step_avg:141.84ms
step:4429/5100 train_loss:3.4455 train_time:626770ms step_avg:141.84ms
step:4430/5100 train_loss:3.2446 train_time:626910ms step_avg:141.83ms
step:4431/5100 train_loss:3.2411 train_time:627050ms step_avg:141.83ms
step:4432/5100 train_loss:3.3784 train_time:627191ms step_avg:141.83ms
step:4433/5100 train_loss:3.2581 train_time:627331ms step_avg:141.83ms
step:4434/5100 train_loss:3.3765 train_time:627470ms step_avg:141.83ms
step:4435/5100 train_loss:3.4252 train_time:627611ms step_avg:141.83ms
step:4436/5100 train_loss:3.3292 train_time:627750ms step_avg:141.83ms
step:4437/5100 train_loss:3.2434 train_time:627890ms step_avg:141.83ms
step:4438/5100 train_loss:3.4863 train_time:628030ms step_avg:141.83ms
step:4439/5100 train_loss:3.4127 train_time:628170ms step_avg:141.83ms
step:4440/5100 train_loss:3.3112 train_time:628311ms step_avg:141.83ms
step:4441/5100 train_loss:3.4168 train_time:628451ms step_avg:141.83ms
step:4442/5100 train_loss:3.4445 train_time:628593ms step_avg:141.83ms
step:4443/5100 train_loss:3.5024 train_time:628732ms step_avg:141.83ms
step:4444/5100 train_loss:3.3872 train_time:628871ms step_avg:141.83ms
step:4445/5100 train_loss:3.2014 train_time:629012ms step_avg:141.83ms
step:4446/5100 train_loss:3.4795 train_time:629150ms step_avg:141.83ms
step:4447/5100 train_loss:3.3590 train_time:629291ms step_avg:141.83ms
step:4448/5100 train_loss:3.2507 train_time:629431ms step_avg:141.83ms
step:4449/5100 train_loss:3.3848 train_time:629571ms step_avg:141.83ms
step:4450/5100 train_loss:3.3749 train_time:629711ms step_avg:141.83ms
step:4451/5100 train_loss:3.3905 train_time:629850ms step_avg:141.83ms
step:4452/5100 train_loss:3.4194 train_time:629991ms step_avg:141.83ms
step:4453/5100 train_loss:3.2854 train_time:630131ms step_avg:141.83ms
step:4454/5100 train_loss:3.3150 train_time:630270ms step_avg:141.82ms
step:4455/5100 train_loss:3.3091 train_time:630411ms step_avg:141.82ms
step:4456/5100 train_loss:3.2026 train_time:630550ms step_avg:141.82ms
step:4457/5100 train_loss:3.4183 train_time:630690ms step_avg:141.82ms
step:4458/5100 train_loss:3.2851 train_time:630831ms step_avg:141.82ms
step:4459/5100 train_loss:3.2561 train_time:630970ms step_avg:141.82ms
step:4460/5100 train_loss:3.3779 train_time:631111ms step_avg:141.82ms
step:4461/5100 train_loss:3.8667 train_time:631250ms step_avg:141.82ms
step:4462/5100 train_loss:3.3716 train_time:631390ms step_avg:141.82ms
step:4463/5100 train_loss:3.4792 train_time:631531ms step_avg:141.82ms
step:4464/5100 train_loss:3.3929 train_time:631670ms step_avg:141.82ms
step:4465/5100 train_loss:3.3602 train_time:631811ms step_avg:141.82ms
step:4466/5100 train_loss:3.4363 train_time:631951ms step_avg:141.82ms
step:4467/5100 train_loss:3.2436 train_time:632091ms step_avg:141.82ms
step:4468/5100 train_loss:3.2837 train_time:632231ms step_avg:141.82ms
step:4469/5100 train_loss:3.4206 train_time:632371ms step_avg:141.82ms
step:4470/5100 train_loss:3.3969 train_time:632512ms step_avg:141.82ms
step:4471/5100 train_loss:3.3241 train_time:632651ms step_avg:141.82ms
step:4472/5100 train_loss:3.2837 train_time:632791ms step_avg:141.82ms
step:4473/5100 train_loss:3.3625 train_time:632931ms step_avg:141.82ms
step:4474/5100 train_loss:3.2127 train_time:633070ms step_avg:141.82ms
step:4475/5100 train_loss:3.2909 train_time:633210ms step_avg:141.82ms
step:4476/5100 train_loss:3.3112 train_time:633350ms step_avg:141.82ms
step:4477/5100 train_loss:3.4767 train_time:633491ms step_avg:141.82ms
step:4478/5100 train_loss:3.2152 train_time:633631ms step_avg:141.82ms
step:4479/5100 train_loss:3.3339 train_time:633771ms step_avg:141.81ms
step:4480/5100 train_loss:3.3735 train_time:633911ms step_avg:141.81ms
step:4481/5100 train_loss:3.3406 train_time:634050ms step_avg:141.81ms
step:4482/5100 train_loss:3.3397 train_time:634189ms step_avg:141.81ms
step:4483/5100 train_loss:3.1509 train_time:634330ms step_avg:141.81ms
step:4484/5100 train_loss:3.3057 train_time:634470ms step_avg:141.81ms
step:4485/5100 train_loss:3.2492 train_time:634611ms step_avg:141.81ms
step:4486/5100 train_loss:3.3828 train_time:634751ms step_avg:141.81ms
step:4487/5100 train_loss:3.2589 train_time:634891ms step_avg:141.81ms
step:4488/5100 train_loss:3.3308 train_time:635031ms step_avg:141.81ms
step:4489/5100 train_loss:3.4719 train_time:635170ms step_avg:141.81ms
step:4490/5100 train_loss:3.4362 train_time:635311ms step_avg:141.81ms
step:4491/5100 train_loss:3.3146 train_time:635451ms step_avg:141.81ms
step:4492/5100 train_loss:3.2732 train_time:635590ms step_avg:141.81ms
step:4493/5100 train_loss:3.3178 train_time:635730ms step_avg:141.81ms
step:4494/5100 train_loss:3.3500 train_time:635870ms step_avg:141.81ms
step:4495/5100 train_loss:3.3459 train_time:636011ms step_avg:141.81ms
step:4496/5100 train_loss:3.2864 train_time:636150ms step_avg:141.81ms
step:4497/5100 train_loss:3.4488 train_time:636291ms step_avg:141.81ms
step:4498/5100 train_loss:3.3270 train_time:636431ms step_avg:141.81ms
step:4499/5100 train_loss:3.1757 train_time:636571ms step_avg:141.81ms
step:4500/5100 train_loss:3.4789 train_time:636712ms step_avg:141.81ms
step:4500/5100 val_loss:3.3214 train_time:636767ms step_avg:141.82ms
step:4501/5100 train_loss:3.2748 train_time:636864ms step_avg:141.81ms
step:4502/5100 train_loss:3.2368 train_time:637007ms step_avg:141.81ms
step:4503/5100 train_loss:3.4239 train_time:637147ms step_avg:141.81ms
step:4504/5100 train_loss:3.3080 train_time:637285ms step_avg:141.81ms
step:4505/5100 train_loss:3.4079 train_time:637424ms step_avg:141.81ms
step:4506/5100 train_loss:3.3177 train_time:637562ms step_avg:141.81ms
step:4507/5100 train_loss:3.3976 train_time:637701ms step_avg:141.81ms
step:4508/5100 train_loss:3.1183 train_time:637844ms step_avg:141.81ms
step:4509/5100 train_loss:3.3961 train_time:637990ms step_avg:141.81ms
step:4510/5100 train_loss:3.2353 train_time:638132ms step_avg:141.81ms
step:4511/5100 train_loss:3.3111 train_time:638272ms step_avg:141.81ms
step:4512/5100 train_loss:3.2565 train_time:638411ms step_avg:141.81ms
step:4513/5100 train_loss:3.2544 train_time:638551ms step_avg:141.81ms
step:4514/5100 train_loss:3.2114 train_time:638689ms step_avg:141.80ms
step:4515/5100 train_loss:3.3514 train_time:638828ms step_avg:141.80ms
step:4516/5100 train_loss:3.2092 train_time:638968ms step_avg:141.80ms
step:4517/5100 train_loss:3.3113 train_time:639110ms step_avg:141.80ms
step:4518/5100 train_loss:3.3187 train_time:639249ms step_avg:141.80ms
step:4519/5100 train_loss:3.3256 train_time:639389ms step_avg:141.80ms
step:4520/5100 train_loss:3.2523 train_time:639529ms step_avg:141.80ms
step:4521/5100 train_loss:3.4298 train_time:639665ms step_avg:141.80ms
step:4522/5100 train_loss:3.5012 train_time:639805ms step_avg:141.80ms
step:4523/5100 train_loss:3.8407 train_time:639946ms step_avg:141.80ms
step:4524/5100 train_loss:3.5783 train_time:640088ms step_avg:141.80ms
step:4525/5100 train_loss:3.3452 train_time:640228ms step_avg:141.80ms
step:4526/5100 train_loss:3.2934 train_time:640366ms step_avg:141.80ms
step:4527/5100 train_loss:3.3595 train_time:640516ms step_avg:141.80ms
step:4528/5100 train_loss:3.3229 train_time:640646ms step_avg:141.80ms
step:4529/5100 train_loss:3.2307 train_time:640786ms step_avg:141.80ms
step:4530/5100 train_loss:3.9312 train_time:640927ms step_avg:141.80ms
step:4531/5100 train_loss:3.4040 train_time:641066ms step_avg:141.80ms
step:4532/5100 train_loss:3.1332 train_time:641208ms step_avg:141.80ms
step:4533/5100 train_loss:3.2378 train_time:641347ms step_avg:141.80ms
step:4534/5100 train_loss:3.3550 train_time:641486ms step_avg:141.80ms
step:4535/5100 train_loss:3.5698 train_time:641626ms step_avg:141.80ms
step:4536/5100 train_loss:3.5703 train_time:641939ms step_avg:141.83ms
step:4537/5100 train_loss:3.3047 train_time:642075ms step_avg:141.83ms
step:4538/5100 train_loss:3.3017 train_time:642215ms step_avg:141.83ms
step:4539/5100 train_loss:3.3374 train_time:642353ms step_avg:141.83ms
step:4540/5100 train_loss:3.9038 train_time:642491ms step_avg:141.83ms
step:4541/5100 train_loss:3.3891 train_time:642630ms step_avg:141.83ms
step:4542/5100 train_loss:3.2986 train_time:642768ms step_avg:141.83ms
step:4543/5100 train_loss:3.4631 train_time:642910ms step_avg:141.83ms
step:4544/5100 train_loss:3.2454 train_time:643051ms step_avg:141.83ms
step:4545/5100 train_loss:3.3546 train_time:643190ms step_avg:141.83ms
step:4546/5100 train_loss:3.5297 train_time:643329ms step_avg:141.83ms
step:4547/5100 train_loss:3.3912 train_time:643467ms step_avg:141.83ms
step:4548/5100 train_loss:3.3356 train_time:643606ms step_avg:141.83ms
step:4549/5100 train_loss:3.3277 train_time:643745ms step_avg:141.83ms
step:4550/5100 train_loss:3.2758 train_time:643886ms step_avg:141.83ms
step:4551/5100 train_loss:3.2559 train_time:644028ms step_avg:141.83ms
step:4552/5100 train_loss:3.2302 train_time:644167ms step_avg:141.82ms
step:4553/5100 train_loss:3.3430 train_time:644308ms step_avg:141.82ms
step:4554/5100 train_loss:3.5362 train_time:644446ms step_avg:141.82ms
step:4555/5100 train_loss:3.4197 train_time:644586ms step_avg:141.82ms
step:4556/5100 train_loss:3.1724 train_time:644727ms step_avg:141.82ms
step:4557/5100 train_loss:3.3785 train_time:644866ms step_avg:141.82ms
step:4558/5100 train_loss:3.3898 train_time:645007ms step_avg:141.82ms
step:4559/5100 train_loss:3.3661 train_time:645146ms step_avg:141.82ms
step:4560/5100 train_loss:3.4762 train_time:645464ms step_avg:141.86ms
step:4561/5100 train_loss:3.3030 train_time:645600ms step_avg:141.86ms
step:4562/5100 train_loss:3.3092 train_time:645740ms step_avg:141.86ms
step:4563/5100 train_loss:3.3348 train_time:645879ms step_avg:141.86ms
step:4564/5100 train_loss:3.3749 train_time:646018ms step_avg:141.86ms
step:4565/5100 train_loss:3.4626 train_time:646157ms step_avg:141.86ms
step:4566/5100 train_loss:3.5204 train_time:646296ms step_avg:141.86ms
step:4567/5100 train_loss:3.3725 train_time:646440ms step_avg:141.86ms
step:4568/5100 train_loss:3.2397 train_time:646586ms step_avg:141.86ms
step:4569/5100 train_loss:3.3397 train_time:646724ms step_avg:141.86ms
step:4570/5100 train_loss:3.2253 train_time:646863ms step_avg:141.86ms
step:4571/5100 train_loss:3.2535 train_time:647005ms step_avg:141.86ms
step:4572/5100 train_loss:3.4629 train_time:647143ms step_avg:141.86ms
step:4573/5100 train_loss:3.1663 train_time:647282ms step_avg:141.85ms
step:4574/5100 train_loss:3.2476 train_time:647425ms step_avg:141.85ms
step:4575/5100 train_loss:3.3739 train_time:647567ms step_avg:141.85ms
step:4576/5100 train_loss:3.4099 train_time:647707ms step_avg:141.85ms
step:4577/5100 train_loss:3.3675 train_time:647846ms step_avg:141.85ms
step:4578/5100 train_loss:3.3245 train_time:647986ms step_avg:141.85ms
step:4579/5100 train_loss:3.3497 train_time:648126ms step_avg:141.85ms
step:4580/5100 train_loss:3.4552 train_time:648266ms step_avg:141.85ms
step:4581/5100 train_loss:3.2846 train_time:648409ms step_avg:141.85ms
step:4582/5100 train_loss:3.3096 train_time:648547ms step_avg:141.85ms
step:4583/5100 train_loss:3.4077 train_time:648687ms step_avg:141.85ms
step:4584/5100 train_loss:3.2658 train_time:648827ms step_avg:141.85ms
step:4585/5100 train_loss:3.3781 train_time:648965ms step_avg:141.85ms
step:4586/5100 train_loss:3.3540 train_time:649105ms step_avg:141.85ms
step:4587/5100 train_loss:3.3382 train_time:649246ms step_avg:141.85ms
step:4588/5100 train_loss:3.1932 train_time:649387ms step_avg:141.85ms
step:4589/5100 train_loss:3.3162 train_time:649527ms step_avg:141.85ms
step:4590/5100 train_loss:3.5136 train_time:649667ms step_avg:141.85ms
step:4591/5100 train_loss:3.3381 train_time:649808ms step_avg:141.85ms
step:4592/5100 train_loss:3.3318 train_time:649946ms step_avg:141.85ms
step:4593/5100 train_loss:3.2903 train_time:650086ms step_avg:141.85ms
step:4594/5100 train_loss:3.4579 train_time:650227ms step_avg:141.85ms
step:4595/5100 train_loss:3.3329 train_time:650366ms step_avg:141.85ms
step:4596/5100 train_loss:3.2374 train_time:650507ms step_avg:141.85ms
step:4597/5100 train_loss:3.2307 train_time:650647ms step_avg:141.85ms
step:4598/5100 train_loss:3.4128 train_time:650787ms step_avg:141.85ms
step:4599/5100 train_loss:3.3338 train_time:650926ms step_avg:141.84ms
step:4600/5100 train_loss:3.4586 train_time:651066ms step_avg:141.84ms
step:4601/5100 train_loss:3.3597 train_time:651206ms step_avg:141.84ms
step:4602/5100 train_loss:3.1717 train_time:651346ms step_avg:141.84ms
step:4603/5100 train_loss:3.2817 train_time:651486ms step_avg:141.84ms
step:4604/5100 train_loss:3.3605 train_time:651627ms step_avg:141.84ms
step:4605/5100 train_loss:3.3769 train_time:651767ms step_avg:141.84ms
step:4606/5100 train_loss:3.2956 train_time:651906ms step_avg:141.84ms
step:4607/5100 train_loss:3.4127 train_time:652045ms step_avg:141.84ms
step:4608/5100 train_loss:3.2852 train_time:652185ms step_avg:141.84ms
step:4609/5100 train_loss:3.3921 train_time:652327ms step_avg:141.84ms
step:4610/5100 train_loss:3.3355 train_time:652466ms step_avg:141.84ms
step:4611/5100 train_loss:3.3771 train_time:652606ms step_avg:141.84ms
step:4612/5100 train_loss:3.5217 train_time:652746ms step_avg:141.84ms
step:4613/5100 train_loss:3.2171 train_time:652886ms step_avg:141.84ms
step:4614/5100 train_loss:3.0675 train_time:653026ms step_avg:141.84ms
step:4615/5100 train_loss:3.2784 train_time:653165ms step_avg:141.84ms
step:4616/5100 train_loss:3.2010 train_time:653306ms step_avg:141.84ms
step:4617/5100 train_loss:3.3035 train_time:653445ms step_avg:141.84ms
step:4618/5100 train_loss:3.1760 train_time:653586ms step_avg:141.84ms
step:4619/5100 train_loss:3.3875 train_time:653727ms step_avg:141.84ms
step:4620/5100 train_loss:3.4331 train_time:653866ms step_avg:141.84ms
step:4621/5100 train_loss:3.4708 train_time:654007ms step_avg:141.84ms
step:4622/5100 train_loss:3.2603 train_time:654147ms step_avg:141.84ms
step:4623/5100 train_loss:3.2676 train_time:654287ms step_avg:141.84ms
step:4624/5100 train_loss:3.3002 train_time:654427ms step_avg:141.84ms
step:4625/5100 train_loss:3.2087 train_time:654566ms step_avg:141.83ms
step:4625/5100 val_loss:3.3094 train_time:654624ms step_avg:141.85ms
step:4626/5100 train_loss:3.3907 train_time:654720ms step_avg:141.84ms
step:4627/5100 train_loss:3.2667 train_time:654866ms step_avg:141.84ms
step:4628/5100 train_loss:3.3298 train_time:655006ms step_avg:141.84ms
step:4629/5100 train_loss:3.5219 train_time:655146ms step_avg:141.84ms
step:4630/5100 train_loss:3.3696 train_time:655286ms step_avg:141.84ms
step:4631/5100 train_loss:3.4561 train_time:655426ms step_avg:141.84ms
step:4632/5100 train_loss:3.2556 train_time:655564ms step_avg:141.84ms
step:4633/5100 train_loss:3.4495 train_time:655706ms step_avg:141.84ms
step:4634/5100 train_loss:3.3117 train_time:655849ms step_avg:141.84ms
step:4635/5100 train_loss:3.3631 train_time:655991ms step_avg:141.84ms
step:4636/5100 train_loss:3.3743 train_time:656131ms step_avg:141.84ms
step:4637/5100 train_loss:3.2042 train_time:656270ms step_avg:141.83ms
step:4638/5100 train_loss:3.3779 train_time:656409ms step_avg:141.83ms
step:4639/5100 train_loss:3.3268 train_time:656549ms step_avg:141.83ms
step:4640/5100 train_loss:3.3274 train_time:656691ms step_avg:141.83ms
step:4641/5100 train_loss:3.2732 train_time:656833ms step_avg:141.83ms
step:4642/5100 train_loss:3.2938 train_time:656973ms step_avg:141.83ms
step:4643/5100 train_loss:3.3185 train_time:657115ms step_avg:141.83ms
step:4644/5100 train_loss:3.5362 train_time:657253ms step_avg:141.83ms
step:4645/5100 train_loss:3.4021 train_time:657393ms step_avg:141.83ms
step:4646/5100 train_loss:3.4403 train_time:657534ms step_avg:141.83ms
step:4647/5100 train_loss:3.2629 train_time:657674ms step_avg:141.83ms
step:4648/5100 train_loss:3.3795 train_time:657815ms step_avg:141.83ms
step:4649/5100 train_loss:3.3245 train_time:657955ms step_avg:141.83ms
step:4650/5100 train_loss:3.3687 train_time:658094ms step_avg:141.83ms
step:4651/5100 train_loss:3.4903 train_time:658236ms step_avg:141.83ms
step:4652/5100 train_loss:3.3215 train_time:658374ms step_avg:141.83ms
step:4653/5100 train_loss:3.4216 train_time:658514ms step_avg:141.83ms
step:4654/5100 train_loss:3.2828 train_time:658653ms step_avg:141.83ms
step:4655/5100 train_loss:3.3136 train_time:658794ms step_avg:141.83ms
step:4656/5100 train_loss:3.3413 train_time:658937ms step_avg:141.83ms
step:4657/5100 train_loss:3.2896 train_time:659073ms step_avg:141.83ms
step:4658/5100 train_loss:3.2146 train_time:659213ms step_avg:141.83ms
step:4659/5100 train_loss:3.2533 train_time:659353ms step_avg:141.83ms
step:4660/5100 train_loss:3.1847 train_time:659496ms step_avg:141.83ms
step:4661/5100 train_loss:3.3617 train_time:659634ms step_avg:141.83ms
step:4662/5100 train_loss:3.3440 train_time:659774ms step_avg:141.83ms
step:4663/5100 train_loss:3.3009 train_time:659916ms step_avg:141.83ms
step:4664/5100 train_loss:3.2295 train_time:660054ms step_avg:141.83ms
step:4665/5100 train_loss:3.2359 train_time:660194ms step_avg:141.82ms
step:4666/5100 train_loss:3.2814 train_time:660334ms step_avg:141.82ms
step:4667/5100 train_loss:3.3818 train_time:660473ms step_avg:141.82ms
step:4668/5100 train_loss:3.2970 train_time:660614ms step_avg:141.82ms
step:4669/5100 train_loss:3.2723 train_time:660753ms step_avg:141.82ms
step:4670/5100 train_loss:3.3337 train_time:660894ms step_avg:141.82ms
step:4671/5100 train_loss:3.4181 train_time:661035ms step_avg:141.82ms
step:4672/5100 train_loss:3.3105 train_time:661173ms step_avg:141.82ms
step:4673/5100 train_loss:3.3614 train_time:661313ms step_avg:141.82ms
step:4674/5100 train_loss:3.2981 train_time:661454ms step_avg:141.82ms
step:4675/5100 train_loss:3.3336 train_time:661594ms step_avg:141.82ms
step:4676/5100 train_loss:3.4127 train_time:661734ms step_avg:141.82ms
step:4677/5100 train_loss:3.1421 train_time:661874ms step_avg:141.82ms
step:4678/5100 train_loss:3.1978 train_time:662015ms step_avg:141.82ms
step:4679/5100 train_loss:3.3285 train_time:662154ms step_avg:141.82ms
step:4680/5100 train_loss:3.2966 train_time:662294ms step_avg:141.82ms
step:4681/5100 train_loss:3.3143 train_time:662434ms step_avg:141.82ms
step:4682/5100 train_loss:3.3061 train_time:662573ms step_avg:141.82ms
step:4683/5100 train_loss:3.2337 train_time:662716ms step_avg:141.82ms
step:4684/5100 train_loss:3.2096 train_time:662854ms step_avg:141.82ms
step:4685/5100 train_loss:3.4644 train_time:662994ms step_avg:141.82ms
step:4686/5100 train_loss:3.5396 train_time:663134ms step_avg:141.82ms
step:4687/5100 train_loss:3.2332 train_time:663273ms step_avg:141.82ms
step:4688/5100 train_loss:3.2448 train_time:663415ms step_avg:141.82ms
step:4689/5100 train_loss:3.4280 train_time:663553ms step_avg:141.82ms
step:4690/5100 train_loss:3.2621 train_time:663694ms step_avg:141.81ms
step:4691/5100 train_loss:3.1158 train_time:663835ms step_avg:141.81ms
step:4692/5100 train_loss:3.2213 train_time:663973ms step_avg:141.81ms
step:4693/5100 train_loss:3.2294 train_time:664113ms step_avg:141.81ms
step:4694/5100 train_loss:3.2494 train_time:664254ms step_avg:141.81ms
step:4695/5100 train_loss:3.2593 train_time:664395ms step_avg:141.81ms
step:4696/5100 train_loss:3.2826 train_time:664534ms step_avg:141.81ms
step:4697/5100 train_loss:3.3479 train_time:664673ms step_avg:141.81ms
step:4698/5100 train_loss:3.2599 train_time:664814ms step_avg:141.81ms
step:4699/5100 train_loss:3.2879 train_time:664954ms step_avg:141.81ms
step:4700/5100 train_loss:3.3785 train_time:665093ms step_avg:141.81ms
step:4701/5100 train_loss:3.3082 train_time:665233ms step_avg:141.81ms
step:4702/5100 train_loss:3.2928 train_time:665374ms step_avg:141.81ms
step:4703/5100 train_loss:3.2489 train_time:665514ms step_avg:141.81ms
step:4704/5100 train_loss:3.3298 train_time:665653ms step_avg:141.81ms
step:4705/5100 train_loss:3.2856 train_time:665793ms step_avg:141.81ms
step:4706/5100 train_loss:3.2113 train_time:665933ms step_avg:141.81ms
step:4707/5100 train_loss:3.3469 train_time:666073ms step_avg:141.81ms
step:4708/5100 train_loss:3.4369 train_time:666214ms step_avg:141.81ms
step:4709/5100 train_loss:3.2397 train_time:666354ms step_avg:141.81ms
step:4710/5100 train_loss:3.2259 train_time:666494ms step_avg:141.81ms
step:4711/5100 train_loss:3.2504 train_time:666634ms step_avg:141.81ms
step:4712/5100 train_loss:3.2612 train_time:666774ms step_avg:141.81ms
step:4713/5100 train_loss:3.3821 train_time:666914ms step_avg:141.81ms
step:4714/5100 train_loss:3.2240 train_time:667053ms step_avg:141.81ms
step:4715/5100 train_loss:3.2925 train_time:667193ms step_avg:141.81ms
step:4716/5100 train_loss:3.2259 train_time:667334ms step_avg:141.81ms
step:4717/5100 train_loss:3.2962 train_time:667473ms step_avg:141.80ms
step:4718/5100 train_loss:3.2219 train_time:667613ms step_avg:141.80ms
step:4719/5100 train_loss:3.1849 train_time:667754ms step_avg:141.80ms
step:4720/5100 train_loss:3.3541 train_time:667894ms step_avg:141.80ms
step:4721/5100 train_loss:3.3383 train_time:668033ms step_avg:141.80ms
step:4722/5100 train_loss:3.3434 train_time:668173ms step_avg:141.80ms
step:4723/5100 train_loss:3.1914 train_time:668316ms step_avg:141.80ms
step:4724/5100 train_loss:3.3786 train_time:668454ms step_avg:141.80ms
step:4725/5100 train_loss:3.2591 train_time:668773ms step_avg:141.84ms
step:4726/5100 train_loss:3.5422 train_time:668902ms step_avg:141.84ms
step:4727/5100 train_loss:3.3868 train_time:669040ms step_avg:141.84ms
step:4728/5100 train_loss:3.2594 train_time:669180ms step_avg:141.84ms
step:4729/5100 train_loss:3.1930 train_time:669317ms step_avg:141.83ms
step:4730/5100 train_loss:3.1485 train_time:669454ms step_avg:141.83ms
step:4731/5100 train_loss:3.2459 train_time:669592ms step_avg:141.83ms
step:4732/5100 train_loss:3.3126 train_time:669737ms step_avg:141.83ms
step:4733/5100 train_loss:3.2140 train_time:669879ms step_avg:141.83ms
step:4734/5100 train_loss:3.0981 train_time:670019ms step_avg:141.83ms
step:4735/5100 train_loss:3.3900 train_time:670158ms step_avg:141.83ms
step:4736/5100 train_loss:3.2617 train_time:670297ms step_avg:141.83ms
step:4737/5100 train_loss:3.4331 train_time:670435ms step_avg:141.83ms
step:4738/5100 train_loss:3.3553 train_time:670573ms step_avg:141.83ms
step:4739/5100 train_loss:3.3133 train_time:670715ms step_avg:141.83ms
step:4740/5100 train_loss:3.2772 train_time:670857ms step_avg:141.83ms
step:4741/5100 train_loss:3.2928 train_time:670997ms step_avg:141.83ms
step:4742/5100 train_loss:3.2884 train_time:671137ms step_avg:141.83ms
step:4743/5100 train_loss:3.1674 train_time:671275ms step_avg:141.83ms
step:4744/5100 train_loss:3.3057 train_time:671415ms step_avg:141.83ms
step:4745/5100 train_loss:3.2637 train_time:671553ms step_avg:141.83ms
step:4746/5100 train_loss:3.2665 train_time:671694ms step_avg:141.83ms
step:4747/5100 train_loss:3.2452 train_time:671834ms step_avg:141.83ms
step:4748/5100 train_loss:3.4294 train_time:671974ms step_avg:141.83ms
step:4749/5100 train_loss:3.2820 train_time:672114ms step_avg:141.83ms
step:4750/5100 train_loss:3.3739 train_time:672443ms step_avg:141.87ms
step:4750/5100 val_loss:3.2975 train_time:672488ms step_avg:141.88ms
step:4751/5100 train_loss:3.1891 train_time:672583ms step_avg:141.87ms
step:4752/5100 train_loss:3.1151 train_time:672728ms step_avg:141.87ms
step:4753/5100 train_loss:3.1932 train_time:672868ms step_avg:141.87ms
step:4754/5100 train_loss:3.3921 train_time:673007ms step_avg:141.86ms
step:4755/5100 train_loss:3.2716 train_time:673146ms step_avg:141.86ms
step:4756/5100 train_loss:3.5156 train_time:673283ms step_avg:141.86ms
step:4757/5100 train_loss:3.3840 train_time:673421ms step_avg:141.86ms
step:4758/5100 train_loss:3.2880 train_time:673563ms step_avg:141.86ms
step:4759/5100 train_loss:3.3480 train_time:673707ms step_avg:141.86ms
step:4760/5100 train_loss:3.3307 train_time:673849ms step_avg:141.86ms
step:4761/5100 train_loss:3.2666 train_time:673987ms step_avg:141.86ms
step:4762/5100 train_loss:3.3099 train_time:674125ms step_avg:141.86ms
step:4763/5100 train_loss:3.2829 train_time:674263ms step_avg:141.86ms
step:4764/5100 train_loss:3.1374 train_time:674400ms step_avg:141.86ms
step:4765/5100 train_loss:3.1530 train_time:674543ms step_avg:141.86ms
step:4766/5100 train_loss:3.1496 train_time:674684ms step_avg:141.86ms
step:4767/5100 train_loss:3.3879 train_time:674824ms step_avg:141.86ms
step:4768/5100 train_loss:3.6502 train_time:674964ms step_avg:141.86ms
step:4769/5100 train_loss:3.3704 train_time:675103ms step_avg:141.86ms
step:4770/5100 train_loss:3.2688 train_time:675243ms step_avg:141.86ms
step:4771/5100 train_loss:3.3356 train_time:675385ms step_avg:141.86ms
step:4772/5100 train_loss:3.2899 train_time:675523ms step_avg:141.86ms
step:4773/5100 train_loss:3.2645 train_time:675664ms step_avg:141.86ms
step:4774/5100 train_loss:3.4533 train_time:675804ms step_avg:141.86ms
step:4775/5100 train_loss:3.2668 train_time:675944ms step_avg:141.86ms
step:4776/5100 train_loss:3.4077 train_time:676083ms step_avg:141.86ms
step:4777/5100 train_loss:3.3188 train_time:676223ms step_avg:141.86ms
step:4778/5100 train_loss:3.1594 train_time:676365ms step_avg:141.86ms
step:4779/5100 train_loss:3.3346 train_time:676502ms step_avg:141.85ms
step:4780/5100 train_loss:3.2662 train_time:676644ms step_avg:141.85ms
step:4781/5100 train_loss:3.3389 train_time:676783ms step_avg:141.85ms
step:4782/5100 train_loss:3.2607 train_time:676923ms step_avg:141.85ms
step:4783/5100 train_loss:3.2087 train_time:677063ms step_avg:141.85ms
step:4784/5100 train_loss:3.2659 train_time:677203ms step_avg:141.85ms
step:4785/5100 train_loss:3.1942 train_time:677343ms step_avg:141.85ms
step:4786/5100 train_loss:3.5224 train_time:677483ms step_avg:141.85ms
step:4787/5100 train_loss:3.4115 train_time:677625ms step_avg:141.85ms
step:4788/5100 train_loss:3.3289 train_time:677767ms step_avg:141.85ms
step:4789/5100 train_loss:3.3199 train_time:677905ms step_avg:141.85ms
step:4790/5100 train_loss:3.2375 train_time:678045ms step_avg:141.85ms
step:4791/5100 train_loss:3.3382 train_time:678184ms step_avg:141.85ms
step:4792/5100 train_loss:3.3489 train_time:678324ms step_avg:141.85ms
step:4793/5100 train_loss:3.2654 train_time:678463ms step_avg:141.85ms
step:4794/5100 train_loss:3.3394 train_time:678603ms step_avg:141.85ms
step:4795/5100 train_loss:3.1793 train_time:678744ms step_avg:141.85ms
step:4796/5100 train_loss:3.3285 train_time:678883ms step_avg:141.85ms
step:4797/5100 train_loss:3.4031 train_time:679024ms step_avg:141.85ms
step:4798/5100 train_loss:3.0637 train_time:679164ms step_avg:141.85ms
step:4799/5100 train_loss:3.2462 train_time:679303ms step_avg:141.85ms
step:4800/5100 train_loss:3.2246 train_time:679444ms step_avg:141.85ms
step:4801/5100 train_loss:3.3269 train_time:679583ms step_avg:141.85ms
step:4802/5100 train_loss:3.1450 train_time:679724ms step_avg:141.85ms
step:4803/5100 train_loss:3.1855 train_time:679864ms step_avg:141.85ms
step:4804/5100 train_loss:3.3823 train_time:680004ms step_avg:141.84ms
step:4805/5100 train_loss:3.3280 train_time:680145ms step_avg:141.84ms
step:4806/5100 train_loss:3.3972 train_time:680283ms step_avg:141.84ms
step:4807/5100 train_loss:3.4090 train_time:680424ms step_avg:141.84ms
step:4808/5100 train_loss:3.1776 train_time:680563ms step_avg:141.84ms
step:4809/5100 train_loss:3.2931 train_time:680704ms step_avg:141.84ms
step:4810/5100 train_loss:3.2479 train_time:680845ms step_avg:141.84ms
step:4811/5100 train_loss:3.4815 train_time:680984ms step_avg:141.84ms
step:4812/5100 train_loss:3.2781 train_time:681125ms step_avg:141.84ms
step:4813/5100 train_loss:3.3142 train_time:681263ms step_avg:141.84ms
step:4814/5100 train_loss:3.2105 train_time:681405ms step_avg:141.84ms
step:4815/5100 train_loss:3.2579 train_time:681543ms step_avg:141.84ms
step:4816/5100 train_loss:3.6959 train_time:681683ms step_avg:141.84ms
step:4817/5100 train_loss:3.3639 train_time:681824ms step_avg:141.84ms
step:4818/5100 train_loss:3.3039 train_time:681964ms step_avg:141.84ms
step:4819/5100 train_loss:3.1551 train_time:682104ms step_avg:141.84ms
step:4820/5100 train_loss:3.2898 train_time:682244ms step_avg:141.84ms
step:4821/5100 train_loss:3.2935 train_time:682383ms step_avg:141.84ms
step:4822/5100 train_loss:3.3470 train_time:682524ms step_avg:141.84ms
step:4823/5100 train_loss:3.3990 train_time:682664ms step_avg:141.84ms
step:4824/5100 train_loss:3.2805 train_time:682804ms step_avg:141.84ms
step:4825/5100 train_loss:3.2612 train_time:682943ms step_avg:141.84ms
step:4826/5100 train_loss:3.1857 train_time:683083ms step_avg:141.84ms
step:4827/5100 train_loss:3.1511 train_time:683224ms step_avg:141.84ms
step:4828/5100 train_loss:3.3365 train_time:683362ms step_avg:141.84ms
step:4829/5100 train_loss:3.2211 train_time:683503ms step_avg:141.84ms
step:4830/5100 train_loss:3.3395 train_time:683644ms step_avg:141.83ms
step:4831/5100 train_loss:3.4997 train_time:683782ms step_avg:141.83ms
step:4832/5100 train_loss:3.2350 train_time:683928ms step_avg:141.83ms
step:4833/5100 train_loss:3.3207 train_time:684063ms step_avg:141.83ms
step:4834/5100 train_loss:3.2801 train_time:684204ms step_avg:141.83ms
step:4835/5100 train_loss:3.4687 train_time:684344ms step_avg:141.83ms
step:4836/5100 train_loss:3.2827 train_time:684484ms step_avg:141.83ms
step:4837/5100 train_loss:3.5415 train_time:684623ms step_avg:141.83ms
step:4838/5100 train_loss:3.4906 train_time:684763ms step_avg:141.83ms
step:4839/5100 train_loss:3.3176 train_time:684904ms step_avg:141.83ms
step:4840/5100 train_loss:3.3143 train_time:685044ms step_avg:141.83ms
step:4841/5100 train_loss:3.3016 train_time:685182ms step_avg:141.83ms
step:4842/5100 train_loss:3.3371 train_time:685329ms step_avg:141.83ms
step:4843/5100 train_loss:3.3324 train_time:685468ms step_avg:141.83ms
step:4844/5100 train_loss:3.1929 train_time:685604ms step_avg:141.83ms
step:4845/5100 train_loss:3.2264 train_time:685744ms step_avg:141.83ms
step:4846/5100 train_loss:3.1924 train_time:685884ms step_avg:141.83ms
step:4847/5100 train_loss:3.3625 train_time:686026ms step_avg:141.83ms
step:4848/5100 train_loss:3.2302 train_time:686162ms step_avg:141.83ms
step:4849/5100 train_loss:3.2662 train_time:686304ms step_avg:141.83ms
step:4850/5100 train_loss:3.3913 train_time:686444ms step_avg:141.83ms
step:4851/5100 train_loss:3.2827 train_time:686583ms step_avg:141.83ms
step:4852/5100 train_loss:3.0894 train_time:686724ms step_avg:141.83ms
step:4853/5100 train_loss:3.1856 train_time:686864ms step_avg:141.83ms
step:4854/5100 train_loss:3.3183 train_time:687004ms step_avg:141.83ms
step:4855/5100 train_loss:3.2751 train_time:687144ms step_avg:141.83ms
step:4856/5100 train_loss:3.4101 train_time:687283ms step_avg:141.82ms
step:4857/5100 train_loss:3.2512 train_time:687424ms step_avg:141.82ms
step:4858/5100 train_loss:3.2848 train_time:687564ms step_avg:141.82ms
step:4859/5100 train_loss:3.2299 train_time:687703ms step_avg:141.82ms
step:4860/5100 train_loss:3.3626 train_time:687845ms step_avg:141.82ms
step:4861/5100 train_loss:3.2335 train_time:687984ms step_avg:141.82ms
step:4862/5100 train_loss:3.2899 train_time:688124ms step_avg:141.82ms
step:4863/5100 train_loss:3.2932 train_time:688263ms step_avg:141.82ms
step:4864/5100 train_loss:3.2543 train_time:688404ms step_avg:141.82ms
step:4865/5100 train_loss:3.3332 train_time:688545ms step_avg:141.82ms
step:4866/5100 train_loss:2.9841 train_time:688683ms step_avg:141.82ms
step:4867/5100 train_loss:3.2006 train_time:688823ms step_avg:141.82ms
step:4868/5100 train_loss:3.2549 train_time:688964ms step_avg:141.82ms
step:4869/5100 train_loss:3.2724 train_time:689103ms step_avg:141.82ms
step:4870/5100 train_loss:3.2737 train_time:689243ms step_avg:141.82ms
step:4871/5100 train_loss:3.2810 train_time:689383ms step_avg:141.82ms
step:4872/5100 train_loss:3.3927 train_time:689524ms step_avg:141.82ms
step:4873/5100 train_loss:3.3857 train_time:689663ms step_avg:141.82ms
step:4874/5100 train_loss:3.4155 train_time:689804ms step_avg:141.82ms
step:4875/5100 train_loss:3.4876 train_time:689944ms step_avg:141.82ms
step:4875/5100 val_loss:3.2873 train_time:689999ms step_avg:141.83ms
step:4876/5100 train_loss:3.3107 train_time:690097ms step_avg:141.82ms
step:4877/5100 train_loss:3.2188 train_time:690242ms step_avg:141.82ms
step:4878/5100 train_loss:3.1814 train_time:690380ms step_avg:141.82ms
step:4879/5100 train_loss:3.2355 train_time:690518ms step_avg:141.82ms
step:4880/5100 train_loss:3.3677 train_time:690655ms step_avg:141.82ms
step:4881/5100 train_loss:3.2108 train_time:690793ms step_avg:141.82ms
step:4882/5100 train_loss:3.3372 train_time:690931ms step_avg:141.82ms
step:4883/5100 train_loss:3.3477 train_time:691074ms step_avg:141.82ms
step:4884/5100 train_loss:3.2544 train_time:691219ms step_avg:141.82ms
step:4885/5100 train_loss:3.2535 train_time:691358ms step_avg:141.82ms
step:4886/5100 train_loss:3.3685 train_time:691496ms step_avg:141.82ms
step:4887/5100 train_loss:3.3897 train_time:691635ms step_avg:141.82ms
step:4888/5100 train_loss:3.2602 train_time:691773ms step_avg:141.81ms
step:4889/5100 train_loss:3.2276 train_time:691912ms step_avg:141.81ms
step:4890/5100 train_loss:3.3161 train_time:692054ms step_avg:141.81ms
step:4891/5100 train_loss:3.2246 train_time:692195ms step_avg:141.81ms
step:4892/5100 train_loss:3.3218 train_time:692337ms step_avg:141.81ms
step:4893/5100 train_loss:3.3161 train_time:692476ms step_avg:141.81ms
step:4894/5100 train_loss:3.3439 train_time:692615ms step_avg:141.81ms
step:4895/5100 train_loss:3.4051 train_time:692755ms step_avg:141.81ms
step:4896/5100 train_loss:3.3097 train_time:692894ms step_avg:141.81ms
step:4897/5100 train_loss:3.2580 train_time:693035ms step_avg:141.81ms
step:4898/5100 train_loss:3.4536 train_time:693177ms step_avg:141.81ms
step:4899/5100 train_loss:3.2177 train_time:693318ms step_avg:141.81ms
step:4900/5100 train_loss:3.2735 train_time:693458ms step_avg:141.81ms
step:4901/5100 train_loss:3.1808 train_time:693597ms step_avg:141.81ms
step:4902/5100 train_loss:3.1510 train_time:693737ms step_avg:141.81ms
step:4903/5100 train_loss:3.2822 train_time:693879ms step_avg:141.81ms
step:4904/5100 train_loss:3.2337 train_time:694016ms step_avg:141.81ms
step:4905/5100 train_loss:3.3039 train_time:694158ms step_avg:141.81ms
step:4906/5100 train_loss:3.3458 train_time:694297ms step_avg:141.81ms
step:4907/5100 train_loss:3.2290 train_time:694438ms step_avg:141.81ms
step:4908/5100 train_loss:3.3151 train_time:694578ms step_avg:141.81ms
step:4909/5100 train_loss:3.2094 train_time:694717ms step_avg:141.81ms
step:4910/5100 train_loss:3.3311 train_time:694855ms step_avg:141.81ms
step:4911/5100 train_loss:3.3664 train_time:694995ms step_avg:141.81ms
step:4912/5100 train_loss:3.2756 train_time:695136ms step_avg:141.81ms
step:4913/5100 train_loss:3.2296 train_time:695279ms step_avg:141.81ms
step:4914/5100 train_loss:3.2359 train_time:695583ms step_avg:141.84ms
step:4915/5100 train_loss:3.1539 train_time:695719ms step_avg:141.84ms
step:4916/5100 train_loss:3.3565 train_time:695859ms step_avg:141.84ms
step:4917/5100 train_loss:3.3332 train_time:695997ms step_avg:141.84ms
step:4918/5100 train_loss:3.2573 train_time:696135ms step_avg:141.84ms
step:4919/5100 train_loss:3.2737 train_time:696273ms step_avg:141.84ms
step:4920/5100 train_loss:3.2752 train_time:696411ms step_avg:141.84ms
step:4921/5100 train_loss:3.3549 train_time:696557ms step_avg:141.84ms
step:4922/5100 train_loss:3.5141 train_time:696698ms step_avg:141.84ms
step:4923/5100 train_loss:3.3687 train_time:696838ms step_avg:141.84ms
step:4924/5100 train_loss:3.2378 train_time:696973ms step_avg:141.83ms
step:4925/5100 train_loss:3.5417 train_time:697112ms step_avg:141.83ms
step:4926/5100 train_loss:3.2806 train_time:697254ms step_avg:141.83ms
step:4927/5100 train_loss:3.2590 train_time:697393ms step_avg:141.83ms
step:4928/5100 train_loss:3.1868 train_time:697536ms step_avg:141.83ms
step:4929/5100 train_loss:3.1881 train_time:697678ms step_avg:141.83ms
step:4930/5100 train_loss:3.3402 train_time:697819ms step_avg:141.83ms
step:4931/5100 train_loss:3.5864 train_time:697957ms step_avg:141.83ms
step:4932/5100 train_loss:3.1933 train_time:698095ms step_avg:141.83ms
step:4933/5100 train_loss:3.2886 train_time:698234ms step_avg:141.83ms
step:4934/5100 train_loss:3.3572 train_time:698373ms step_avg:141.83ms
step:4935/5100 train_loss:3.1469 train_time:698516ms step_avg:141.83ms
step:4936/5100 train_loss:3.3047 train_time:698657ms step_avg:141.83ms
step:4937/5100 train_loss:3.3487 train_time:698797ms step_avg:141.83ms
step:4938/5100 train_loss:3.3346 train_time:698936ms step_avg:141.83ms
step:4939/5100 train_loss:3.3317 train_time:699075ms step_avg:141.83ms
step:4940/5100 train_loss:3.4476 train_time:699393ms step_avg:141.86ms
step:4941/5100 train_loss:3.2900 train_time:699528ms step_avg:141.86ms
step:4942/5100 train_loss:3.3024 train_time:699667ms step_avg:141.86ms
step:4943/5100 train_loss:3.0499 train_time:699807ms step_avg:141.86ms
step:4944/5100 train_loss:3.5371 train_time:699946ms step_avg:141.86ms
step:4945/5100 train_loss:3.4995 train_time:700084ms step_avg:141.86ms
step:4946/5100 train_loss:3.1179 train_time:700223ms step_avg:141.86ms
step:4947/5100 train_loss:3.3610 train_time:700366ms step_avg:141.86ms
step:4948/5100 train_loss:3.3953 train_time:700509ms step_avg:141.86ms
step:4949/5100 train_loss:3.2318 train_time:700651ms step_avg:141.86ms
step:4950/5100 train_loss:3.3557 train_time:700790ms step_avg:141.86ms
step:4951/5100 train_loss:3.2190 train_time:700930ms step_avg:141.86ms
step:4952/5100 train_loss:3.3425 train_time:701069ms step_avg:141.86ms
step:4953/5100 train_loss:3.3040 train_time:701209ms step_avg:141.86ms
step:4954/5100 train_loss:3.1939 train_time:701351ms step_avg:141.86ms
step:4955/5100 train_loss:3.3368 train_time:701492ms step_avg:141.86ms
step:4956/5100 train_loss:3.1609 train_time:701634ms step_avg:141.86ms
step:4957/5100 train_loss:3.2629 train_time:701774ms step_avg:141.86ms
step:4958/5100 train_loss:3.2374 train_time:701914ms step_avg:141.86ms
step:4959/5100 train_loss:3.2490 train_time:702054ms step_avg:141.86ms
step:4960/5100 train_loss:3.2899 train_time:702194ms step_avg:141.86ms
step:4961/5100 train_loss:3.4383 train_time:702335ms step_avg:141.86ms
step:4962/5100 train_loss:3.1930 train_time:702475ms step_avg:141.86ms
step:4963/5100 train_loss:3.3377 train_time:702618ms step_avg:141.86ms
step:4964/5100 train_loss:3.1787 train_time:702757ms step_avg:141.86ms
step:4965/5100 train_loss:3.9032 train_time:702896ms step_avg:141.86ms
step:4966/5100 train_loss:3.1685 train_time:703036ms step_avg:141.86ms
step:4967/5100 train_loss:3.3092 train_time:703176ms step_avg:141.86ms
step:4968/5100 train_loss:3.1224 train_time:703316ms step_avg:141.85ms
step:4969/5100 train_loss:3.8386 train_time:703457ms step_avg:141.85ms
step:4970/5100 train_loss:3.3744 train_time:703596ms step_avg:141.85ms
step:4971/5100 train_loss:3.2948 train_time:703736ms step_avg:141.85ms
step:4972/5100 train_loss:3.2400 train_time:703875ms step_avg:141.85ms
step:4973/5100 train_loss:3.3246 train_time:704018ms step_avg:141.85ms
step:4974/5100 train_loss:3.1918 train_time:704155ms step_avg:141.85ms
step:4975/5100 train_loss:3.1974 train_time:704295ms step_avg:141.85ms
step:4976/5100 train_loss:3.3487 train_time:704437ms step_avg:141.85ms
step:4977/5100 train_loss:3.2661 train_time:704576ms step_avg:141.85ms
step:4978/5100 train_loss:3.2261 train_time:704716ms step_avg:141.85ms
step:4979/5100 train_loss:3.2823 train_time:704855ms step_avg:141.85ms
step:4980/5100 train_loss:3.2231 train_time:704995ms step_avg:141.85ms
step:4981/5100 train_loss:3.3610 train_time:705136ms step_avg:141.85ms
step:4982/5100 train_loss:3.3313 train_time:705276ms step_avg:141.85ms
step:4983/5100 train_loss:3.1376 train_time:705418ms step_avg:141.85ms
step:4984/5100 train_loss:3.1657 train_time:705559ms step_avg:141.85ms
step:4985/5100 train_loss:3.4475 train_time:705696ms step_avg:141.85ms
step:4986/5100 train_loss:3.3380 train_time:705836ms step_avg:141.85ms
step:4987/5100 train_loss:3.2532 train_time:705976ms step_avg:141.85ms
step:4988/5100 train_loss:3.2832 train_time:706116ms step_avg:141.85ms
step:4989/5100 train_loss:3.2673 train_time:706256ms step_avg:141.85ms
step:4990/5100 train_loss:3.2446 train_time:706396ms step_avg:141.85ms
step:4991/5100 train_loss:3.3007 train_time:706537ms step_avg:141.85ms
step:4992/5100 train_loss:3.3384 train_time:706676ms step_avg:141.85ms
step:4993/5100 train_loss:3.1653 train_time:706816ms step_avg:141.85ms
step:4994/5100 train_loss:3.2968 train_time:706956ms step_avg:141.85ms
step:4995/5100 train_loss:3.2120 train_time:707096ms step_avg:141.84ms
step:4996/5100 train_loss:3.3669 train_time:707237ms step_avg:141.84ms
step:4997/5100 train_loss:3.2382 train_time:707376ms step_avg:141.84ms
step:4998/5100 train_loss:3.4056 train_time:707516ms step_avg:141.84ms
step:4999/5100 train_loss:3.2599 train_time:707657ms step_avg:141.84ms
step:5000/5100 train_loss:3.3897 train_time:707796ms step_avg:141.84ms
step:5000/5100 val_loss:3.2791 train_time:707851ms step_avg:141.85ms
step:5001/5100 train_loss:3.3109 train_time:707948ms step_avg:141.84ms
step:5002/5100 train_loss:3.3105 train_time:708096ms step_avg:141.85ms
step:5003/5100 train_loss:3.2050 train_time:708236ms step_avg:141.85ms
step:5004/5100 train_loss:3.2780 train_time:708375ms step_avg:141.85ms
step:5005/5100 train_loss:3.2885 train_time:708514ms step_avg:141.84ms
step:5006/5100 train_loss:3.1622 train_time:708655ms step_avg:141.84ms
step:5007/5100 train_loss:3.3958 train_time:708793ms step_avg:141.84ms
step:5008/5100 train_loss:3.2459 train_time:708932ms step_avg:141.84ms
step:5009/5100 train_loss:3.2667 train_time:709076ms step_avg:141.84ms
step:5010/5100 train_loss:3.2379 train_time:709220ms step_avg:141.84ms
step:5011/5100 train_loss:3.4398 train_time:709361ms step_avg:141.84ms
step:5012/5100 train_loss:3.2450 train_time:709500ms step_avg:141.84ms
step:5013/5100 train_loss:3.2213 train_time:709640ms step_avg:141.84ms
step:5014/5100 train_loss:3.1865 train_time:709780ms step_avg:141.84ms
step:5015/5100 train_loss:3.2956 train_time:709920ms step_avg:141.84ms
step:5016/5100 train_loss:3.2709 train_time:710063ms step_avg:141.84ms
step:5017/5100 train_loss:3.3084 train_time:710203ms step_avg:141.84ms
step:5018/5100 train_loss:3.3319 train_time:710345ms step_avg:141.84ms
step:5019/5100 train_loss:3.2892 train_time:710485ms step_avg:141.84ms
step:5020/5100 train_loss:3.7916 train_time:710625ms step_avg:141.84ms
step:5021/5100 train_loss:3.2359 train_time:710765ms step_avg:141.84ms
step:5022/5100 train_loss:3.3399 train_time:710904ms step_avg:141.84ms
step:5023/5100 train_loss:3.2617 train_time:711045ms step_avg:141.84ms
step:5024/5100 train_loss:3.4051 train_time:711185ms step_avg:141.84ms
step:5025/5100 train_loss:3.2012 train_time:711325ms step_avg:141.84ms
step:5026/5100 train_loss:3.3617 train_time:711465ms step_avg:141.84ms
step:5027/5100 train_loss:3.2131 train_time:711607ms step_avg:141.84ms
step:5028/5100 train_loss:3.4244 train_time:711744ms step_avg:141.84ms
step:5029/5100 train_loss:3.3312 train_time:711884ms step_avg:141.84ms
step:5030/5100 train_loss:3.3639 train_time:712025ms step_avg:141.84ms
step:5031/5100 train_loss:3.2127 train_time:712166ms step_avg:141.84ms
step:5032/5100 train_loss:3.2549 train_time:712305ms step_avg:141.84ms
step:5033/5100 train_loss:3.1911 train_time:712447ms step_avg:141.84ms
step:5034/5100 train_loss:3.3866 train_time:712586ms step_avg:141.84ms
step:5035/5100 train_loss:3.3787 train_time:712726ms step_avg:141.84ms
step:5036/5100 train_loss:3.2360 train_time:712866ms step_avg:141.84ms
step:5037/5100 train_loss:3.1618 train_time:713006ms step_avg:141.84ms
step:5038/5100 train_loss:3.1989 train_time:713147ms step_avg:141.84ms
step:5039/5100 train_loss:3.3318 train_time:713286ms step_avg:141.83ms
step:5040/5100 train_loss:3.2642 train_time:713426ms step_avg:141.83ms
step:5041/5100 train_loss:3.4305 train_time:713566ms step_avg:141.83ms
step:5042/5100 train_loss:3.2261 train_time:713704ms step_avg:141.83ms
step:5043/5100 train_loss:3.4137 train_time:713845ms step_avg:141.83ms
step:5044/5100 train_loss:3.3240 train_time:713985ms step_avg:141.83ms
step:5045/5100 train_loss:3.3836 train_time:714125ms step_avg:141.83ms
step:5046/5100 train_loss:3.2128 train_time:714265ms step_avg:141.83ms
step:5047/5100 train_loss:3.3563 train_time:714405ms step_avg:141.83ms
step:5048/5100 train_loss:3.1003 train_time:714545ms step_avg:141.83ms
step:5049/5100 train_loss:3.2604 train_time:714685ms step_avg:141.83ms
step:5050/5100 train_loss:3.2725 train_time:714826ms step_avg:141.83ms
step:5051/5100 train_loss:3.2110 train_time:714965ms step_avg:141.83ms
step:5052/5100 train_loss:3.2471 train_time:715105ms step_avg:141.83ms
step:5053/5100 train_loss:3.2938 train_time:715246ms step_avg:141.83ms
step:5054/5100 train_loss:3.3290 train_time:715384ms step_avg:141.83ms
step:5055/5100 train_loss:3.4104 train_time:715525ms step_avg:141.83ms
step:5056/5100 train_loss:3.3548 train_time:715666ms step_avg:141.83ms
step:5057/5100 train_loss:3.2402 train_time:715805ms step_avg:141.83ms
step:5058/5100 train_loss:3.1421 train_time:715946ms step_avg:141.83ms
step:5059/5100 train_loss:3.0588 train_time:716084ms step_avg:141.83ms
step:5060/5100 train_loss:3.2769 train_time:716225ms step_avg:141.83ms
step:5061/5100 train_loss:3.3636 train_time:716365ms step_avg:141.83ms
step:5062/5100 train_loss:3.2824 train_time:716505ms step_avg:141.83ms
step:5063/5100 train_loss:3.4109 train_time:716646ms step_avg:141.83ms
step:5064/5100 train_loss:3.4132 train_time:716785ms step_avg:141.83ms
step:5065/5100 train_loss:3.2706 train_time:716925ms step_avg:141.82ms
step:5066/5100 train_loss:3.3762 train_time:717065ms step_avg:141.82ms
step:5067/5100 train_loss:3.5765 train_time:717204ms step_avg:141.82ms
step:5068/5100 train_loss:3.2359 train_time:717345ms step_avg:141.82ms
step:5069/5100 train_loss:3.5765 train_time:717484ms step_avg:141.82ms
step:5070/5100 train_loss:3.2859 train_time:717625ms step_avg:141.82ms
step:5071/5100 train_loss:3.7010 train_time:717765ms step_avg:141.82ms
step:5072/5100 train_loss:3.2303 train_time:717905ms step_avg:141.82ms
step:5073/5100 train_loss:3.2865 train_time:718045ms step_avg:141.82ms
step:5074/5100 train_loss:3.4227 train_time:718185ms step_avg:141.82ms
step:5075/5100 train_loss:3.2627 train_time:718325ms step_avg:141.82ms
step:5076/5100 train_loss:3.2545 train_time:718466ms step_avg:141.82ms
step:5077/5100 train_loss:3.2229 train_time:718605ms step_avg:141.82ms
step:5078/5100 train_loss:3.3033 train_time:718746ms step_avg:141.82ms
step:5079/5100 train_loss:3.4382 train_time:718886ms step_avg:141.82ms
step:5080/5100 train_loss:3.4166 train_time:719025ms step_avg:141.82ms
step:5081/5100 train_loss:3.2366 train_time:719165ms step_avg:141.82ms
step:5082/5100 train_loss:3.3606 train_time:719304ms step_avg:141.82ms
step:5083/5100 train_loss:3.2208 train_time:719444ms step_avg:141.82ms
step:5084/5100 train_loss:3.3046 train_time:719585ms step_avg:141.82ms
step:5085/5100 train_loss:3.1964 train_time:719725ms step_avg:141.82ms
step:5086/5100 train_loss:4.0388 train_time:719866ms step_avg:141.82ms
step:5087/5100 train_loss:3.3294 train_time:720005ms step_avg:141.82ms
step:5088/5100 train_loss:3.2483 train_time:720145ms step_avg:141.82ms
step:5089/5100 train_loss:3.2622 train_time:720285ms step_avg:141.82ms
step:5090/5100 train_loss:3.3928 train_time:720425ms step_avg:141.82ms
step:5091/5100 train_loss:3.3126 train_time:720565ms step_avg:141.82ms
step:5092/5100 train_loss:3.2143 train_time:720705ms step_avg:141.82ms
step:5093/5100 train_loss:3.2273 train_time:720846ms step_avg:141.81ms
step:5094/5100 train_loss:3.2320 train_time:720985ms step_avg:141.81ms
step:5095/5100 train_loss:3.1577 train_time:721125ms step_avg:141.81ms
step:5096/5100 train_loss:3.2794 train_time:721266ms step_avg:141.81ms
step:5097/5100 train_loss:3.0628 train_time:721405ms step_avg:141.81ms
step:5098/5100 train_loss:3.3622 train_time:721545ms step_avg:141.81ms
step:5099/5100 train_loss:3.2208 train_time:721687ms step_avg:141.81ms
step:5100/5100 train_loss:3.2696 train_time:721825ms step_avg:141.81ms
step:5100/5100 val_loss:3.2755 train_time:721882ms step_avg:141.82ms