simpletransformers

Форк
0
48 строк · 1.3 Кб
1
import logging
2
import sacrebleu
3
import pandas as pd
4
from simpletransformers.t5 import T5Model, T5Args
5

6

7
logging.basicConfig(level=logging.INFO)
8
transformers_logger = logging.getLogger("transformers")
9
transformers_logger.setLevel(logging.WARNING)
10

11
model_args = T5Args()
12
model_args.max_length = 512
13
model_args.length_penalty = 1
14
model_args.num_beams = 10
15

16
model = T5Model("mt5", "outputs_base", args=model_args)
17

18
eval_df = pd.read_csv("data/eval.tsv", sep="\t").astype(str)
19

20
sinhala_truth = [
21
    eval_df.loc[eval_df["prefix"] == "translate english to sinhala"][
22
        "target_text"
23
    ].tolist()
24
]
25
to_sinhala = eval_df.loc[eval_df["prefix"] == "translate english to sinhala"][
26
    "input_text"
27
].tolist()
28

29
english_truth = [
30
    eval_df.loc[eval_df["prefix"] == "translate sinhala to english"][
31
        "target_text"
32
    ].tolist()
33
]
34
to_english = eval_df.loc[eval_df["prefix"] == "translate sinhala to english"][
35
    "input_text"
36
].tolist()
37

38
# Predict
39
sinhala_preds = model.predict(to_sinhala)
40

41
eng_sin_bleu = sacrebleu.corpus_bleu(sinhala_preds, sinhala_truth)
42
print("--------------------------")
43
print("English to Sinhalese: ", eng_sin_bleu.score)
44

45
english_preds = model.predict(to_english)
46

47
sin_eng_bleu = sacrebleu.corpus_bleu(english_preds, english_truth)
48
print("Sinhalese to English: ", sin_eng_bleu.score)
49

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.