simpletransformers

Форк
0
/
multilabel_classification.py 
44 строки · 1.4 Кб
1
import pandas as pd
2

3
from simpletransformers.classification import MultiLabelClassificationModel
4

5
# Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.
6
train_data = [
7
    ["Example sentence 1 for multilabel classification.", [1, 1, 1, 1, 0, 1]]
8
] + [["This is another example sentence. ", [0, 1, 1, 0, 0, 0]]]
9
train_df = pd.DataFrame(train_data, columns=["text", "labels"])
10

11
eval_data = [
12
    ["Example eval sentence for multilabel classification.", [1, 1, 1, 1, 0, 1]],
13
    ["Example eval senntence belonging to class 2", [0, 1, 1, 0, 0, 0]],
14
]
15
eval_df = pd.DataFrame(eval_data)
16

17
# Create a MultiLabelClassificationModel
18
model = MultiLabelClassificationModel(
19
    "roberta",
20
    "roberta-base",
21
    num_labels=6,
22
    args={
23
        "reprocess_input_data": True,
24
        "overwrite_output_dir": True,
25
        "num_train_epochs": 5,
26
    },
27
)
28

29
# You can set class weights by using the optional weight argument
30
print(train_df.head())
31

32
# Train the model
33
model.train_model(train_df)
34

35
# Evaluate the model
36
result, model_outputs, wrong_predictions = model.eval_model(eval_df)
37
print(result)
38
print(model_outputs)
39

40
predictions, raw_outputs = model.predict(
41
    ["This thing is entirely different from the other thing. "]
42
)
43
print(predictions)
44
print(raw_outputs)
45

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.