allennlp

Форк
0
/
train_fixtures.py 
75 строк · 2.4 Кб
1
#!/usr/bin/env python
2

3
import glob
4
import logging
5
import os
6
import re
7
import shutil
8
import sys
9
import tempfile
10

11
sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
12
from allennlp.commands.test_install import _get_module_root
13
from allennlp.commands.train import train_model_from_file, train_model
14
from allennlp.common import Params
15
from allennlp.common.util import pushd
16

17

18
logger = logging.getLogger(__name__)
19

20

21
def train_fixture(config_prefix: str, config_filename: str = "experiment.json") -> None:
22
    config_file = config_prefix + config_filename
23
    serialization_dir = config_prefix + "serialization"
24
    # Train model doesn't like it if we have incomplete serialization
25
    # directories, so remove them if they exist.
26
    if os.path.exists(serialization_dir):
27
        shutil.rmtree(serialization_dir)
28

29
    # train the model
30
    train_model_from_file(config_file, serialization_dir)
31

32
    # remove unnecessary files
33
    shutil.rmtree(os.path.join(serialization_dir, "log"))
34

35
    for filename in glob.glob(os.path.join(serialization_dir, "*")):
36
        if (
37
            filename.endswith(".log")
38
            or filename.endswith(".json")
39
            or re.search(r"epoch_[0-9]+\.th$", filename)
40
        ):
41
            os.remove(filename)
42

43

44
def train_fixture_gpu(config_prefix: str) -> None:
45
    config_file = config_prefix + "experiment.json"
46
    serialization_dir = config_prefix + "serialization"
47
    params = Params.from_file(config_file)
48
    params["trainer"]["cuda_device"] = 0
49

50
    # train this one to a tempdir
51
    tempdir = tempfile.gettempdir()
52
    train_model(params, tempdir)
53

54
    # now copy back the weights and and archived model
55
    shutil.copy(os.path.join(tempdir, "best.th"), os.path.join(serialization_dir, "best_gpu.th"))
56
    shutil.copy(
57
        os.path.join(tempdir, "model.tar.gz"), os.path.join(serialization_dir, "model_gpu.tar.gz")
58
    )
59

60

61
if __name__ == "__main__":
62
    module_root = _get_module_root().parent
63
    with pushd(module_root, verbose=True):
64
        models = [
65
            ("basic_classifier", "experiment_seq2seq.jsonnet"),
66
            "simple_tagger",
67
            "simple_tagger_with_elmo",
68
            "simple_tagger_with_span_f1",
69
        ]
70
        for model in models:
71
            if isinstance(model, tuple):
72
                model, config_filename = model
73
                train_fixture(f"allennlp/tests/fixtures/{model}/", config_filename)
74
            else:
75
                train_fixture(f"allennlp/tests/fixtures/{model}/")
76

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.