1
# Owner(s): ["oncall: jit"]
6
from dataclasses import dataclass, field, InitVar
8
from typing import List, Optional
10
from hypothesis import given, settings, strategies as st
13
from torch.testing._internal.jit_utils import JitTestCase
16
# Example jittable dataclass
21
norm: Optional[torch.Tensor] = None
23
def __post_init__(self):
24
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
27
class MixupScheme(Enum):
32
"before_fusion_projection",
33
"after_fusion_projection",
34
"after_classifier_projection",
40
def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
45
class MixupScheme2(Enum):
52
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
59
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
64
# Make sure the Meta internal tooling doesn't raise an overflow error
65
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
68
class TestDataclasses(JitTestCase):
70
def tearDownClass(cls):
71
torch._C._jit_clear_class_registry()
73
def test_init_vars(self):
75
@dataclass(order=True)
79
norm_p: InitVar[int] = 2
80
norm: Optional[torch.Tensor] = None
82
def __post_init__(self, norm_p: int):
84
torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p
87
def fn(x: float, y: float, p: int):
91
self.checkScript(fn, (1.0, 2.0, 3))
93
# Sort of tests both __post_init__ and optional fields
94
@settings(deadline=None)
95
@given(NonHugeFloats, NonHugeFloats)
96
def test__post_init__(self, x, y):
97
P = torch.jit.script(Point)
99
def fn(x: float, y: float):
103
self.checkScript(fn, [x, y])
105
@settings(deadline=None)
107
st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)
109
def test_comparators(self, pt1, pt2):
112
P = torch.jit.script(Point)
114
def compare(x1: float, y1: float, x2: float, y2: float):
119
# pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
126
self.checkScript(compare, [x1, y1, x2, y2])
128
def test_default_factories(self):
131
x: List[int] = field(default_factory=list)
133
with self.assertRaises(NotImplementedError):
134
torch.jit.script(Foo)
140
torch.jit.script(fn)()
142
# The user should be able to write their own __eq__ implementation
143
# without us overriding it.
144
def test_custom__eq__(self):
151
def __eq__(self, other: "CustomEq") -> bool:
152
return self.a == other.a # ignore the b field
154
def fn(a: int, b1: int, b2: int):
155
pt1 = CustomEq(a, b1)
156
pt2 = CustomEq(a, b2)
159
self.checkScript(fn, [1, 2, 3])
161
def test_no_source(self):
162
with self.assertRaises(RuntimeError):
163
# uses list in Enum is not supported
164
torch.jit.script(MixupParams)
166
torch.jit.script(MixupParams2) # don't throw
168
def test_use_unregistered_dataclass_raises(self):
169
def f(a: MixupParams3):
172
with self.assertRaises(OSError):