simpletransformers
57 строк · 1.6 Кб
1import os
2
3import pandas as pd
4
5from simpletransformers.classification import ClassificationModel
6
7train_data = [
8["Example sentence belonging to class 1", "Yep, this is 1", 0.8],
9["Example sentence belonging to class 0", "Yep, this is 0", 0.2],
10[
11"This is an entirely different phrase altogether and should be treated so.",
12"Is this being picked up?",
131000.5,
14],
15]
16
17train_df = pd.DataFrame(train_data, columns=["text_a", "text_b", "labels"])
18
19eval_data = [
20["Example sentence belonging to class 1", "Yep, this is 1", 1.9],
21["Example sentence belonging to class 0", "Yep, this is 0", 0.1],
22["Example 2 sentence belonging to class 0", "Yep, this is 0", 5],
23]
24
25eval_df = pd.DataFrame(eval_data, columns=["text_a", "text_b", "labels"])
26
27os.makedirs("data", exist_ok=True)
28
29train_df.to_csv("data/regression_train.tsv", sep="\t", index=False)
30eval_df.to_csv("data/regression_eval.tsv", sep="\t", index=False)
31
32train_args = {
33"reprocess_input_data": True,
34"overwrite_output_dir": True,
35"lazy_text_a_column": 0,
36"lazy_text_b_column": 1,
37"lazy_labels_column": 2,
38"lazy_header_row": True,
39"regression": True,
40"lazy_loading": True,
41}
42
43# Create a TransformerModel
44model = ClassificationModel("bert", "bert-base-cased", num_labels=1, args=train_args)
45# print(train_df.head())
46
47# Train the model
48model.train_model("data/regression_train.tsv")
49
50# # # Evaluate the model
51result, model_outputs, wrong_predictions = model.eval_model("data/regression_eval.tsv")
52
53print(result)
54
55preds, out = model.predict([["Test sentence", "Other sentence"]])
56
57print(preds)
58