simpletransformers
72 строки · 2.0 Кб
1import os
2
3import pandas as pd
4import pytest
5
6from simpletransformers.language_modeling import LanguageModelingModel
7
8
9@pytest.mark.parametrize(
10"model_type, model_name",
11[
12("bert", "bert-base-uncased"),
13("longformer", "allenai/longformer-base-4096"),
14("bert", None),
15("electra", None),
16("longformer", None),
17# ("xlnet", "xlnet-base-cased"),
18# ("xlm", "xlm-mlm-17-1280"),
19("roberta", "roberta-base"),
20# ("distilbert", "distilbert-base-uncased"),
21# ("albert", "albert-base-v1"),
22# ("camembert", "camembert-base"),
23# ("xlmroberta", "xlm-roberta-base"),
24# ("flaubert", "flaubert-base-cased"),
25],
26)
27def test_language_modeling(model_type, model_name):
28with open("train.txt", "w") as f:
29for i in range(100):
30f.writelines("Hello world with Simple Transformers! \n")
31
32if model_type == "electra":
33model_args = {
34"reprocess_input_data": True,
35"overwrite_output_dir": True,
36"num_train_epochs": 1,
37"no_save": True,
38"vocab_size": 100,
39"generator_config": {
40"embedding_size": 512,
41"hidden_size": 256,
42"num_hidden_layers": 1,
43},
44"discriminator_config": {
45"embedding_size": 512,
46"hidden_size": 256,
47"num_hidden_layers": 2,
48},
49}
50else:
51model_args = {
52"reprocess_input_data": True,
53"overwrite_output_dir": True,
54"num_train_epochs": 1,
55"no_save": True,
56}
57if model_name is None:
58model_args["vocab_size"] = 100
59
60if model_name is None:
61model_args["vocab_size"] = 100
62
63model = LanguageModelingModel(
64model_type,
65model_name,
66args=model_args,
67train_files="train.txt",
68use_cuda=False,
69)
70
71# Train the model
72model.train_model("train.txt")
73