simpletransformers
48 строк · 1.3 Кб
1import logging
2import sacrebleu
3import pandas as pd
4from simpletransformers.t5 import T5Model, T5Args
5
6
7logging.basicConfig(level=logging.INFO)
8transformers_logger = logging.getLogger("transformers")
9transformers_logger.setLevel(logging.WARNING)
10
11model_args = T5Args()
12model_args.max_length = 512
13model_args.length_penalty = 1
14model_args.num_beams = 10
15
16model = T5Model("mt5", "outputs_base", args=model_args)
17
18eval_df = pd.read_csv("data/eval.tsv", sep="\t").astype(str)
19
20sinhala_truth = [
21eval_df.loc[eval_df["prefix"] == "translate english to sinhala"][
22"target_text"
23].tolist()
24]
25to_sinhala = eval_df.loc[eval_df["prefix"] == "translate english to sinhala"][
26"input_text"
27].tolist()
28
29english_truth = [
30eval_df.loc[eval_df["prefix"] == "translate sinhala to english"][
31"target_text"
32].tolist()
33]
34to_english = eval_df.loc[eval_df["prefix"] == "translate sinhala to english"][
35"input_text"
36].tolist()
37
38# Predict
39sinhala_preds = model.predict(to_sinhala)
40
41eng_sin_bleu = sacrebleu.corpus_bleu(sinhala_preds, sinhala_truth)
42print("--------------------------")
43print("English to Sinhalese: ", eng_sin_bleu.score)
44
45english_preds = model.predict(to_english)
46
47sin_eng_bleu = sacrebleu.corpus_bleu(english_preds, english_truth)
48print("Sinhalese to English: ", sin_eng_bleu.score)
49