simpletransformers
56 строк · 1.3 Кб
1import logging
2
3import pandas as pd
4
5from simpletransformers.seq2seq import Seq2SeqModel
6
7logging.basicConfig(level=logging.INFO)
8transformers_logger = logging.getLogger("transformers")
9transformers_logger.setLevel(logging.WARNING)
10
11
12train_data = [
13["one", "1"],
14["two", "2"],
15]
16
17train_df = pd.DataFrame(train_data, columns=["input_text", "target_text"])
18
19eval_data = [
20["three", "3"],
21["four", "4"],
22]
23
24eval_df = pd.DataFrame(eval_data, columns=["input_text", "target_text"])
25
26model_args = {
27"reprocess_input_data": True,
28"overwrite_output_dir": True,
29"max_seq_length": 10,
30"train_batch_size": 2,
31"num_train_epochs": 100,
32"save_eval_checkpoints": False,
33"save_model_every_epoch": False,
34# "silent": True,
35"evaluate_generated_text": True,
36"evaluate_during_training": True,
37"evaluate_during_training_verbose": True,
38"use_multiprocessing": False,
39"save_best_model": False,
40"max_length": 15,
41}
42
43model = Seq2SeqModel("bert", "bert-base-cased", "bert-base-cased", args=model_args)
44
45
46def count_matches(labels, preds):
47print(labels)
48print(preds)
49return sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])
50
51
52model.train_model(train_df, eval_data=eval_df, matches=count_matches)
53
54print(model.eval_model(eval_df, matches=count_matches))
55
56print(model.predict(["four", "five"]))
57