optuna
94 строки · 3.0 Кб
1from __future__ import annotations
2
3import io
4from typing import TYPE_CHECKING
5
6import boto3
7import pytest
8
9from optuna.artifacts import Boto3ArtifactStore
10from optuna.artifacts.exceptions import ArtifactNotFound
11
12
13try:
14# TODO(nabenabe0928): Replace it with `from moto import mock_aws` after dropping Python3.7.
15from moto import mock_aws
16except ImportError:
17from moto import mock_s3 as mock_aws # type: ignore[attr-defined,no-redef]
18
19if TYPE_CHECKING:
20from collections.abc import Iterator
21
22from mypy_boto3_s3 import S3Client
23from typing_extensions import Annotated
24
25# TODO(Shinichi) import Annotated from typing after python 3.8 support is dropped.
26
27
28@pytest.fixture()
29def init_mock_client() -> Iterator[tuple[str, S3Client]]:
30with mock_aws():
31# Runs before each test
32bucket_name = "moto-bucket"
33s3_client = boto3.client("s3")
34s3_client.create_bucket(Bucket=bucket_name)
35
36yield bucket_name, s3_client
37
38# Runs after each test
39objects = s3_client.list_objects(Bucket=bucket_name).get("Contents", [])
40if objects:
41s3_client.delete_objects(
42Bucket=bucket_name,
43Delete={"Objects": [{"Key": obj["Key"] for obj in objects}], "Quiet": True},
44)
45s3_client.delete_bucket(Bucket=bucket_name)
46
47
48@pytest.mark.parametrize("avoid_buf_copy", [True, False])
49def test_upload_download(
50init_mock_client: Annotated[tuple[str, S3Client], pytest.fixture],
51avoid_buf_copy: bool,
52) -> None:
53bucket_name, s3_client = init_mock_client
54backend = Boto3ArtifactStore(bucket_name, avoid_buf_copy=avoid_buf_copy)
55
56artifact_id = "dummy-uuid"
57dummy_content = b"Hello World"
58buf = io.BytesIO(dummy_content)
59
60backend.write(artifact_id, buf)
61assert len(s3_client.list_objects(Bucket=bucket_name)["Contents"]) == 1
62
63obj = s3_client.get_object(Bucket=bucket_name, Key=artifact_id)
64assert obj["Body"].read() == dummy_content
65
66with backend.open_reader(artifact_id) as f:
67actual = f.read()
68assert actual == dummy_content
69if avoid_buf_copy is False:
70assert buf.closed is False
71
72
73def test_remove(init_mock_client: Annotated[tuple[str, S3Client], pytest.fixture]) -> None:
74bucket_name, s3_client = init_mock_client
75backend = Boto3ArtifactStore(bucket_name)
76
77artifact_id = "dummy-uuid"
78backend.write(artifact_id, io.BytesIO(b"Hello"))
79objects = s3_client.list_objects(Bucket=bucket_name)["Contents"]
80assert len([obj for obj in objects if obj["Key"] == artifact_id]) == 1
81
82backend.remove(artifact_id)
83objects = s3_client.list_objects(Bucket=bucket_name).get("Contents", [])
84assert len([obj for obj in objects if obj["Key"] == artifact_id]) == 0
85
86
87def test_file_not_found_exception(
88init_mock_client: Annotated[tuple[str, S3Client], pytest.fixture]
89) -> None:
90bucket_name, _ = init_mock_client
91backend = Boto3ArtifactStore(bucket_name)
92
93with pytest.raises(ArtifactNotFound):
94backend.open_reader("not-found-id")
95