pytorch

Форк
0
/
test_dataclasses.py 
173 строки · 4.4 Кб
1
# Owner(s): ["oncall: jit"]
2
# flake8: noqa
3

4
import sys
5
import unittest
6
from dataclasses import dataclass, field, InitVar
7
from enum import Enum
8
from typing import List, Optional
9

10
from hypothesis import given, settings, strategies as st
11

12
import torch
13
from torch.testing._internal.jit_utils import JitTestCase
14

15

16
# Example jittable dataclass
17
@dataclass(order=True)
18
class Point:
19
    x: float
20
    y: float
21
    norm: Optional[torch.Tensor] = None
22

23
    def __post_init__(self):
24
        self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
25

26

27
class MixupScheme(Enum):
28
    INPUT = ["input"]
29

30
    MANIFOLD = [
31
        "input",
32
        "before_fusion_projection",
33
        "after_fusion_projection",
34
        "after_classifier_projection",
35
    ]
36

37

38
@dataclass
39
class MixupParams:
40
    def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
41
        self.alpha = alpha
42
        self.scheme = scheme
43

44

45
class MixupScheme2(Enum):
46
    A = 1
47
    B = 2
48

49

50
@dataclass
51
class MixupParams2:
52
    def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
53
        self.alpha = alpha
54
        self.scheme = scheme
55

56

57
@dataclass
58
class MixupParams3:
59
    def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
60
        self.alpha = alpha
61
        self.scheme = scheme
62

63

64
# Make sure the Meta internal tooling doesn't raise an overflow error
65
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
66

67

68
class TestDataclasses(JitTestCase):
69
    @classmethod
70
    def tearDownClass(cls):
71
        torch._C._jit_clear_class_registry()
72

73
    def test_init_vars(self):
74
        @torch.jit.script
75
        @dataclass(order=True)
76
        class Point2:
77
            x: float
78
            y: float
79
            norm_p: InitVar[int] = 2
80
            norm: Optional[torch.Tensor] = None
81

82
            def __post_init__(self, norm_p: int):
83
                self.norm = (
84
                    torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p
85
                ) ** (1 / norm_p)
86

87
        def fn(x: float, y: float, p: int):
88
            pt = Point2(x, y, p)
89
            return pt.norm
90

91
        self.checkScript(fn, (1.0, 2.0, 3))
92

93
    # Sort of tests both __post_init__ and optional fields
94
    @settings(deadline=None)
95
    @given(NonHugeFloats, NonHugeFloats)
96
    def test__post_init__(self, x, y):
97
        P = torch.jit.script(Point)
98

99
        def fn(x: float, y: float):
100
            pt = P(x, y)
101
            return pt.norm
102

103
        self.checkScript(fn, [x, y])
104

105
    @settings(deadline=None)
106
    @given(
107
        st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)
108
    )
109
    def test_comparators(self, pt1, pt2):
110
        x1, y1 = pt1
111
        x2, y2 = pt2
112
        P = torch.jit.script(Point)
113

114
        def compare(x1: float, y1: float, x2: float, y2: float):
115
            pt1 = P(x1, y1)
116
            pt2 = P(x2, y2)
117
            return (
118
                pt1 == pt2,
119
                # pt1 != pt2,   # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
120
                pt1 < pt2,
121
                pt1 <= pt2,
122
                pt1 > pt2,
123
                pt1 >= pt2,
124
            )
125

126
        self.checkScript(compare, [x1, y1, x2, y2])
127

128
    def test_default_factories(self):
129
        @dataclass
130
        class Foo(object):
131
            x: List[int] = field(default_factory=list)
132

133
        with self.assertRaises(NotImplementedError):
134
            torch.jit.script(Foo)
135

136
            def fn():
137
                foo = Foo()
138
                return foo.x
139

140
            torch.jit.script(fn)()
141

142
    # The user should be able to write their own __eq__ implementation
143
    # without us overriding it.
144
    def test_custom__eq__(self):
145
        @torch.jit.script
146
        @dataclass
147
        class CustomEq:
148
            a: int
149
            b: int
150

151
            def __eq__(self, other: "CustomEq") -> bool:
152
                return self.a == other.a  # ignore the b field
153

154
        def fn(a: int, b1: int, b2: int):
155
            pt1 = CustomEq(a, b1)
156
            pt2 = CustomEq(a, b2)
157
            return pt1 == pt2
158

159
        self.checkScript(fn, [1, 2, 3])
160

161
    def test_no_source(self):
162
        with self.assertRaises(RuntimeError):
163
            # uses list in Enum is not supported
164
            torch.jit.script(MixupParams)
165

166
        torch.jit.script(MixupParams2)  # don't throw
167

168
    def test_use_unregistered_dataclass_raises(self):
169
        def f(a: MixupParams3):
170
            return 0
171

172
        with self.assertRaises(OSError):
173
            torch.jit.script(f)
174

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.