11
sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
12
from allennlp.commands.test_install import _get_module_root
13
from allennlp.commands.train import train_model_from_file, train_model
14
from allennlp.common import Params
15
from allennlp.common.util import pushd
18
logger = logging.getLogger(__name__)
21
def train_fixture(config_prefix: str, config_filename: str = "experiment.json") -> None:
22
config_file = config_prefix + config_filename
23
serialization_dir = config_prefix + "serialization"
24
# Train model doesn't like it if we have incomplete serialization
25
# directories, so remove them if they exist.
26
if os.path.exists(serialization_dir):
27
shutil.rmtree(serialization_dir)
30
train_model_from_file(config_file, serialization_dir)
32
# remove unnecessary files
33
shutil.rmtree(os.path.join(serialization_dir, "log"))
35
for filename in glob.glob(os.path.join(serialization_dir, "*")):
37
filename.endswith(".log")
38
or filename.endswith(".json")
39
or re.search(r"epoch_[0-9]+\.th$", filename)
44
def train_fixture_gpu(config_prefix: str) -> None:
45
config_file = config_prefix + "experiment.json"
46
serialization_dir = config_prefix + "serialization"
47
params = Params.from_file(config_file)
48
params["trainer"]["cuda_device"] = 0
50
# train this one to a tempdir
51
tempdir = tempfile.gettempdir()
52
train_model(params, tempdir)
54
# now copy back the weights and and archived model
55
shutil.copy(os.path.join(tempdir, "best.th"), os.path.join(serialization_dir, "best_gpu.th"))
57
os.path.join(tempdir, "model.tar.gz"), os.path.join(serialization_dir, "model_gpu.tar.gz")
61
if __name__ == "__main__":
62
module_root = _get_module_root().parent
63
with pushd(module_root, verbose=True):
65
("basic_classifier", "experiment_seq2seq.jsonnet"),
67
"simple_tagger_with_elmo",
68
"simple_tagger_with_span_f1",
71
if isinstance(model, tuple):
72
model, config_filename = model
73
train_fixture(f"allennlp/tests/fixtures/{model}/", config_filename)
75
train_fixture(f"allennlp/tests/fixtures/{model}/")