1
# Owner(s): ["module: cpp-extensions"]
6
from typing import Union
10
import torch.testing._internal.common_utils as common
11
from torch.testing._internal.common_utils import IS_ARM64, TEST_CUDA
13
import torch.utils.cpp_extension
14
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
17
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
18
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
21
def remove_build_path():
22
if sys.platform == "win32":
23
# Not wiping extensions build folder because Windows
25
default_build_root = torch.utils.cpp_extension.get_default_build_root()
26
if os.path.exists(default_build_root):
27
shutil.rmtree(default_build_root, ignore_errors=True)
33
def device_count() -> int:
37
def get_rng_state(device: Union[int, str, torch.device] = 'foo') -> torch.Tensor:
38
# create a tensor using our custom device object.
39
return torch.empty(4, 4, device="foo")
42
def set_rng_state(new_state: torch.Tensor, device: Union[int, str, torch.device] = 'foo') -> None:
53
@unittest.skipIf(IS_ARM64, "Does not work on arm")
54
@torch.testing._internal.common_utils.markDynamoStrictTest
55
class TestCppExtensionOpenRgistration(common.TestCase):
56
"""Tests Open Device Registration with C++ extensions.
62
# cpp extensions use relative paths. Those paths are relative to
63
# this file, so we'll change the working directory temporarily
64
self.old_working_dir = os.getcwd()
65
os.chdir(os.path.dirname(os.path.abspath(__file__)))
66
assert self.module is not None
70
# return the working directory (see setUp)
71
os.chdir(self.old_working_dir)
76
cls.module = torch.utils.cpp_extension.load(
77
name="custom_device_extension",
79
"cpp_extensions/open_registration_extension.cpp",
81
extra_include_paths=["cpp_extensions"],
87
def tearDownClass(cls):
90
def test_open_device_registration(self):
91
def test_base_device_registration():
92
torch.utils.rename_privateuse1_backend('foo')
93
self.assertFalse(self.module.custom_add_called())
94
# create a tensor using our custom device object
95
device = self.module.custom_device()
96
x = torch.empty(4, 4, device=device)
97
y = torch.empty(4, 4, device=device)
98
# Check that our device is correct.
99
self.assertTrue(x.device == device)
100
self.assertFalse(x.is_cpu)
101
self.assertFalse(self.module.custom_add_called())
102
# calls out custom add kernel, registered to the dispatcher
104
# check that it was called
105
self.assertTrue(self.module.custom_add_called())
106
z_cpu = z.to(device='cpu')
107
# Check that our cross-device copy correctly copied the data to cpu
108
self.assertTrue(z_cpu.is_cpu)
109
self.assertFalse(z.is_cpu)
110
self.assertTrue(z.device == device)
111
self.assertEqual(z, z_cpu)
114
# check whether the error can be reported correctly
115
def test_before_common_registration():
116
# check that register module name should be the same as custom backend
117
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
118
torch._register_device_module('xxx', DummyModule)
119
# check generator registered before using
120
torch.utils.rename_privateuse1_backend('foo')
121
with self.assertRaisesRegex(RuntimeError, "torch has no module of"):
122
with torch.random.fork_rng(device_type="foo"):
124
# check attributes before registered
125
self.assertFalse(hasattr(torch.Tensor, 'is_foo'))
126
self.assertFalse(hasattr(torch.Tensor, 'foo'))
127
self.assertFalse(hasattr(torch.TypedStorage, 'is_foo'))
128
self.assertFalse(hasattr(torch.TypedStorage, 'foo'))
129
self.assertFalse(hasattr(torch.UntypedStorage, 'is_foo'))
130
self.assertFalse(hasattr(torch.UntypedStorage, 'foo'))
131
self.assertFalse(hasattr(torch.nn.Module, 'foo'))
133
def test_after_common_registration():
134
# check attributes after registered
135
self.assertTrue(hasattr(torch.Tensor, 'is_foo'))
136
self.assertTrue(hasattr(torch.Tensor, 'foo'))
137
self.assertTrue(hasattr(torch.TypedStorage, 'is_foo'))
138
self.assertTrue(hasattr(torch.TypedStorage, 'foo'))
139
self.assertTrue(hasattr(torch.UntypedStorage, 'is_foo'))
140
self.assertTrue(hasattr(torch.UntypedStorage, 'foo'))
141
self.assertTrue(hasattr(torch.nn.Module, 'foo'))
143
def test_common_registration():
144
# first rename custom backend
145
torch.utils.rename_privateuse1_backend('foo')
146
# backend name can only rename once
147
with self.assertRaisesRegex(RuntimeError, "torch.register_privateuse1_backend()"):
148
torch.utils.rename_privateuse1_backend('xxx')
149
# register foo module, torch.foo
150
torch._register_device_module('foo', DummyModule)
151
self.assertTrue(torch.utils.backend_registration._get_custom_mod_func("device_count")() == 1)
152
with self.assertRaisesRegex(RuntimeError, "Try to call torch.foo"):
153
torch.utils.backend_registration._get_custom_mod_func("func_name_")
154
# default set for_tensor and for_module are True, so only set for_storage is True
155
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
156
# generator tensor and module can be registered only once
157
with self.assertRaisesRegex(RuntimeError, "The custom device module of"):
158
torch.utils.generate_methods_for_privateuse1_backend()
160
def test_open_device_generator_registration_and_hooks():
161
device = self.module.custom_device()
162
# None of our CPU operations should call the custom add function.
163
self.assertFalse(self.module.custom_add_called())
164
# check generator registered before using
165
with self.assertRaisesRegex(RuntimeError,
166
"Please register a generator to the PrivateUse1 dispatch key"):
167
gen_ = torch.Generator(device=device)
168
self.module.register_generator_first()
169
gen = torch.Generator(device=device)
170
self.assertTrue(gen.device == device)
171
# generator can be registered only once
172
with self.assertRaisesRegex(RuntimeError,
173
"Only can register a generator to the PrivateUse1 dispatch key once"):
174
self.module.register_generator_second()
175
self.module.register_hook()
176
default_gen = self.module.default_generator(0)
177
self.assertTrue(default_gen.device.type == torch._C._get_privateuse1_backend_name())
179
def test_open_device_dispatchstub():
180
# test kernels could be reused by privateuse1 backend through dispatchstub
181
torch.utils.rename_privateuse1_backend('foo')
182
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu")
183
foo_input_data = input_data.to("foo")
184
self.assertFalse(self.module.custom_abs_called())
185
torch.abs(foo_input_data)
186
self.assertTrue(self.module.custom_abs_called())
188
def test_open_device_quantized():
189
torch.utils.rename_privateuse1_backend('foo')
190
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to("foo")
191
quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
192
self.assertEqual(quantized_tensor.device, torch.device('foo:0'))
193
self.assertEqual(quantized_tensor.dtype, torch.qint8)
195
def test_open_device_random():
196
with torch.random.fork_rng(device_type="foo"):
199
def test_open_device_tensor():
200
device = self.module.custom_device()
201
# check whether print tensor.type() meets the expectation
203
torch.bool: 'torch.foo.BoolTensor',
204
torch.double: 'torch.foo.DoubleTensor',
205
torch.float32: 'torch.foo.FloatTensor',
206
torch.half: 'torch.foo.HalfTensor',
207
torch.int32: 'torch.foo.IntTensor',
208
torch.int64: 'torch.foo.LongTensor',
209
torch.int8: 'torch.foo.CharTensor',
210
torch.short: 'torch.foo.ShortTensor',
211
torch.uint8: 'torch.foo.ByteTensor',
213
for tt, dt in dtypes.items():
214
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
215
self.assertTrue(test_tensor.type() == dt)
216
# check whether the attributes and methods of the corresponding custom backend are generated correctly
217
x = torch.empty(4, 4)
218
self.assertFalse(x.is_foo)
219
x = x.foo(torch.device("foo"))
220
self.assertFalse(self.module.custom_add_called())
221
self.assertTrue(x.is_foo)
222
# test different device type input
223
y = torch.empty(4, 4)
224
self.assertFalse(y.is_foo)
225
y = y.foo(torch.device("foo:0"))
226
self.assertFalse(self.module.custom_add_called())
227
self.assertTrue(y.is_foo)
228
# test different device type input
229
z = torch.empty(4, 4)
230
self.assertFalse(z.is_foo)
232
self.assertFalse(self.module.custom_add_called())
233
self.assertTrue(z.is_foo)
235
def test_open_device_storage():
236
# check whether the attributes and methods for storage of the corresponding custom backend are generated correctly
237
x = torch.empty(4, 4)
239
self.assertFalse(z1.is_foo)
241
self.assertFalse(self.module.custom_add_called())
242
self.assertTrue(z1.is_foo)
243
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
244
z1.foo(torch.device("cpu"))
246
self.assertFalse(self.module.custom_add_called())
247
self.assertFalse(z1.is_foo)
248
z1 = z1.foo(device="foo:0", non_blocking=False)
249
self.assertFalse(self.module.custom_add_called())
250
self.assertTrue(z1.is_foo)
251
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
252
z1.foo(device="cuda:0", non_blocking=False)
253
# check UntypedStorage
254
y = torch.empty(4, 4)
255
z2 = y.untyped_storage()
256
self.assertFalse(z2.is_foo)
258
self.assertFalse(self.module.custom_add_called())
259
self.assertTrue(z2.is_foo)
260
# check custom StorageImpl create
261
self.module.custom_storage_registry()
262
z3 = y.untyped_storage()
263
self.assertFalse(self.module.custom_storageImpl_called())
265
self.assertTrue(self.module.custom_storageImpl_called())
267
def test_open_device_storage_pin_memory():
268
torch.utils.rename_privateuse1_backend('foo')
269
with self.assertRaisesRegex(RuntimeError, "The custom device module of"):
270
torch.utils.generate_methods_for_privateuse1_backend(for_tensor=False, for_module=False, for_storage=True)
271
# Check if the pin_memory is functioning properly on custom device
272
cpu_tensor = torch.empty(3)
273
self.assertFalse(cpu_tensor.is_foo)
274
self.assertFalse(cpu_tensor.is_pinned("foo"))
275
cpu_tensor_pin = cpu_tensor.pin_memory("foo")
276
self.assertTrue(cpu_tensor_pin.is_pinned("foo"))
277
# Test storage pin_memory on custom device string
278
cpu_storage = cpu_tensor.storage()
279
foo_device = torch.device("foo")
280
self.assertFalse(cpu_storage.is_pinned("foo"))
281
cpu_storage_pin = cpu_storage.pin_memory("foo")
282
self.assertFalse(cpu_storage.is_pinned())
283
self.assertFalse(cpu_storage.is_pinned("foo"))
284
self.assertFalse(cpu_storage.is_pinned(foo_device))
285
self.assertFalse(cpu_storage_pin.is_pinned())
286
self.assertTrue(cpu_storage_pin.is_pinned("foo"))
287
self.assertTrue(cpu_storage_pin.is_pinned(foo_device))
288
cpu_storage_pin_already = cpu_storage_pin.pin_memory("foo")
289
self.assertTrue(cpu_storage_pin.is_pinned("foo"))
290
self.assertTrue(cpu_storage_pin.is_pinned(foo_device))
291
self.assertTrue(cpu_storage_pin_already.is_pinned("foo"))
292
self.assertTrue(cpu_storage_pin_already.is_pinned(foo_device))
294
# Test storage pin_memory on torch.device
295
self.assertFalse(cpu_storage.is_pinned("foo"))
296
cpu_storage_pinned = cpu_storage.pin_memory(foo_device)
297
self.assertFalse(cpu_storage.is_pinned())
298
self.assertFalse(cpu_storage.is_pinned("foo"))
299
self.assertFalse(cpu_storage.is_pinned(foo_device))
300
self.assertFalse(cpu_storage_pinned.is_pinned())
301
self.assertTrue(cpu_storage_pinned.is_pinned("foo"))
302
self.assertTrue(cpu_storage_pinned.is_pinned(foo_device))
304
# Test untyped storage pin_memory and is_pin
305
cpu_tensor = torch.randn([3, 2, 1, 4])
306
cpu_untyped_storage = cpu_tensor.untyped_storage()
307
self.assertFalse(cpu_untyped_storage.is_pinned())
308
self.assertFalse(cpu_untyped_storage.is_pinned("foo"))
309
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("foo")
310
self.assertFalse(cpu_untyped_storage_pinned.is_pinned())
311
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo"))
312
self.assertTrue(cpu_untyped_storage_pinned.is_pinned(foo_device))
313
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory(foo_device)
314
self.assertFalse(cpu_untyped_storage_pinned.is_pinned())
315
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo"))
316
self.assertTrue(cpu_untyped_storage_pinned.is_pinned(foo_device))
317
with self.assertRaisesRegex(TypeError, "positional arguments but 3 were given"):
318
cpu_untyped_storage_pinned.is_pinned("foo1", "foo2")
320
# Test storage pin_memory on error device
321
self.assertFalse(cpu_storage_pinned.is_pinned("hpu"))
322
with self.assertRaisesRegex(NotImplementedError, "with arguments from the 'HPU' backend"):
323
cpu_storage.pin_memory("hpu")
324
self.assertFalse(cpu_untyped_storage_pinned.is_pinned("hpu"))
325
with self.assertRaisesRegex(NotImplementedError, "with arguments from the 'HPU' backend"):
326
cpu_untyped_storage.pin_memory("hpu")
327
invalid_device = torch.device("hpu")
328
self.assertFalse(cpu_untyped_storage_pinned.is_pinned(invalid_device))
329
with self.assertRaisesRegex(NotImplementedError, "with arguments from the 'HPU' backend"):
330
cpu_untyped_storage.pin_memory(invalid_device)
332
def test_open_device_serialization():
333
self.module.set_custom_device_index(-1)
334
storage = torch.UntypedStorage(4, device=torch.device('foo'))
335
self.assertEqual(torch.serialization.location_tag(storage), 'foo')
337
self.module.set_custom_device_index(0)
338
storage = torch.UntypedStorage(4, device=torch.device('foo'))
339
self.assertEqual(torch.serialization.location_tag(storage), 'foo:0')
341
cpu_storage = torch.empty(4, 4).storage()
342
foo_storage = torch.serialization.default_restore_location(cpu_storage, 'foo:0')
343
self.assertTrue(foo_storage.is_foo)
344
# test tensor MetaData serialization
345
x = torch.empty(4, 4).long()
347
self.assertFalse(self.module.check_backend_meta(y))
348
self.module.custom_set_backend_meta(y)
349
self.assertTrue(self.module.check_backend_meta(y))
351
self.module.custom_serialization_registry()
352
with tempfile.TemporaryDirectory() as tmpdir:
353
path = os.path.join(tmpdir, 'data.pt')
355
z1 = torch.load(path)
356
# loads correctly onto the foo backend device
357
self.assertTrue(z1.is_foo)
358
# loads BackendMeta data correctly
359
self.assertTrue(self.module.check_backend_meta(z1))
361
z2 = torch.load(path, map_location='cpu')
362
# loads correctly onto the cpu backend device
363
self.assertFalse(z2.is_foo)
364
# loads BackendMeta data correctly
365
self.assertFalse(self.module.check_backend_meta(z2))
367
def test_open_device_storage_resize():
368
torch.utils.rename_privateuse1_backend('foo')
369
cpu_tensor = torch.randn([8])
370
foo_tensor = cpu_tensor.foo()
371
foo_storage = foo_tensor.storage()
372
self.assertTrue(foo_storage.size() == 8)
373
foo_storage.resize_(8)
374
self.assertTrue(foo_storage.size() == 8)
375
with self.assertRaisesRegex(RuntimeError, 'Overflow'):
376
foo_storage.resize_(8**29)
378
def test_open_device_storage_type():
379
torch.utils.rename_privateuse1_backend('foo')
380
# test cpu float storage
381
cpu_tensor = torch.randn([8]).float()
382
cpu_storage = cpu_tensor.storage()
383
self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
385
# test custom float storage before defining FloatStorage
386
foo_tensor = cpu_tensor.foo()
387
foo_storage = foo_tensor.storage()
388
self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage")
390
class CustomFloatStorage:
392
def __module__(self):
393
return "torch." + torch._C._get_privateuse1_backend_name()
397
return "FloatStorage"
399
# test custom float storage after defining FloatStorage
401
torch.foo.FloatStorage = CustomFloatStorage()
402
self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage")
404
# test custom int storage after defining FloatStorage
405
foo_tensor2 = torch.randn([8]).int().foo()
406
foo_storage2 = foo_tensor2.storage()
407
self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage")
409
torch.foo.FloatStorage = None
411
def test_open_device_faketensor():
412
torch.utils.rename_privateuse1_backend('foo')
413
with torch._subclasses.fake_tensor.FakeTensorMode.push():
414
a = torch.empty(1, device="foo")
415
b = torch.empty(1, device="foo:0")
418
def test_open_device_named_tensor():
419
torch.utils.rename_privateuse1_backend('foo')
420
a = torch.empty([2, 3, 4, 5], device="foo", names=["N", "C", "H", "W"])
422
# Not an open registration test - this file is just very convenient
423
# for testing torch.compile on custom C++ operators
424
def test_compile_autograd_function_returns_self():
425
x_ref = torch.randn(4, requires_grad=True)
426
out_ref = self.module.custom_autograd_fn_returns_self(x_ref)
427
out_ref.sum().backward()
429
x_test = x_ref.clone().detach().requires_grad_(True)
430
f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self)
431
out_test = f_compiled(x_test)
432
out_test.sum().backward()
434
self.assertEqual(out_ref, out_test)
435
self.assertEqual(x_ref.grad, x_test.grad)
437
# Not an open registration test - this file is just very convenient
438
# for testing torch.compile on custom C++ operators
439
def test_compile_autograd_function_aliasing():
440
x_ref = torch.randn(4, requires_grad=True)
441
out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref)
442
out_ref.sum().backward()
444
x_test = x_ref.clone().detach().requires_grad_(True)
445
f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing)
446
out_test = f_compiled(x_test)
447
out_test.sum().backward()
449
self.assertEqual(out_ref, out_test)
450
self.assertEqual(x_ref.grad, x_test.grad)
452
def test_open_device_tensor_type_fallback():
453
torch.utils.rename_privateuse1_backend('foo')
454
# create tensors located in custom device
455
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to('foo')
456
y = torch.Tensor([1, 0, 2]).to('foo')
457
# create result tensor located in cpu
458
z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]])
459
# Check that our device is correct.
460
device = self.module.custom_device()
461
self.assertTrue(x.device == device)
462
self.assertFalse(x.is_cpu)
463
# call sub op, which will fallback to cpu
465
self.assertEqual(z_cpu, z)
466
# call index op, which will fallback to cpu
467
z_cpu = torch.Tensor([3, 1])
468
y = torch.Tensor([1, 0]).long().to('foo')
470
self.assertEqual(z_cpu, z)
472
def test_open_device_tensorlist_type_fallback():
473
torch.utils.rename_privateuse1_backend('foo')
474
# create tensors located in custom device
475
v_foo = torch.Tensor([1, 2, 3]).to('foo')
476
# create result tensor located in cpu
477
z_cpu = torch.Tensor([2, 4, 6])
478
# create tensorlist for foreach_add op
481
# Check that our device is correct.
482
device = self.module.custom_device()
483
self.assertTrue(v_foo.device == device)
484
self.assertFalse(v_foo.is_cpu)
485
# call _foreach_add op, which will fallback to cpu
486
z = torch._foreach_add(x, y)
488
self.assertEqual(z_cpu, z[0])
489
self.assertEqual(z_cpu, z[1])
491
test_base_device_registration()
492
test_before_common_registration()
493
test_common_registration()
494
test_after_common_registration()
495
test_open_device_generator_registration_and_hooks()
496
test_open_device_dispatchstub()
497
test_open_device_random()
498
test_open_device_tensor()
499
test_open_device_storage()
500
test_open_device_storage_pin_memory()
501
test_open_device_serialization()
502
test_open_device_storage_resize()
503
test_open_device_storage_type()
504
test_open_device_faketensor()
505
test_open_device_named_tensor()
506
test_open_device_quantized()
508
test_compile_autograd_function_returns_self()
509
test_compile_autograd_function_aliasing()
511
test_open_device_tensor_type_fallback()
512
test_open_device_tensorlist_type_fallback()
515
if __name__ == "__main__":