simpletransformers

Форк
0
28 строк · 832.0 Байт
1
import pandas as pd
2

3
from simpletransformers.t5 import T5Model
4

5
train_df = pd.read_csv("data/train.tsv", sep="\t").astype(str)
6
eval_df = pd.read_csv("data/eval.tsv", sep="\t").astype(str)
7

8
model_args = {
9
    "max_seq_length": 196,
10
    "train_batch_size": 16,
11
    "eval_batch_size": 64,
12
    "num_train_epochs": 1,
13
    "evaluate_during_training": True,
14
    "evaluate_during_training_steps": 15000,
15
    "evaluate_during_training_verbose": True,
16
    "use_multiprocessing": False,
17
    "fp16": False,
18
    "save_steps": -1,
19
    "save_eval_checkpoints": False,
20
    "save_model_every_epoch": False,
21
    "reprocess_input_data": True,
22
    "overwrite_output_dir": True,
23
    "wandb_project": "T5 mixed tasks - Binary, Multi-Label, Regression",
24
}
25

26
model = T5Model("t5", "t5-base", args=model_args)
27

28
model.train_model(train_df, eval_data=eval_df)
29

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.