10
from typing import Union
11
from unittest.mock import patch
16
import torch.testing._internal.common_utils as common
17
import torch.utils.cpp_extension
18
from torch.serialization import safe_globals
19
from torch.testing._internal.common_utils import (
26
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
29
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
30
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
33
def remove_build_path():
34
if sys.platform == "win32":
37
default_build_root = torch.utils.cpp_extension.get_default_build_root()
38
if os.path.exists(default_build_root):
39
shutil.rmtree(default_build_root, ignore_errors=True)
42
def generate_faked_module():
43
def device_count() -> int:
46
def get_rng_state(device: Union[int, str, torch.device] = "foo") -> torch.Tensor:
48
return torch.empty(4, 4, device="foo")
51
new_state: torch.Tensor, device: Union[int, str, torch.device] = "foo"
62
foo = types.ModuleType("foo")
64
foo.device_count = device_count
65
foo.get_rng_state = get_rng_state
66
foo.set_rng_state = set_rng_state
67
foo.is_available = is_available
68
foo.current_device = current_device
69
foo._lazy_init = lambda: None
70
foo.is_initialized = lambda: True
75
@unittest.skipIf(IS_ARM64, "Does not work on arm")
76
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
77
@torch.testing._internal.common_utils.markDynamoStrictTest
78
class TestCppExtensionOpenRgistration(common.TestCase):
79
"""Tests Open Device Registration with C++ extensions."""
88
self.old_working_dir = os.getcwd()
89
os.chdir(os.path.dirname(os.path.abspath(__file__)))
91
assert self.module is not None
97
os.chdir(self.old_working_dir)
103
cls.module = torch.utils.cpp_extension.load(
104
name="custom_device_extension",
106
"cpp_extensions/open_registration_extension.cpp",
108
extra_include_paths=["cpp_extensions"],
114
torch.utils.rename_privateuse1_backend("foo")
115
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
116
torch._register_device_module("foo", generate_faked_module())
118
def test_base_device_registration(self):
119
self.assertFalse(self.module.custom_add_called())
121
device = self.module.custom_device()
122
x = torch.empty(4, 4, device=device)
123
y = torch.empty(4, 4, device=device)
125
self.assertTrue(x.device == device)
126
self.assertFalse(x.is_cpu)
127
self.assertFalse(self.module.custom_add_called())
131
self.assertTrue(self.module.custom_add_called())
132
z_cpu = z.to(device="cpu")
134
self.assertTrue(z_cpu.is_cpu)
135
self.assertFalse(z.is_cpu)
136
self.assertTrue(z.device == device)
137
self.assertEqual(z, z_cpu)
139
def test_common_registration(self):
141
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
142
torch._register_device_module("dev", generate_faked_module())
143
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
144
torch._register_device_module("foo", generate_faked_module())
147
torch.utils.rename_privateuse1_backend("foo")
150
with self.assertRaisesRegex(
151
RuntimeError, "torch.register_privateuse1_backend()"
153
torch.utils.rename_privateuse1_backend("dev")
156
with self.assertRaisesRegex(RuntimeError, "The custom device module of"):
157
torch.utils.generate_methods_for_privateuse1_backend()
161
torch.utils.backend_registration._get_custom_mod_func("device_count")() == 1
163
with self.assertRaisesRegex(RuntimeError, "Try to call torch.foo"):
164
torch.utils.backend_registration._get_custom_mod_func("func_name_")
167
self.assertTrue(hasattr(torch.Tensor, "is_foo"))
168
self.assertTrue(hasattr(torch.Tensor, "foo"))
169
self.assertTrue(hasattr(torch.TypedStorage, "is_foo"))
170
self.assertTrue(hasattr(torch.TypedStorage, "foo"))
171
self.assertTrue(hasattr(torch.UntypedStorage, "is_foo"))
172
self.assertTrue(hasattr(torch.UntypedStorage, "foo"))
173
self.assertTrue(hasattr(torch.nn.Module, "foo"))
174
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_foo"))
175
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "foo"))
177
def test_open_device_generator_registration_and_hooks(self):
178
device = self.module.custom_device()
180
self.assertFalse(self.module.custom_add_called())
183
with self.assertRaisesRegex(
185
"Please register a generator to the PrivateUse1 dispatch key",
187
torch.Generator(device=device)
189
self.module.register_generator_first()
190
gen = torch.Generator(device=device)
191
self.assertTrue(gen.device == device)
194
with self.assertRaisesRegex(
196
"Only can register a generator to the PrivateUse1 dispatch key once",
198
self.module.register_generator_second()
200
if self.module.is_register_hook() is False:
201
self.module.register_hook()
202
default_gen = self.module.default_generator(0)
204
default_gen.device.type == torch._C._get_privateuse1_backend_name()
207
def test_open_device_dispatchstub(self):
209
input_data = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu")
210
foo_input_data = input_data.to("foo")
211
output_data = torch.abs(input_data)
212
foo_output_data = torch.abs(foo_input_data)
213
self.assertEqual(output_data, foo_output_data.cpu())
215
output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
217
foo_input_data = input_data.to("foo")
218
foo_output_data = output_data.to("foo")
220
torch.abs(input_data, out=output_data[:, :, 0:6:2])
221
torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:2])
222
self.assertEqual(output_data, foo_output_data.cpu())
226
output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
227
foo_input_data = input_data.to("foo")
228
foo_output_data = output_data.to("foo")
229
torch.abs(input_data, out=output_data[:, :, 0:6:3])
230
torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:3])
231
self.assertEqual(output_data, foo_output_data.cpu())
233
def test_open_device_quantized(self):
234
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to("foo")
235
quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
236
self.assertEqual(quantized_tensor.device, torch.device("foo:0"))
237
self.assertEqual(quantized_tensor.dtype, torch.qint8)
239
def test_open_device_random(self):
241
with torch.random.fork_rng(device_type="foo"):
244
def test_open_device_tensor(self):
245
device = self.module.custom_device()
249
torch.bool: "torch.foo.BoolTensor",
250
torch.double: "torch.foo.DoubleTensor",
251
torch.float32: "torch.foo.FloatTensor",
252
torch.half: "torch.foo.HalfTensor",
253
torch.int32: "torch.foo.IntTensor",
254
torch.int64: "torch.foo.LongTensor",
255
torch.int8: "torch.foo.CharTensor",
256
torch.short: "torch.foo.ShortTensor",
257
torch.uint8: "torch.foo.ByteTensor",
259
for tt, dt in dtypes.items():
260
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
261
self.assertTrue(test_tensor.type() == dt)
264
x = torch.empty(4, 4)
265
self.assertFalse(x.is_foo)
267
x = x.foo(torch.device("foo"))
268
self.assertFalse(self.module.custom_add_called())
269
self.assertTrue(x.is_foo)
272
y = torch.empty(4, 4)
273
self.assertFalse(y.is_foo)
275
y = y.foo(torch.device("foo:0"))
276
self.assertFalse(self.module.custom_add_called())
277
self.assertTrue(y.is_foo)
280
z = torch.empty(4, 4)
281
self.assertFalse(z.is_foo)
284
self.assertFalse(self.module.custom_add_called())
285
self.assertTrue(z.is_foo)
287
def test_open_device_packed_sequence(self):
288
device = self.module.custom_device()
290
b = torch.tensor([1, 1, 1, 1, 1])
291
input = torch.nn.utils.rnn.PackedSequence(a, b)
292
self.assertFalse(input.is_foo)
293
input_foo = input.foo()
294
self.assertTrue(input_foo.is_foo)
296
def test_open_device_storage(self):
298
x = torch.empty(4, 4)
300
self.assertFalse(z1.is_foo)
303
self.assertFalse(self.module.custom_add_called())
304
self.assertTrue(z1.is_foo)
306
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
307
z1.foo(torch.device("cpu"))
310
self.assertFalse(self.module.custom_add_called())
311
self.assertFalse(z1.is_foo)
313
z1 = z1.foo(device="foo:0", non_blocking=False)
314
self.assertFalse(self.module.custom_add_called())
315
self.assertTrue(z1.is_foo)
317
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
318
z1.foo(device="cuda:0", non_blocking=False)
321
y = torch.empty(4, 4)
322
z2 = y.untyped_storage()
323
self.assertFalse(z2.is_foo)
326
self.assertFalse(self.module.custom_add_called())
327
self.assertTrue(z2.is_foo)
330
self.module.custom_storage_registry()
332
z3 = y.untyped_storage()
333
self.assertFalse(self.module.custom_storageImpl_called())
336
self.assertTrue(self.module.custom_storageImpl_called())
337
self.assertFalse(self.module.custom_storageImpl_called())
340
self.assertTrue(self.module.custom_storageImpl_called())
342
@skipIfTorchDynamo("unsupported aten.is_pinned.default")
343
def test_open_device_storage_pin_memory(self):
345
cpu_tensor = torch.empty(3)
346
self.assertFalse(cpu_tensor.is_foo)
347
self.assertFalse(cpu_tensor.is_pinned("foo"))
349
cpu_tensor_pin = cpu_tensor.pin_memory("foo")
350
self.assertTrue(cpu_tensor_pin.is_pinned("foo"))
353
cpu_storage = cpu_tensor.storage()
358
self.assertTrue(cpu_storage.is_pinned("foo"))
360
cpu_storage_pinned = cpu_storage.pin_memory("foo")
361
self.assertTrue(cpu_storage_pinned.is_pinned("foo"))
364
cpu_tensor = torch.randn([3, 2, 1, 4])
365
cpu_untyped_storage = cpu_tensor.untyped_storage()
366
self.assertTrue(cpu_untyped_storage.is_pinned("foo"))
368
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("foo")
369
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo"))
372
"Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function"
374
def test_open_device_serialization(self):
375
self.module.set_custom_device_index(-1)
376
storage = torch.UntypedStorage(4, device=torch.device("foo"))
377
self.assertEqual(torch.serialization.location_tag(storage), "foo")
379
self.module.set_custom_device_index(0)
380
storage = torch.UntypedStorage(4, device=torch.device("foo"))
381
self.assertEqual(torch.serialization.location_tag(storage), "foo:0")
383
cpu_storage = torch.empty(4, 4).storage()
384
foo_storage = torch.serialization.default_restore_location(cpu_storage, "foo:0")
385
self.assertTrue(foo_storage.is_foo)
388
x = torch.empty(4, 4).long()
390
self.assertFalse(self.module.check_backend_meta(y))
391
self.module.custom_set_backend_meta(y)
392
self.assertTrue(self.module.check_backend_meta(y))
394
self.module.custom_serialization_registry()
395
with tempfile.TemporaryDirectory() as tmpdir:
396
path = os.path.join(tmpdir, "data.pt")
398
z1 = torch.load(path)
400
self.assertTrue(z1.is_foo)
402
self.assertTrue(self.module.check_backend_meta(z1))
405
z2 = torch.load(path, map_location="cpu")
407
self.assertFalse(z2.is_foo)
409
self.assertFalse(self.module.check_backend_meta(z2))
411
def test_open_device_storage_resize(self):
412
cpu_tensor = torch.randn([8])
413
foo_tensor = cpu_tensor.foo()
414
foo_storage = foo_tensor.storage()
415
self.assertTrue(foo_storage.size() == 8)
418
foo_tensor.resize_(8)
419
self.assertTrue(foo_storage.size() == 8)
421
with self.assertRaisesRegex(TypeError, "Overflow"):
422
foo_tensor.resize_(8**29)
424
def test_open_device_storage_type(self):
426
cpu_tensor = torch.randn([8]).float()
427
cpu_storage = cpu_tensor.storage()
428
self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
431
foo_tensor = cpu_tensor.foo()
432
foo_storage = foo_tensor.storage()
433
self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage")
435
class CustomFloatStorage:
437
def __module__(self):
438
return "torch." + torch._C._get_privateuse1_backend_name()
442
return "FloatStorage"
446
torch.foo.FloatStorage = CustomFloatStorage()
447
self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage")
450
foo_tensor2 = torch.randn([8]).int().foo()
451
foo_storage2 = foo_tensor2.storage()
452
self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage")
454
torch.foo.FloatStorage = None
456
def test_open_device_faketensor(self):
457
with torch._subclasses.fake_tensor.FakeTensorMode.push():
458
a = torch.empty(1, device="foo")
459
b = torch.empty(1, device="foo:0")
462
def test_open_device_named_tensor(self):
463
torch.empty([2, 3, 4, 5], device="foo", names=["N", "C", "H", "W"])
467
def test_compile_autograd_function_returns_self(self):
468
x_ref = torch.randn(4, requires_grad=True)
469
out_ref = self.module.custom_autograd_fn_returns_self(x_ref)
470
out_ref.sum().backward()
472
x_test = x_ref.clone().detach().requires_grad_(True)
473
f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self)
474
out_test = f_compiled(x_test)
475
out_test.sum().backward()
477
self.assertEqual(out_ref, out_test)
478
self.assertEqual(x_ref.grad, x_test.grad)
482
@skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket")
483
def test_compile_autograd_function_aliasing(self):
484
x_ref = torch.randn(4, requires_grad=True)
485
out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref)
486
out_ref.sum().backward()
488
x_test = x_ref.clone().detach().requires_grad_(True)
489
f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing)
490
out_test = f_compiled(x_test)
491
out_test.sum().backward()
493
self.assertEqual(out_ref, out_test)
494
self.assertEqual(x_ref.grad, x_test.grad)
496
def test_open_device_scalar_type_fallback(self):
497
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
498
z = torch.triu_indices(3, 3, device="foo")
499
self.assertEqual(z_cpu, z)
501
def test_open_device_tensor_type_fallback(self):
503
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("foo")
504
y = torch.Tensor([1, 0, 2]).to("foo")
506
z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]])
508
device = self.module.custom_device()
509
self.assertTrue(x.device == device)
510
self.assertFalse(x.is_cpu)
514
self.assertEqual(z_cpu, z)
517
z_cpu = torch.Tensor([3, 1])
518
y = torch.Tensor([1, 0]).long().to("foo")
520
self.assertEqual(z_cpu, z)
522
def test_open_device_tensorlist_type_fallback(self):
524
v_foo = torch.Tensor([1, 2, 3]).to("foo")
526
z_cpu = torch.Tensor([2, 4, 6])
531
device = self.module.custom_device()
532
self.assertTrue(v_foo.device == device)
533
self.assertFalse(v_foo.is_cpu)
536
z = torch._foreach_add(x, y)
537
self.assertEqual(z_cpu, z[0])
538
self.assertEqual(z_cpu, z[1])
541
self.module.fallback_with_undefined_tensor()
543
def test_open_device_numpy_serialization(self):
544
torch.utils.rename_privateuse1_backend("foo")
545
device = self.module.custom_device()
546
default_protocol = torch.serialization.DEFAULT_PROTOCOL
548
with patch.object(torch._C, "_has_storage", return_value=False):
549
x = torch.randn(2, 3)
552
rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0]
554
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
557
with TemporaryFileName() as f:
561
np.core.multiarray._reconstruct,
565
type(np.dtype(np.float32))
566
if np.__version__ < "1.25.0"
567
else np.dtypes.Float32DType,
570
sd_loaded = torch.load(f, map_location="cpu")
571
self.assertTrue(sd_loaded["x"].is_cpu)
574
with TemporaryFileName() as f:
575
with self.assertRaisesRegex(
577
"Cannot serialize tensors on backends with no storage under skip_data context manager",
579
with torch.serialization.skip_data():
583
if __name__ == "__main__":