10
import torch.xpu._gpu_trace as gpu_trace
11
from torch.testing._internal.autocast_test_lists import AutocastTestLists
12
from torch.testing._internal.common_device_type import (
13
instantiate_device_type_tests,
18
from torch.testing._internal.common_methods_invocations import ops_and_refs
19
from torch.testing._internal.common_utils import (
27
from torch.utils.checkpoint import checkpoint_sequential
31
print("XPU not available, skipping tests", file=sys.stderr)
34
TEST_MULTIXPU = torch.xpu.device_count() > 1
36
cpu_device = torch.device("cpu")
37
xpu_device = torch.device("xpu")
39
any_common_cpu_xpu_one = OpDTypes.any_common_cpu_cuda_one
40
_xpu_computation_op_list = [
56
_xpu_tensor_factory_op_list = [
61
_xpu_not_test_dtype_op_list = [
66
_xpu_all_op_list = _xpu_computation_op_list + _xpu_tensor_factory_op_list
67
_xpu_all_ops = [op for op in ops_and_refs if op.name in _xpu_all_op_list]
68
_xpu_computation_ops = [
69
op for op in ops_and_refs if op.name in _xpu_computation_op_list
73
class TestXpu(TestCase):
74
def test_device_behavior(self):
75
current_device = torch.xpu.current_device()
76
torch.xpu.set_device(current_device)
77
self.assertEqual(current_device, torch.xpu.current_device())
79
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
80
def test_multi_device_behavior(self):
81
current_device = torch.xpu.current_device()
82
target_device = (current_device + 1) % torch.xpu.device_count()
84
with torch.xpu.device(target_device):
85
self.assertEqual(target_device, torch.xpu.current_device())
86
self.assertEqual(current_device, torch.xpu.current_device())
88
with torch.xpu._DeviceGuard(target_device):
89
self.assertEqual(target_device, torch.xpu.current_device())
90
self.assertEqual(current_device, torch.xpu.current_device())
92
def test_get_device_properties(self):
93
current_device = torch.xpu.current_device()
94
device_properties = torch.xpu.get_device_properties(current_device)
95
self.assertEqual(device_properties, torch.xpu.get_device_properties(None))
96
self.assertEqual(device_properties, torch.xpu.get_device_properties())
98
device_name = torch.xpu.get_device_name(current_device)
99
self.assertEqual(device_name, torch.xpu.get_device_name(None))
100
self.assertEqual(device_name, torch.xpu.get_device_name())
102
device_capability = torch.xpu.get_device_capability(current_device)
103
self.assertTrue(device_capability["max_work_group_size"] > 0)
104
self.assertTrue(device_capability["max_num_sub_groups"] > 0)
106
device_properties.driver_version, device_capability["driver_version"]
108
self.assertEqual(device_properties.has_fp16, device_capability["has_fp16"])
109
self.assertEqual(device_properties.has_fp64, device_capability["has_fp64"])
111
device_properties.has_atomic64, device_capability["has_atomic64"]
114
device_properties.has_bfloat16_conversions,
115
device_capability["has_bfloat16_conversions"],
118
device_properties.has_subgroup_matrix_multiply_accumulate,
119
device_capability["has_subgroup_matrix_multiply_accumulate"],
122
device_properties.has_subgroup_matrix_multiply_accumulate_tensor_float32,
123
device_capability["has_subgroup_matrix_multiply_accumulate_tensor_float32"],
126
device_properties.has_subgroup_2d_block_io,
127
device_capability["has_subgroup_2d_block_io"],
130
def test_wrong_xpu_fork(self):
131
stderr = TestCase.runWithPytorchAPIUsageStderr(
134
from torch.multiprocessing import Process
136
torch.xpu.set_device(rank)
137
if __name__ == "__main__":
140
for rank in range(size):
141
# it would work fine without the line below
142
torch.xpu.set_device(0)
143
p = Process(target=run, args=(rank,))
150
self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
152
def test_lazy_init(self):
153
"""Validate that no XPU calls are made during `import torch` call"""
155
def check_output(script: str) -> str:
157
subprocess.check_output([sys.executable, "-c", script])
164
from torch.multiprocessing import Process
167
def run_model(model, input):
168
input_xpu = input.clone().to('xpu')
169
model_xpu = copy.deepcopy(model).to('xpu')
170
loss_xpu = model_xpu(input_xpu).sum()
171
loss = model(input).sum()
172
torch.testing.assert_close(loss_xpu.cpu(), loss)
174
def test_multi_process(model, input):
175
p = Process(target=run_model, args=(model, input))
178
assert p.exitcode == 0
180
input = torch.rand(32, 3, 224, 224)
181
model = torch.nn.Sequential(
182
torch.nn.Conv2d(3, 64, 3, stride=2),
184
torch.nn.MaxPool2d(2, 2),
186
test_multi_process(model, input)
187
test_multi_process(model, input)
188
print(torch.xpu.device_count())
190
rc = check_output(test_script)
191
self.assertEqual(rc, str(torch.xpu.device_count()))
193
def test_streams(self):
194
s0 = torch.xpu.Stream()
195
torch.xpu.set_stream(s0)
196
s1 = torch.xpu.current_stream()
197
self.assertEqual(s0, s1)
198
s2 = torch.xpu.Stream()
199
self.assertFalse(s0 == s2)
200
torch.xpu.set_stream(s2)
201
with torch.xpu.stream(s0):
202
self.assertEqual(s0, torch.xpu.current_stream())
203
self.assertEqual(s2, torch.xpu.current_stream())
205
def test_stream_priority(self):
206
low, high = torch.xpu.Stream.priority_range()
207
s0 = torch.xpu.Stream(device=0, priority=low)
209
self.assertEqual(low, s0.priority)
210
self.assertEqual(torch.device("xpu:0"), s0.device)
212
s1 = torch.xpu.Stream(device=0, priority=high)
214
self.assertEqual(high, s1.priority)
215
self.assertEqual(torch.device("xpu:0"), s1.device)
217
def test_stream_event_repr(self):
218
s = torch.xpu.current_stream()
219
self.assertTrue("torch.xpu.Stream" in str(s))
220
e = torch.xpu.Event()
221
self.assertTrue("torch.xpu.Event(uninitialized)" in str(e))
223
self.assertTrue("torch.xpu.Event" in str(e))
225
def test_events(self):
226
stream = torch.xpu.current_stream()
227
event = torch.xpu.Event()
228
self.assertTrue(event.query())
229
stream.record_event(event)
231
self.assertTrue(event.query())
233
def test_generic_stream_event(self):
234
stream = torch.Stream("xpu")
235
self.assertEqual(stream.device_index, torch.xpu.current_device())
236
xpu_stream = torch.xpu.Stream(
237
stream_id=stream.stream_id,
238
device_index=stream.device_index,
239
device_type=stream.device_type,
241
self.assertEqual(stream.stream_id, xpu_stream.stream_id)
242
self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
244
event1 = torch.Event("xpu")
245
event2 = torch.Event("xpu")
246
self.assertEqual(event1.event_id, 0)
247
a = torch.randn(1000)
248
b = torch.randn(1000)
249
with torch.xpu.stream(xpu_stream):
250
a_xpu = a.to("xpu", non_blocking=True)
251
b_xpu = b.to("xpu", non_blocking=True)
252
self.assertEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
253
event1.record(stream)
255
self.assertTrue(event1.query())
256
c_xpu = a_xpu + b_xpu
259
self.assertTrue(event2.query())
260
self.assertNotEqual(event1.event_id, event2.event_id)
261
self.assertEqual(c_xpu.cpu(), a + b)
262
with self.assertRaisesRegex(
263
NotImplementedError, "elapsedTime is not supported by XPU backend."
265
event1.elapsed_time(event2)
267
def test_generator(self):
268
torch.manual_seed(2024)
269
g_state0 = torch.xpu.get_rng_state()
270
torch.manual_seed(1234)
271
g_state1 = torch.xpu.get_rng_state()
272
self.assertNotEqual(g_state0, g_state1)
274
torch.xpu.manual_seed(2024)
275
g_state2 = torch.xpu.get_rng_state()
276
self.assertEqual(g_state0, g_state2)
278
torch.xpu.set_rng_state(g_state1)
279
self.assertEqual(g_state1, torch.xpu.get_rng_state())
281
torch.manual_seed(1234)
282
torch.xpu.set_rng_state(g_state0)
283
self.assertEqual(2024, torch.xpu.initial_seed())
287
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
288
def test_compare_cpu(self, device, dtype, op):
290
if isinstance(arg, torch.Tensor):
291
return arg.to(device="cpu")
294
samples = op.reference_inputs(device, dtype)
296
for sample in samples:
297
cpu_sample = sample.transform(to_cpu)
298
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
299
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
301
xpu_results = sample.output_process_fn_grad(xpu_results)
302
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
306
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
309
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
310
@unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
311
def test_non_standard_bool_values(self, device, dtype, op):
313
def convert_boolean_tensors(x):
314
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
318
true_vals = torch.randint(
319
2, 255, x.shape, dtype=torch.uint8, device=x.device
321
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
322
x_int = torch.where(x, true_vals, false_vals)
324
ret = x_int.view(torch.bool)
325
self.assertEqual(ret, x)
328
for sample in op.sample_inputs(device, dtype):
329
expect = op(sample.input, *sample.args, **sample.kwargs)
331
transformed = sample.transform(convert_boolean_tensors)
332
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
334
self.assertEqual(expect, actual)
336
def test_serialization_array_with_storage(self):
337
x = torch.randn(5, 5).xpu()
338
y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
339
q = [x, y, x, y.storage()]
340
with tempfile.NamedTemporaryFile() as f:
343
q_copy = torch.load(f)
344
self.assertEqual(q_copy, q, atol=0, rtol=0)
346
self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
347
self.assertEqual(q_copy[0].dtype, torch.float)
348
self.assertEqual(q_copy[1].dtype, torch.int)
349
self.assertEqual(q_copy[2].dtype, torch.float)
350
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
351
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
354
self.assertEqual(q_copy[3], y.storage())
356
def test_serialization_array_with_empty(self):
358
torch.randn(4, 4).xpu(),
359
torch.tensor([], dtype=torch.float, device=torch.device("xpu")),
361
with tempfile.NamedTemporaryFile() as f:
364
x_copy = torch.load(f)
365
for original, copy in zip(x, x_copy):
366
self.assertEqual(copy, original)
367
self.assertIs(type(copy), type(original))
368
self.assertEqual(copy.get_device(), original.get_device())
371
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
374
class TestXpuAutocast(TestCase):
378
skip_list = ["gru_cell"]
382
self.autocast_lists = AutocastTestLists(torch.device("xpu"))
385
del self.autocast_lists
388
def _run_autocast_outofplace(
389
self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
392
def cast(val, to_type):
393
if isinstance(val, torch.Tensor):
394
return val.to(to_type) if val.is_floating_point() else val
395
elif isinstance(val, collections.abc.Iterable):
396
return type(val)(cast(v, to_type) for v in val)
400
if add_kwargs is None:
402
fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
403
self.assertFalse(torch.is_autocast_enabled("xpu"))
404
with torch.amp.autocast("xpu", dtype=fast_dtype):
405
self.assertTrue(torch.is_autocast_enabled("xpu"))
407
out_type = out_type if out_type is not None else run_as_type
408
output = output_method = None
411
if module is not None and hasattr(module, op):
412
output = getattr(module, op)(*args, **add_kwargs)
413
if isinstance(output, torch.Tensor):
415
out_type == output.dtype,
416
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
420
if hasattr(torch.Tensor, op):
421
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
422
if isinstance(output_method, torch.Tensor):
424
out_type == output_method.dtype,
425
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
429
(output is not None) or (output_method is not None),
430
f"{op} not found as an attribute on either Tensor or the requested module {module}",
435
def compare(first, second):
436
if isinstance(first, torch.Tensor):
437
return torch.equal(first, second)
438
elif isinstance(first, collections.abc.Iterable):
439
return all(compare(f, s) for f, s in zip(first, second))
441
return first == second
444
if (output is not None) and (output_method is not None):
445
self.assertTrue(type(output) == type(output_method))
446
comparison = compare(output, output_method)
448
comparison, f"torch.{op} result did not match Tensor.{op} result"
453
output_to_compare = output if output is not None else output_method
454
with torch.amp.autocast("xpu", enabled=False):
455
self.assertFalse(torch.is_autocast_enabled("xpu"))
457
if module is not None and hasattr(module, op):
458
control = getattr(module, op)(
459
*cast(args, run_as_type), **add_kwargs
462
control = getattr(args[0].to(run_as_type), op)(
463
*cast(args[1:], run_as_type), **add_kwargs
465
self.assertTrue(type(output_to_compare) == type(control))
466
comparison = compare(output_to_compare, control)
467
self.assertTrue(comparison, f"torch.{op} result did not match control")
468
self.assertTrue(torch.is_autocast_enabled("xpu"))
469
self.assertFalse(torch.is_autocast_enabled("xpu"))
471
def test_autocast_torch_fp16(self):
472
for op_with_args in self.autocast_lists.torch_fp16:
474
op, args = op_with_args[0], op_with_args[1]
475
if op in self.skip_list:
477
if len(op_with_args) == 3:
480
self._run_autocast_outofplace(op, args, torch.float16)
482
def test_autocast_torch_bf16(self):
483
for op_with_args in self.autocast_lists.torch_fp16:
485
op, args = op_with_args[0], op_with_args[1]
486
if op in self.skip_list:
488
if len(op_with_args) == 3:
491
self._run_autocast_outofplace(op, args, torch.bfloat16)
493
def test_autocast_torch_need_autocast_promote(self):
494
for op, args in self.autocast_lists.torch_need_autocast_promote:
495
self._run_autocast_outofplace(op, args, torch.float32)
497
def test_autocast_torch_expect_builtin_promote(self):
498
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
499
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
501
def test_autocast_checkpointing(self):
502
model = torch.nn.Sequential(
503
torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
506
(8, 8), device="xpu", dtype=torch.float16, requires_grad=True
508
for reentrant in (True, False):
509
with torch.autocast("xpu"):
510
output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
511
self.assertTrue(output.requires_grad)
512
self.assertTrue(output.dtype is torch.float16)
513
output.sum().backward()
515
def test_xpu_autocast_dtype(self):
516
dtype = torch.get_autocast_dtype("xpu")
517
self.assertEqual(dtype, torch.float16)
518
mat0_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
519
mat1_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
520
with torch.amp.autocast("xpu"):
521
result = torch.mm(mat0_fp32, mat1_fp32)
522
self.assertEqual(result.dtype, torch.float16)
525
class TestXpuTrace(TestCase):
527
torch._C._activate_gpu_trace()
528
self.mock = unittest.mock.MagicMock()
530
def test_event_creation_callback(self):
531
gpu_trace.register_callback_for_event_creation(self.mock)
533
event = torch.xpu.Event()
535
self.mock.assert_called_once_with(event._as_parameter_.value)
537
def test_event_deletion_callback(self):
538
gpu_trace.register_callback_for_event_deletion(self.mock)
540
event = torch.xpu.Event()
542
event_id = event._as_parameter_.value
544
self.mock.assert_called_once_with(event_id)
546
def test_event_record_callback(self):
547
gpu_trace.register_callback_for_event_record(self.mock)
549
event = torch.xpu.Event()
551
self.mock.assert_called_once_with(
552
event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
555
def test_event_wait_callback(self):
556
gpu_trace.register_callback_for_event_wait(self.mock)
558
event = torch.xpu.Event()
561
self.mock.assert_called_once_with(
562
event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
565
def test_device_synchronization_callback(self):
566
gpu_trace.register_callback_for_device_synchronization(self.mock)
568
torch.xpu.synchronize()
569
self.mock.assert_called()
571
def test_stream_synchronization_callback(self):
572
gpu_trace.register_callback_for_stream_synchronization(self.mock)
574
stream = torch.xpu.Stream()
576
self.mock.assert_called_once_with(stream.sycl_queue)
578
def test_event_synchronization_callback(self):
579
gpu_trace.register_callback_for_event_synchronization(self.mock)
581
event = torch.xpu.Event()
584
self.mock.assert_called_once_with(event._as_parameter_.value)
587
if __name__ == "__main__":