pytorch-lightning
20 строк · 967.0 Байт
1from typing import Callable
2
3import torchmetrics
4from lightning_utilities.core.imports import compare_version as _compare_version
5
6from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_8_0
7from lightning.pytorch.utilities.migration.utils import _patch_pl_to_mirror_if_necessary
8
9
10def compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
11new_package = _patch_pl_to_mirror_if_necessary(package)
12return _compare_version(new_package, op, version, use_base_version)
13
14
15if not _TORCHMETRICS_GREATER_EQUAL_0_8_0:
16# up to v0.8.0 torchmetrics had a hardcoded reference to lightning.pytorch which has to be redirected to the
17# unified package. this was removed in
18# https://github.com/Lightning-AI/torchmetrics/commit/b225889b34b83272117b758cbc28772a5c2356d9
19torchmetrics.utilities.imports._compare_version = compare_version
20torchmetrics.metric._compare_version = compare_version
21