transformers
64 строки · 2.1 Кб
1# we define a fixture function below and it will be "used" by
2# referencing its name from tests
3
4import os5
6import pytest7from attr import dataclass8
9
10os.environ["AWS_DEFAULT_REGION"] = "us-east-1" # defaults region11
12
13@dataclass
14class SageMakerTestEnvironment:15framework: str16role = "arn:aws:iam::558105141721:role/sagemaker_execution_role"17hyperparameters = {18"task_name": "mnli",19"per_device_train_batch_size": 16,20"per_device_eval_batch_size": 16,21"do_train": True,22"do_eval": True,23"do_predict": True,24"output_dir": "/opt/ml/model",25"overwrite_output_dir": True,26"max_steps": 500,27"save_steps": 5500,28}29distributed_hyperparameters = {**hyperparameters, "max_steps": 1000}30
31@property32def metric_definitions(self) -> str:33if self.framework == "pytorch":34return [35{"Name": "train_runtime", "Regex": r"train_runtime.*=\D*(.*?)$"},36{"Name": "eval_accuracy", "Regex": r"eval_accuracy.*=\D*(.*?)$"},37{"Name": "eval_loss", "Regex": r"eval_loss.*=\D*(.*?)$"},38]39else:40return [41{"Name": "train_runtime", "Regex": r"train_runtime.*=\D*(.*?)$"},42{"Name": "eval_accuracy", "Regex": r"loss.*=\D*(.*?)]?$"},43{"Name": "eval_loss", "Regex": r"sparse_categorical_accuracy.*=\D*(.*?)]?$"},44]45
46@property47def base_job_name(self) -> str:48return f"{self.framework}-transfromers-test"49
50@property51def test_path(self) -> str:52return f"./tests/sagemaker/scripts/{self.framework}"53
54@property55def image_uri(self) -> str:56if self.framework == "pytorch":57return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04"58else:59return "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-tensorflow-training:2.4.1-transformers4.6.1-gpu-py37-cu110-ubuntu18.04"60
61
62@pytest.fixture(scope="class")63def sm_env(request):64request.cls.env = SageMakerTestEnvironment(framework=request.cls.framework)65