simpletransformers

Форк
0
/
minimal_seq2seq.py 
56 строк · 1.3 Кб
1
import logging
2

3
import pandas as pd
4

5
from simpletransformers.seq2seq import Seq2SeqModel
6

7
logging.basicConfig(level=logging.INFO)
8
transformers_logger = logging.getLogger("transformers")
9
transformers_logger.setLevel(logging.WARNING)
10

11

12
train_data = [
13
    ["one", "1"],
14
    ["two", "2"],
15
]
16

17
train_df = pd.DataFrame(train_data, columns=["input_text", "target_text"])
18

19
eval_data = [
20
    ["three", "3"],
21
    ["four", "4"],
22
]
23

24
eval_df = pd.DataFrame(eval_data, columns=["input_text", "target_text"])
25

26
model_args = {
27
    "reprocess_input_data": True,
28
    "overwrite_output_dir": True,
29
    "max_seq_length": 10,
30
    "train_batch_size": 2,
31
    "num_train_epochs": 100,
32
    "save_eval_checkpoints": False,
33
    "save_model_every_epoch": False,
34
    # "silent": True,
35
    "evaluate_generated_text": True,
36
    "evaluate_during_training": True,
37
    "evaluate_during_training_verbose": True,
38
    "use_multiprocessing": False,
39
    "save_best_model": False,
40
    "max_length": 15,
41
}
42

43
model = Seq2SeqModel("bert", "bert-base-cased", "bert-base-cased", args=model_args)
44

45

46
def count_matches(labels, preds):
47
    print(labels)
48
    print(preds)
49
    return sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])
50

51

52
model.train_model(train_df, eval_data=eval_df, matches=count_matches)
53

54
print(model.eval_model(eval_df, matches=count_matches))
55

56
print(model.predict(["four", "five"]))
57

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.