7
from torch._C import parse_ir
8
from torch.testing import FileCheck
12
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13
sys.path.append(pytorch_test_dir)
14
from torch.testing._internal.jit_utils import JitTestCase
17
if __name__ == "__main__":
19
"This test file is not meant to be run directly, use:\n\n"
20
"\tpython test/test_jit.py TESTNAME\n\n"
26
class TestIgnorableArgs(JitTestCase):
27
def test_slice_ignorable_args_for_slice(self):
28
graph_str = """graph():
29
%13 : int = prim::Constant[value=0]()
30
%10 : bool = prim::Constant[value=0]()
31
%8 : NoneType = prim::Constant()
32
%0 : int = prim::Constant[value=1]()
33
%1 : int = prim::Constant[value=2]()
34
%2 : int = prim::Constant[value=3]()
35
%3 : int = prim::Constant[value=4]()
36
%4 : int = prim::Constant[value=9]()
37
%5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
38
%6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
39
%7 : int[][] = prim::ListConstruct(%5, %6)
40
%val.1 : Tensor = aten::tensor(%7, %8, %8, %10)
41
%16 : Tensor = aten::slice(%val.1, %13, %1, %8, %0)
42
%20 : Tensor = aten::slice(%16, %0, %8, %0, %0)
44
graph = parse_ir(graph_str)
45
function = self.createFunctionFromGraph(graph)
46
function_copy = self.getExportImportCopy(function)
47
src = str(function.code)
54
"torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)"
56
self.assertEqual(function(), function_copy())
58
def test_add_out_ignorable_args(self):
60
def fn(x: torch.Tensor, y: torch.Tensor):
61
torch.add(x, y, out=y)
63
FileCheck().check("torch.add(x, y, out=y)").run(fn.code)