simpletransformers

Форк
0
/
test_language_modeling.py 
72 строки · 2.0 Кб
1
import os
2

3
import pandas as pd
4
import pytest
5

6
from simpletransformers.language_modeling import LanguageModelingModel
7

8

9
@pytest.mark.parametrize(
10
    "model_type, model_name",
11
    [
12
        ("bert", "bert-base-uncased"),
13
        ("longformer", "allenai/longformer-base-4096"),
14
        ("bert", None),
15
        ("electra", None),
16
        ("longformer", None),
17
        # ("xlnet", "xlnet-base-cased"),
18
        # ("xlm", "xlm-mlm-17-1280"),
19
        ("roberta", "roberta-base"),
20
        # ("distilbert", "distilbert-base-uncased"),
21
        # ("albert", "albert-base-v1"),
22
        # ("camembert", "camembert-base"),
23
        # ("xlmroberta", "xlm-roberta-base"),
24
        # ("flaubert", "flaubert-base-cased"),
25
    ],
26
)
27
def test_language_modeling(model_type, model_name):
28
    with open("train.txt", "w") as f:
29
        for i in range(100):
30
            f.writelines("Hello world with Simple Transformers! \n")
31

32
    if model_type == "electra":
33
        model_args = {
34
            "reprocess_input_data": True,
35
            "overwrite_output_dir": True,
36
            "num_train_epochs": 1,
37
            "no_save": True,
38
            "vocab_size": 100,
39
            "generator_config": {
40
                "embedding_size": 512,
41
                "hidden_size": 256,
42
                "num_hidden_layers": 1,
43
            },
44
            "discriminator_config": {
45
                "embedding_size": 512,
46
                "hidden_size": 256,
47
                "num_hidden_layers": 2,
48
            },
49
        }
50
    else:
51
        model_args = {
52
            "reprocess_input_data": True,
53
            "overwrite_output_dir": True,
54
            "num_train_epochs": 1,
55
            "no_save": True,
56
        }
57
        if model_name is None:
58
            model_args["vocab_size"] = 100
59

60
    if model_name is None:
61
        model_args["vocab_size"] = 100
62

63
    model = LanguageModelingModel(
64
        model_type,
65
        model_name,
66
        args=model_args,
67
        train_files="train.txt",
68
        use_cuda=False,
69
    )
70

71
    # Train the model
72
    model.train_model("train.txt")
73

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.