7
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
10
@contextlib.contextmanager
12
cur_env = os.environ.get("PYTORCH_JIT", "1")
13
os.environ["PYTORCH_JIT"] = "0"
17
os.environ["PYTORCH_JIT"] = cur_env
20
class TestJitDisabled(TestCase):
22
These tests are separate from the rest of the JIT tests because we need
23
run a new subprocess and `import torch` with the correct environment
27
def compare_enabled_disabled(self, src):
29
Runs the script in `src` with PYTORCH_JIT enabled and disabled and
30
compares their stdout for equality.
34
with TemporaryFileName() as fname:
35
with open(fname, 'w') as f:
38
out_disabled = subprocess.check_output([
41
out_enabled = subprocess.check_output([
44
self.assertEqual(out_disabled, out_enabled)
46
def test_attribute(self):
50
class Foo(torch.jit.ScriptModule):
51
def __init__(self, x):
53
self.x = torch.jit.Attribute(x, torch.Tensor)
55
def forward(self, input):
58
s = Foo(torch.ones(2, 3))
61
self.compare_enabled_disabled(_program_string)
63
def test_script_module_construction(self):
67
class AModule(torch.jit.ScriptModule):
68
@torch.jit.script_method
69
def forward(self, input):
73
print("Didn't throw exception")
75
self.compare_enabled_disabled(_program_string)
77
def test_recursive_script(self):
81
class AModule(torch.nn.Module):
82
def forward(self, input):
85
sm = torch.jit.script(AModule())
86
print("Didn't throw exception")
88
self.compare_enabled_disabled(_program_string)
90
if __name__ == '__main__':