pytorch

Форк
0
/
torch_version.py 
63 строки · 2.4 Кб
1
from typing import Any, Iterable
2

3
from torch._vendor.packaging.version import InvalidVersion, Version
4
from torch.version import __version__ as internal_version
5

6

7
__all__ = ["TorchVersion"]
8

9

10
class TorchVersion(str):
11
    """A string with magic powers to compare to both Version and iterables!
12
    Prior to 1.10.0 torch.__version__ was stored as a str and so many did
13
    comparisons against torch.__version__ as if it were a str. In order to not
14
    break them we have TorchVersion which masquerades as a str while also
15
    having the ability to compare against both packaging.version.Version as
16
    well as tuples of values, eg. (1, 2, 1)
17
    Examples:
18
        Comparing a TorchVersion object to a Version object
19
            TorchVersion('1.10.0a') > Version('1.10.0a')
20
        Comparing a TorchVersion object to a Tuple object
21
            TorchVersion('1.10.0a') > (1, 2)    # 1.2
22
            TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
23
        Comparing a TorchVersion object against a string
24
            TorchVersion('1.10.0a') > '1.2'
25
            TorchVersion('1.10.0a') > '1.2.1'
26
    """
27

28
    # fully qualified type names here to appease mypy
29
    def _convert_to_version(self, inp: Any) -> Any:
30
        if isinstance(inp, Version):
31
            return inp
32
        elif isinstance(inp, str):
33
            return Version(inp)
34
        elif isinstance(inp, Iterable):
35
            # Ideally this should work for most cases by attempting to group
36
            # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
37
            # Examples:
38
            #   * (1)         -> Version("1")
39
            #   * (1, 20)     -> Version("1.20")
40
            #   * (1, 20, 1)  -> Version("1.20.1")
41
            return Version(".".join(str(item) for item in inp))
42
        else:
43
            raise InvalidVersion(inp)
44

45
    def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
46
        try:
47
            return getattr(Version(self), method)(self._convert_to_version(cmp))
48
        except BaseException as e:
49
            if not isinstance(e, InvalidVersion):
50
                raise
51
            # Fall back to regular string comparison if dealing with an invalid
52
            # version like 'parrot'
53
            return getattr(super(), method)(cmp)
54

55

56
for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
57
    setattr(
58
        TorchVersion,
59
        cmp_method,
60
        lambda x, y, method=cmp_method: x._cmp_wrapper(y, method),
61
    )
62

63
__version__ = TorchVersion(internal_version)
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.