pytorch

Форк
0
/
test_jit_disabled.py 
91 строка · 2.3 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import sys
4
import os
5
import contextlib
6
import subprocess
7
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
8

9

10
@contextlib.contextmanager
11
def _jit_disabled():
12
    cur_env = os.environ.get("PYTORCH_JIT", "1")
13
    os.environ["PYTORCH_JIT"] = "0"
14
    try:
15
        yield
16
    finally:
17
        os.environ["PYTORCH_JIT"] = cur_env
18

19

20
class TestJitDisabled(TestCase):
21
    """
22
    These tests are separate from the rest of the JIT tests because we need
23
    run a new subprocess and `import torch` with the correct environment
24
    variables set.
25
    """
26

27
    def compare_enabled_disabled(self, src):
28
        """
29
        Runs the script in `src` with PYTORCH_JIT enabled and disabled and
30
        compares their stdout for equality.
31
        """
32
        # Write `src` out to a temporary so our source inspection logic works
33
        # correctly.
34
        with TemporaryFileName() as fname:
35
            with open(fname, 'w') as f:
36
                f.write(src)
37
                with _jit_disabled():
38
                    out_disabled = subprocess.check_output([
39
                        sys.executable,
40
                        fname])
41
                out_enabled = subprocess.check_output([
42
                    sys.executable,
43
                    fname])
44
                self.assertEqual(out_disabled, out_enabled)
45

46
    def test_attribute(self):
47
        _program_string = """
48
import torch
49

50
class Foo(torch.jit.ScriptModule):
51
    def __init__(self, x):
52
        super().__init__()
53
        self.x = torch.jit.Attribute(x, torch.Tensor)
54

55
    def forward(self, input):
56
        return input
57

58
s = Foo(torch.ones(2, 3))
59
print(s.x)
60
"""
61
        self.compare_enabled_disabled(_program_string)
62

63
    def test_script_module_construction(self):
64
        _program_string = """
65
import torch
66

67
class AModule(torch.jit.ScriptModule):
68
    @torch.jit.script_method
69
    def forward(self, input):
70
        pass
71

72
AModule()
73
print("Didn't throw exception")
74
"""
75
        self.compare_enabled_disabled(_program_string)
76

77
    def test_recursive_script(self):
78
        _program_string = """
79
import torch
80

81
class AModule(torch.nn.Module):
82
    def forward(self, input):
83
        pass
84

85
sm = torch.jit.script(AModule())
86
print("Didn't throw exception")
87
"""
88
        self.compare_enabled_disabled(_program_string)
89

90
if __name__ == '__main__':
91
    run_tests()
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.