simpletransformers

Форк
0
/
test_classification.py 
231 строка · 7.5 Кб
1
import pandas as pd
2
import pytest
3

4
from simpletransformers.classification import (
5
    ClassificationModel,
6
    MultiLabelClassificationModel,
7
)
8

9

10
@pytest.mark.parametrize(
11
    "model_type, model_name",
12
    [
13
        ("bert", "bert-base-uncased"),
14
        ("bigbird", "google/bigbird-roberta-base"),
15
        # ("longformer", "allenai/longformer-base-4096"),
16
        # ("electra", "google/electra-small-discriminator"),
17
        # ("mobilebert", "google/mobilebert-uncased"),
18
        # ("bertweet", "vinai/bertweet-base"),
19
        # ("deberta", "microsoft/deberta-base"),
20
        # ("xlnet", "xlnet-base-cased"),
21
        # ("xlm", "xlm-mlm-17-1280"),
22
        # ("roberta", "roberta-base"),
23
        # ("distilbert", "distilbert-base-uncased"),
24
        # ("albert", "albert-base-v1"),
25
        # ("camembert", "camembert-base"),
26
        # ("xlmroberta", "xlm-roberta-base"),
27
        # ("flaubert", "flaubert-base-cased"),
28
    ],
29
)
30
def test_binary_classification(model_type, model_name):
31
    # Train and Evaluation data needs to be in a Pandas Dataframe of two columns.
32
    # The first column is the text with type str, and the second column is the
33
    # label with type int.
34
    train_data = [
35
        ["Example sentence belonging to class 1", 1],
36
        ["Example sentence belonging to class 0", 0],
37
    ]
38
    train_df = pd.DataFrame(train_data, columns=["text", "labels"])
39

40
    eval_data = [
41
        ["Example eval sentence belonging to class 1", 1],
42
        ["Example eval sentence belonging to class 0", 0],
43
    ]
44
    eval_df = pd.DataFrame(eval_data, columns=["text", "labels"])
45

46
    # Create a ClassificationModel
47
    model = ClassificationModel(
48
        model_type,
49
        model_name,
50
        use_cuda=False,
51
        args={
52
            "no_save": True,
53
            "reprocess_input_data": True,
54
            "overwrite_output_dir": True,
55
            "scheduler": "constant_schedule",
56
            "max_seq_length": 20,
57
        },
58
    )
59

60
    # Train the model
61
    model.train_model(train_df)
62

63
    # Evaluate the model
64
    result, model_outputs, wrong_predictions = model.eval_model(eval_df)
65

66
    predictions, raw_outputs = model.predict(["Some arbitary sentence"])
67

68

69
@pytest.mark.parametrize(
70
    "model_type, model_name",
71
    [
72
        # ("bert", "bert-base-uncased"),
73
        # ("xlnet", "xlnet-base-cased"),
74
        ("bigbird", "google/bigbird-roberta-base"),
75
        # ("xlm", "xlm-mlm-17-1280"),
76
        ("roberta", "roberta-base"),
77
        # ("distilbert", "distilbert-base-uncased"),
78
        # ("albert", "albert-base-v1"),
79
        # ("camembert", "camembert-base"),
80
        # ("xlmroberta", "xlm-roberta-base"),
81
        # ("flaubert", "flaubert-base-cased"),
82
    ],
83
)
84
def test_multiclass_classification(model_type, model_name):
85
    # Train and Evaluation data needs to be in a Pandas Dataframe containing at
86
    # least two columns. If the Dataframe has a header, it should contain a 'text'
87
    # and a 'labels' column. If no header is present, the Dataframe should
88
    # contain at least two columns, with the first column is the text with
89
    # type str, and the second column in the label with type int.
90
    train_data = [
91
        ["Example sentence belonging to class 1", 1],
92
        ["Example sentence belonging to class 0", 0],
93
        ["Example eval senntence belonging to class 2", 2],
94
    ]
95
    train_df = pd.DataFrame(train_data, columns=["text", "labels"])
96

97
    eval_data = [
98
        ["Example eval sentence belonging to class 1", 1],
99
        ["Example eval sentence belonging to class 0", 0],
100
        ["Example eval senntence belonging to class 2", 2],
101
    ]
102
    eval_df = pd.DataFrame(eval_data, columns=["text", "labels"])
103

104
    # Create a ClassificationModel
105
    model = ClassificationModel(
106
        model_type,
107
        model_name,
108
        num_labels=3,
109
        args={
110
            "no_save": True,
111
            "reprocess_input_data": True,
112
            "overwrite_output_dir": True,
113
            "max_seq_length": 20,
114
        },
115
        use_cuda=False,
116
    )
117

118
    # Train the model
119
    model.train_model(train_df)
120

121
    # Evaluate the model
122
    result, model_outputs, wrong_predictions = model.eval_model(eval_df)
123

124
    predictions, raw_outputs = model.predict(["Some arbitary sentence"])
125

126

127
@pytest.mark.parametrize(
128
    "model_type, model_name",
129
    [
130
        ("bert", "bert-base-uncased"),
131
        # ("xlnet", "xlnet-base-cased"),
132
        #     ("xlm", "xlm-mlm-17-1280"),
133
        #     ("roberta", "roberta-base"),
134
        #     ("distilbert", "distilbert-base-uncased"),
135
        #     ("albert", "albert-base-v1"),
136
        # ("camembert", "camembert-base")
137
    ],
138
)
139
def test_multilabel_classification(model_type, model_name):
140
    # Train and Evaluation data needs to be in a Pandas Dataframe containing at
141
    # least two columns, a 'text' and a 'labels' column. The `labels` column
142
    # should contain multi-hot encoded lists.
143
    train_data = [
144
        ["Example sentence 1 for multilabel classification.", [1, 1, 1, 1, 0, 1]]
145
    ] + [["This is another example sentence. ", [0, 1, 1, 0, 0, 0]]]
146
    train_df = pd.DataFrame(train_data, columns=["text", "labels"])
147

148
    eval_data = [
149
        ["Example eval sentence for multilabel classification.", [1, 1, 1, 1, 0, 1]],
150
        ["Example eval senntence belonging to class 2", [0, 1, 1, 0, 0, 0]],
151
    ]
152
    eval_df = pd.DataFrame(eval_data, columns=["text", "labels"])
153

154
    # Create a MultiLabelClassificationModel
155
    model = MultiLabelClassificationModel(
156
        model_type,
157
        model_name,
158
        num_labels=6,
159
        args={
160
            "no_save": True,
161
            "reprocess_input_data": True,
162
            "overwrite_output_dir": True,
163
            "num_train_epochs": 1,
164
            "max_seq_length": 20,
165
        },
166
        use_cuda=False,
167
    )
168

169
    # Train the model
170
    model.train_model(train_df)
171

172
    # Evaluate the model
173
    result, model_outputs, wrong_predictions = model.eval_model(eval_df)
174

175
    predictions, raw_outputs = model.predict(
176
        ["This thing is entirely different from the other thing. "]
177
    )
178

179

180
@pytest.mark.parametrize(
181
    "model_type, model_name",
182
    [
183
        # ("bert", "bert-base-uncased"),
184
        # ("xlnet", "xlnet-base-cased"),
185
        ("bigbird", "google/bigbird-roberta-base"),
186
        # ("xlm", "xlm-mlm-17-1280"),
187
        ("roberta", "roberta-base"),
188
        # ("distilbert", "distilbert-base-uncased"),
189
        # ("albert", "albert-base-v1"),
190
        # ("camembert", "camembert-base"),
191
        # ("xlmroberta", "xlm-roberta-base"),
192
        # ("flaubert", "flaubert-base-cased"),
193
    ],
194
)
195
def test_sliding_window(model_type, model_name):
196
    # Train and Evaluation data needs to be in a Pandas Dataframe of two columns.
197
    # The first column is the text with type str, and the second column is the
198
    # label with type int.
199
    train_data = [
200
        ["Example sentence belonging to class 1" * 10, 1],
201
        ["Example sentence belonging to class 0", 0],
202
    ]
203
    train_df = pd.DataFrame(train_data)
204

205
    eval_data = [
206
        ["Example eval sentence belonging to class 1", 1],
207
        ["Example eval sentence belonging to class 0" * 10, 0],
208
    ]
209
    eval_df = pd.DataFrame(eval_data)
210

211
    # Create a ClassificationModel
212
    model = ClassificationModel(
213
        model_type,
214
        model_name,
215
        use_cuda=False,
216
        args={
217
            "no_save": True,
218
            "reprocess_input_data": True,
219
            "overwrite_output_dir": True,
220
            "max_seq_length": 20,
221
            "sliding_window": True,
222
        },
223
    )
224

225
    # Train the model
226
    model.train_model(train_df)
227

228
    # Evaluate the model
229
    result, model_outputs, wrong_predictions = model.eval_model(eval_df)
230

231
    predictions, raw_outputs = model.predict(["Some arbitary sentence"])
232

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.