simpletransformers
44 строки · 1.4 Кб
1import pandas as pd
2
3from simpletransformers.classification import MultiLabelClassificationModel
4
5# Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.
6train_data = [
7["Example sentence 1 for multilabel classification.", [1, 1, 1, 1, 0, 1]]
8] + [["This is another example sentence. ", [0, 1, 1, 0, 0, 0]]]
9train_df = pd.DataFrame(train_data, columns=["text", "labels"])
10
11eval_data = [
12["Example eval sentence for multilabel classification.", [1, 1, 1, 1, 0, 1]],
13["Example eval senntence belonging to class 2", [0, 1, 1, 0, 0, 0]],
14]
15eval_df = pd.DataFrame(eval_data)
16
17# Create a MultiLabelClassificationModel
18model = MultiLabelClassificationModel(
19"roberta",
20"roberta-base",
21num_labels=6,
22args={
23"reprocess_input_data": True,
24"overwrite_output_dir": True,
25"num_train_epochs": 5,
26},
27)
28
29# You can set class weights by using the optional weight argument
30print(train_df.head())
31
32# Train the model
33model.train_model(train_df)
34
35# Evaluate the model
36result, model_outputs, wrong_predictions = model.eval_model(eval_df)
37print(result)
38print(model_outputs)
39
40predictions, raw_outputs = model.predict(
41["This thing is entirely different from the other thing. "]
42)
43print(predictions)
44print(raw_outputs)
45