import argparse
import json

import ftfy
import sentencepiece as spm


parser = argparse.ArgumentParser()
parser.add_argument('--vocab_size', type=int, default=50304, help='vocab size')
parser.add_argument('--model', type=str, default='wiki50304.model', help='model name')
parser.add_argument('corpus_jsonl', type=str)
args = parser.parse_args()

# read jsonl file as iterator
def read_jsonl(filename):
    with open(filename, 'r') as f:
        for line in f:
            yield ftfy.fix_text(json.loads(line)['text'])

with open(args.model, 'wb') as model_writer:
    spm.SentencePieceTrainer.train(
        sentence_iterator=read_jsonl(args.corpus_jsonl),
        max_sentence_length=1<<31-1,
        model_writer=model_writer,
        vocab_size=args.vocab_size,
        num_threads=32,
        character_coverage=0.9995,
        model_type='bpe',
        split_digits=True,
        #allow_whitespace_only_pieces=True,
        normalization_rule_name='nfkc',
        #normalization_rule_name='nmt_nfkc',
        byte_fallback=True,
        pad_id=0,
        eos_id=args.vocab_size-1,
        bos_id=args.vocab_size-2,
        unk_id=args.vocab_size-3,
        add_dummy_prefix=True, # encode("Hello World") = encode(" Hello World")
        control_symbols=[
            '<|endofprompt|>',
            '<|fim_prefix|>',
            '<|fim_middle|>',
            '<|fim_suffix|>',
            '<|startofmeta|>',
            '<|endofmeta|>',
            '<|startoftranscript|>',
            '<|startoflm|>',
            '<|startofprev|>',
            '<|transcribe|>',
            '<|translate|>'])