simpletransformers
28 строк · 832.0 Байт
1import pandas as pd2
3from simpletransformers.t5 import T5Model4
5train_df = pd.read_csv("data/train.tsv", sep="\t").astype(str)6eval_df = pd.read_csv("data/eval.tsv", sep="\t").astype(str)7
8model_args = {9"max_seq_length": 196,10"train_batch_size": 16,11"eval_batch_size": 64,12"num_train_epochs": 1,13"evaluate_during_training": True,14"evaluate_during_training_steps": 15000,15"evaluate_during_training_verbose": True,16"use_multiprocessing": False,17"fp16": False,18"save_steps": -1,19"save_eval_checkpoints": False,20"save_model_every_epoch": False,21"reprocess_input_data": True,22"overwrite_output_dir": True,23"wandb_project": "T5 mixed tasks - Binary, Multi-Label, Regression",24}
25
26model = T5Model("t5", "t5-base", args=model_args)27
28model.train_model(train_df, eval_data=eval_df)29