optuna
57 строк · 1.6 Кб
1from __future__ import annotations2
3import copy4import io5import shutil6import threading7from typing import TYPE_CHECKING8
9from optuna.artifacts.exceptions import ArtifactNotFound10
11
12if TYPE_CHECKING:13from typing import BinaryIO14
15
16class FailArtifactStore:17def open_reader(self, artifact_id: str) -> BinaryIO:18raise Exception("something error raised")19
20def write(self, artifact_id: str, content_body: BinaryIO) -> None:21raise Exception("something error raised")22
23def remove(self, artifact_id: str) -> None:24raise Exception("something error raised")25
26
27class InMemoryArtifactStore:28def __init__(self) -> None:29self._data: dict[str, io.BytesIO] = {}30self._lock = threading.Lock()31
32def open_reader(self, artifact_id: str) -> BinaryIO:33with self._lock:34data = self._data.get(artifact_id)35if data is None:36raise ArtifactNotFound("not found")37return copy.deepcopy(data)38
39def write(self, artifact_id: str, content_body: BinaryIO) -> None:40buf = io.BytesIO()41shutil.copyfileobj(content_body, buf)42buf.seek(0)43with self._lock:44self._data[artifact_id] = buf45
46def remove(self, artifact_id: str) -> None:47with self._lock:48if artifact_id not in self._data:49raise ArtifactNotFound("not found")50del self._data[artifact_id]51
52
53if TYPE_CHECKING:54from optuna.artifacts._protocol import ArtifactStore55
56_fail: ArtifactStore = FailArtifactStore()57_inmemory: ArtifactStore = InMemoryArtifactStore()58