1
# Owner(s): ["oncall: jit"]
6
from pathlib import Path
7
from typing import Sequence
8
from unittest import skip
14
import torch._lazy.config
15
import torch._lazy.ir_cache
16
import torch._lazy.metrics
17
import torch._lazy.ts_backend
18
from torch.testing._internal.common_device_type import (
19
instantiate_device_type_tests,
22
from torch.testing._internal.common_methods_invocations import op_db
23
from torch.testing._internal.common_utils import run_tests, TestCase
24
from torch.testing._internal.jit_utils import JitTestCase
27
torch._lazy.ts_backend.init()
31
return "cuda" if "LTC_TS_CUDA" in os.environ else "cpu"
34
def remove_suffixes(l):
35
return [x.split(".")[0] for x in l]
39
path_to_script = Path(os.path.abspath(os.path.dirname(__file__)))
40
TS_NATIVE_FUNCTIONS_PATH = (
41
path_to_script.parent.parent / "aten/src/ATen/native/ts_native_functions.yaml"
43
with open(TS_NATIVE_FUNCTIONS_PATH) as f:
44
yaml_ts = yaml.load(f, yaml.SafeLoader)
48
yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"]
52
HAS_SYMINT_SUFFIX = yaml_ts["symint"]
53
FALLBACK_LIST = {"clamp"}
54
SKIP_RUNTIME_ERROR_LIST = {
55
"index_select", # Empty output_sizes is not supported
56
"clone", # is clone decomposed?
57
# General ASAN Failure due to related to generating bool values.
58
# https://github.com/pytorch/pytorch/issues/74519
59
# https://github.com/pytorch/pytorch/issues/63034
60
"nonzero", # ASAN failure (paste: P501906539)
63
"logdet", # ASAN failure
65
SKIP_INCORRECT_RESULTS_LIST = {
66
"squeeze", # Value out of range
67
"t", # Value out of range
68
"transpose", # Value out of range
69
"bernoulli", # incorrect results
70
"pow", # incorrect results
71
"addcdiv", # incorrect results (on CI not locally?)
73
# The following ops all show up directly in ts_native_functions.yaml,
74
# but run functionalized versions of the composite kernels in core.
75
# This means that we don't expect the ops to show directly in the LTC metrics.
76
FUNCTIONAL_DECOMPOSE_LIST = {
86
"linalg_pinv.atol_rtol_tensor",
89
# For some ops, we don't support all variants. Here we use formatted_name
90
# to uniquely identify the variant.
91
SKIP_VARIANT_LIST = {"norm_nuc", "min_reduction_with_dim"}
96
SKIP_RUNTIME_ERROR_LIST,
97
SKIP_INCORRECT_RESULTS_LIST,
98
FUNCTIONAL_DECOMPOSE_LIST,
107
SKIP_RUNTIME_ERROR_LIST,
108
SKIP_INCORRECT_RESULTS_LIST,
109
FUNCTIONAL_DECOMPOSE_LIST,
119
copy_t = t.detach().clone().requires_grad_(True).to(device=dev)
123
class TestLazyTensor(JitTestCase):
124
@skip("Disable until autograd supports symints")
125
def testConvolutionBackward(self):
126
test_device = get_test_device()
127
inp = torch.rand(1, 3, 128, 128, device=test_device, requires_grad=True)
128
inp_copy = clone_move(inp)
129
grad = torch.rand(1, 32, 121, 121, device=test_device) # no requires_grad
130
grad_copy = clone_move(grad)
131
weight = torch.rand(32, 3, 8, 8, device=test_device, requires_grad=True)
132
weight_copy = clone_move(weight)
133
bias = torch.rand(32, device=test_device, requires_grad=True)
134
bias_copy = clone_move(bias)
137
conv_out = torch.nn.functional.conv2d(inp, weight, bias)
138
(inp_grad, weight_grad, bias_grad) = torch.autograd.grad(
139
[conv_out], [inp, weight, bias], [grad]
143
conv_copy_out = torch.nn.functional.conv2d(inp_copy, weight_copy, bias_copy)
144
(inp_copy_grad, weight_copy_grad, bias_copy_grad) = torch.autograd.grad(
145
[conv_copy_out], [inp_copy, weight_copy, bias_copy], [grad_copy]
149
torch.testing.assert_close(bias_copy_grad.cpu(), bias_grad.cpu())
151
torch.testing.assert_close(weight_copy_grad.cpu(), weight_grad.cpu())
152
torch.testing.assert_close(inp_copy_grad.cpu(), inp_grad.cpu())
154
def test_view_mark_step_preserved(self):
155
test_device = get_test_device()
156
inp = torch.rand(4, device=test_device)
157
inp_lazy = clone_move(inp)
159
def foo(x, *, mark_step):
165
torch._lazy.mark_step()
167
# y and x should contiue to be aliased after the mark_step call.
171
out_ref = foo(inp, mark_step=False)
172
out = foo(inp_lazy, mark_step=True)
173
# out will have some pending mutations, which will be synced by the .cpu() call.
174
torch.testing.assert_close(out_ref.cpu(), out.cpu())
176
def test_tensor_ctr(self):
177
test_device = get_test_device()
178
inp = torch.tensor([[1, 2, 3, 4, 5]], device=test_device)
179
inp_lazy = torch.tensor([[1, 2, 3, 4, 5]], device="lazy")
182
# Calling a view op to ensure that functionalization wrapping occurs.
187
torch.testing.assert_close(out_ref.cpu(), out.cpu())
190
class TestLazyOpInfo(TestCase):
195
if op.name in LAZY_OPS_LIST
196
and op.name not in SKIP_RUNTIME_ERROR_LIST
197
and op.name not in FUNCTIONAL_DECOMPOSE_LIST
198
and op.formatted_name not in SKIP_VARIANT_LIST
200
allowed_dtypes=(torch.float,),
202
def test_dispatched_to_lazy(self, device, dtype, op):
205
if op.variant_test_name != "":
206
l.append(op.variant_test_name)
209
global HAS_SYMINT_SUFFIX, FALLBACK_LIST
210
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
211
sample = next(iter(samples))
212
args = [sample.input] + list(sample.args)
213
kwargs = sample.kwargs
214
torch._lazy.mark_step()
215
torch._lazy.wait_device_ops()
216
torch._lazy.metrics.reset()
218
r = op(*args, **kwargs)
219
torch._lazy.mark_step()
220
torch._lazy.wait_device_ops()
221
prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
222
symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
223
found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes(
224
torch._lazy.metrics.counter_names()
228
for alias in op.aliases:
230
f"{prefix}::{alias.name}{symint_suffix}"
231
in remove_suffixes(torch._lazy.metrics.counter_names())
233
found = found or alias_found
236
self.assertTrue(found)
242
if op.name in LAZY_OPS_LIST
243
and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST
245
allowed_dtypes=(torch.float,),
247
def test_correctness(self, device, dtype, op):
248
test_device = get_test_device()
250
def clone_to_device(input, dev):
251
if isinstance(input, torch.Tensor):
252
return input.detach().clone().to(device=dev)
253
if isinstance(input, Sequence) and not isinstance(input, str):
254
return tuple(map(functools.partial(clone_to_device, dev=dev), input))
257
def assert_allclose_rec(t):
259
self.assertEqual(type(a), type(b))
260
if isinstance(a, torch.Tensor):
262
torch.allclose(clone_to_device(a, test_device), b, atol=1e-4)
265
if isinstance(a, Sequence):
266
map(assert_allclose_rec, zip(a, b))
268
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
269
for sample in samples:
270
# Need to run mark step so that all random ops are computed in the right order
271
torch._lazy.mark_step()
273
args = [sample.input] + list(sample.args)
274
kwargs = sample.kwargs
275
copy_args = clone_to_device(args, test_device)
277
r_exp = op(*copy_args, **kwargs)
278
r_actual = op(*args, **kwargs)
280
torch._lazy.mark_step()
281
assert_allclose_rec((r_actual, r_exp))
287
if op.name in LAZY_OPS_LIST
288
and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST
290
allowed_dtypes=(torch.float,),
292
def test_correctness_with_reusing_ir(self, device, dtype, op):
293
torch._lazy.config.set_reuse_ir(True)
294
test_device = get_test_device()
296
def clone_to_device(input, dev):
297
if isinstance(input, torch.Tensor):
298
return input.detach().clone().to(device=dev)
299
if isinstance(input, Sequence) and not isinstance(input, str):
300
return tuple(map(functools.partial(clone_to_device, dev=dev), input))
303
def assert_allclose_rec(t):
305
self.assertEqual(type(a), type(b))
306
if isinstance(a, torch.Tensor):
308
torch.allclose(clone_to_device(a, test_device), b, atol=1e-4)
311
if isinstance(a, Sequence):
312
map(assert_allclose_rec, zip(a, b))
314
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
315
for sample in samples:
316
# Need to run mark step so that all random ops are computed in the right order
317
torch._lazy.mark_step()
319
args = [sample.input] + list(sample.args)
320
kwargs = sample.kwargs
321
copy_args = clone_to_device(args, test_device)
323
r_exp = op(*copy_args, **kwargs)
324
r_actual = op(*args, **kwargs)
326
torch._lazy.mark_step()
327
assert_allclose_rec((r_actual, r_exp))
329
torch._lazy.ir_cache.reset()
330
torch._lazy.config.set_reuse_ir(False)
333
# TODO: after we move to master, add Lazy as a new Device here:
334
# https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_device_type.py#L532
335
instantiate_device_type_tests(TestLazyOpInfo, globals(), only_for="cpu")
338
class TestLazyDynamicOps(TestCase):
340
def setUpClass(cls) -> None:
341
# Setup the dynamic shape mode
342
cls.old_ssa_mode = torch._C._lazy._get_symbolic_shape_mode()
343
torch._C._lazy._set_symbolic_shape_mode(True)
344
return super().setUpClass()
347
def tearDownClass(cls) -> None:
348
torch._C._lazy._set_symbolic_shape_mode(cls.old_ssa_mode)
349
return super().tearDownClass()
351
def test_nonzero_dynamic(self):
352
# Test that nonzero gives upper bounds sizes when symbolic shape mode is enabled
353
test_device = get_test_device()
355
[[0, 1.0, 2.0], [3.0, 0, 0]], device=test_device, requires_grad=True
357
x1_lazy = clone_move(x1)
358
x2_lazy = torch.nonzero(x1_lazy)
360
# FIXME: Add bindings to get upper bounds
361
# self.assertEqual(tuple(x2_lazy.size()), (6, 2))
363
# We should still be able to instantiate it and get the actual result
364
x2_eager = x2_lazy.cpu()
365
self.assertEqual(tuple(x2_eager.size()), (3, 2))
367
def test_adaptiveavgpool3d_dynamic(self):
368
# Test that adaptive_avg_pool3d gives correct shapes with lazy backend
369
img_cpu = torch.zeros([2, 3, 4, 5, 6], device="cpu")
370
out_cpu = torch.nn.AdaptiveAvgPool3d(2).to(device="cpu")(img_cpu)
372
test_device = get_test_device()
373
img_lazy = torch.zeros([2, 3, 4, 5, 6], device=test_device)
374
out_lazy = torch.nn.AdaptiveAvgPool3d(2).to(test_device)(img_lazy)
376
self.assertEqual(out_cpu.shape, out_lazy.shape)
379
if __name__ == "__main__":