pytorch

Форк
0
/
test_attr.py 
68 строк · 2.1 Кб
1
# Owner(s): ["oncall: jit"]
2

3
from typing import NamedTuple, Tuple
4

5
import torch
6
from torch.testing import FileCheck
7
from torch.testing._internal.jit_utils import JitTestCase
8

9

10
if __name__ == "__main__":
11
    raise RuntimeError(
12
        "This test file is not meant to be run directly, use:\n\n"
13
        "\tpython test/test_jit.py TESTNAME\n\n"
14
        "instead."
15
    )
16

17

18
class TestGetDefaultAttr(JitTestCase):
19
    def test_getattr_with_default(self):
20
        class A(torch.nn.Module):
21
            def __init__(self) -> None:
22
                super().__init__()
23
                self.init_attr_val = 1.0
24

25
            def forward(self, x):
26
                y = getattr(self, "init_attr_val")  # noqa: B009
27
                w: list[float] = [1.0]
28
                z = getattr(self, "missing", w)  # noqa: B009
29
                z.append(y)
30
                return z
31

32
        result = A().forward(0.0)
33
        self.assertEqual(2, len(result))
34
        graph = torch.jit.script(A()).graph
35

36
        # The "init_attr_val" attribute exists
37
        FileCheck().check('prim::GetAttr[name="init_attr_val"]').run(graph)
38
        # The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST
39
        FileCheck().check_not("missing").run(graph)
40
        # instead the getattr call will emit the default value, which is a list with one float element
41
        FileCheck().check("float[] = prim::ListConstruct").run(graph)
42

43
    def test_getattr_named_tuple(self):
44
        global MyTuple
45

46
        class MyTuple(NamedTuple):
47
            x: str
48
            y: torch.Tensor
49

50
        def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]:
51
            return (
52
                getattr(x, "x", "fdsa"),
53
                getattr(x, "y", torch.ones((3, 3))),
54
                getattr(x, "z", 7),
55
            )
56

57
        inp = MyTuple(x="test", y=torch.ones(3, 3) * 2)
58
        ref = fn(inp)
59
        fn_s = torch.jit.script(fn)
60
        res = fn_s(inp)
61
        self.assertEqual(res, ref)
62

63
    def test_getattr_tuple(self):
64
        def fn(x: Tuple[str, int]) -> int:
65
            return getattr(x, "x", 2)
66

67
        with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"):
68
            torch.jit.script(fn)
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.