pytorch

Форк
0
/
test_xpu.py 
588 строк · 21.6 Кб
1
# Owner(s): ["module: intel"]
2

3
import collections
4
import subprocess
5
import sys
6
import tempfile
7
import unittest
8

9
import torch
10
import torch.xpu._gpu_trace as gpu_trace
11
from torch.testing._internal.autocast_test_lists import AutocastTestLists
12
from torch.testing._internal.common_device_type import (
13
    instantiate_device_type_tests,
14
    onlyXPU,
15
    OpDTypes,
16
    ops,
17
)
18
from torch.testing._internal.common_methods_invocations import ops_and_refs
19
from torch.testing._internal.common_utils import (
20
    NoTest,
21
    run_tests,
22
    suppress_warnings,
23
    TEST_WITH_UBSAN,
24
    TEST_XPU,
25
    TestCase,
26
)
27
from torch.utils.checkpoint import checkpoint_sequential
28

29

30
if not TEST_XPU:
31
    print("XPU not available, skipping tests", file=sys.stderr)
32
    TestCase = NoTest  # noqa: F811
33

34
TEST_MULTIXPU = torch.xpu.device_count() > 1
35

36
cpu_device = torch.device("cpu")
37
xpu_device = torch.device("xpu")
38

39
any_common_cpu_xpu_one = OpDTypes.any_common_cpu_cuda_one
40
_xpu_computation_op_list = [
41
    "fill",
42
    "zeros",
43
    "zeros_like",
44
    "clone",
45
    "view_as_real",
46
    "view_as_complex",
47
    "view",
48
    "resize_",
49
    "resize_as_",
50
    "add",
51
    "sub",
52
    "mul",
53
    "div",
54
    "abs",
55
]
56
_xpu_tensor_factory_op_list = [
57
    "as_strided",
58
    "empty",
59
    "empty_strided",
60
]
61
_xpu_not_test_dtype_op_list = [
62
    "resize_",  # Skipped by CPU
63
    "resize_as_",  # Skipped by CPU
64
    "abs",  # Not aligned dtype
65
]
66
_xpu_all_op_list = _xpu_computation_op_list + _xpu_tensor_factory_op_list
67
_xpu_all_ops = [op for op in ops_and_refs if op.name in _xpu_all_op_list]
68
_xpu_computation_ops = [
69
    op for op in ops_and_refs if op.name in _xpu_computation_op_list
70
]
71

72

73
class TestXpu(TestCase):
74
    def test_device_behavior(self):
75
        current_device = torch.xpu.current_device()
76
        torch.xpu.set_device(current_device)
77
        self.assertEqual(current_device, torch.xpu.current_device())
78

79
    @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
80
    def test_multi_device_behavior(self):
81
        current_device = torch.xpu.current_device()
82
        target_device = (current_device + 1) % torch.xpu.device_count()
83

84
        with torch.xpu.device(target_device):
85
            self.assertEqual(target_device, torch.xpu.current_device())
86
        self.assertEqual(current_device, torch.xpu.current_device())
87

88
        with torch.xpu._DeviceGuard(target_device):
89
            self.assertEqual(target_device, torch.xpu.current_device())
90
        self.assertEqual(current_device, torch.xpu.current_device())
91

92
    def test_get_device_properties(self):
93
        current_device = torch.xpu.current_device()
94
        device_properties = torch.xpu.get_device_properties(current_device)
95
        self.assertEqual(device_properties, torch.xpu.get_device_properties(None))
96
        self.assertEqual(device_properties, torch.xpu.get_device_properties())
97

98
        device_name = torch.xpu.get_device_name(current_device)
99
        self.assertEqual(device_name, torch.xpu.get_device_name(None))
100
        self.assertEqual(device_name, torch.xpu.get_device_name())
101

102
        device_capability = torch.xpu.get_device_capability(current_device)
103
        self.assertTrue(device_capability["max_work_group_size"] > 0)
104
        self.assertTrue(device_capability["max_num_sub_groups"] > 0)
105
        self.assertEqual(
106
            device_properties.driver_version, device_capability["driver_version"]
107
        )
108
        self.assertEqual(device_properties.has_fp16, device_capability["has_fp16"])
109
        self.assertEqual(device_properties.has_fp64, device_capability["has_fp64"])
110
        self.assertEqual(
111
            device_properties.has_atomic64, device_capability["has_atomic64"]
112
        )
113
        self.assertEqual(
114
            device_properties.has_bfloat16_conversions,
115
            device_capability["has_bfloat16_conversions"],
116
        )
117
        self.assertEqual(
118
            device_properties.has_subgroup_matrix_multiply_accumulate,
119
            device_capability["has_subgroup_matrix_multiply_accumulate"],
120
        )
121
        self.assertEqual(
122
            device_properties.has_subgroup_matrix_multiply_accumulate_tensor_float32,
123
            device_capability["has_subgroup_matrix_multiply_accumulate_tensor_float32"],
124
        )
125
        self.assertEqual(
126
            device_properties.has_subgroup_2d_block_io,
127
            device_capability["has_subgroup_2d_block_io"],
128
        )
129

130
    def test_wrong_xpu_fork(self):
131
        stderr = TestCase.runWithPytorchAPIUsageStderr(
132
            """\
133
import torch
134
from torch.multiprocessing import Process
135
def run(rank):
136
    torch.xpu.set_device(rank)
137
if __name__ == "__main__":
138
    size = 2
139
    processes = []
140
    for rank in range(size):
141
        # it would work fine without the line below
142
        torch.xpu.set_device(0)
143
        p = Process(target=run, args=(rank,))
144
        p.start()
145
        processes.append(p)
146
    for p in processes:
147
        p.join()
148
"""
149
        )
150
        self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
151

152
    def test_lazy_init(self):
153
        """Validate that no XPU calls are made during `import torch` call"""
154

155
        def check_output(script: str) -> str:
156
            return (
157
                subprocess.check_output([sys.executable, "-c", script])
158
                .decode("ascii")
159
                .strip()
160
            )
161

162
        test_script = """\
163
import torch
164
from torch.multiprocessing import Process
165
import copy
166

167
def run_model(model, input):
168
    input_xpu = input.clone().to('xpu')
169
    model_xpu = copy.deepcopy(model).to('xpu')
170
    loss_xpu = model_xpu(input_xpu).sum()
171
    loss = model(input).sum()
172
    torch.testing.assert_close(loss_xpu.cpu(), loss)
173

174
def test_multi_process(model, input):
175
    p = Process(target=run_model, args=(model, input))
176
    p.start()
177
    p.join()
178
    assert p.exitcode == 0
179

180
input = torch.rand(32, 3, 224, 224)
181
model = torch.nn.Sequential(
182
    torch.nn.Conv2d(3, 64, 3, stride=2),
183
    torch.nn.ReLU(),
184
    torch.nn.MaxPool2d(2, 2),
185
)
186
test_multi_process(model, input)
187
test_multi_process(model, input)
188
print(torch.xpu.device_count())
189
"""
190
        rc = check_output(test_script)
191
        self.assertEqual(rc, str(torch.xpu.device_count()))
192

193
    def test_streams(self):
194
        s0 = torch.xpu.Stream()
195
        torch.xpu.set_stream(s0)
196
        s1 = torch.xpu.current_stream()
197
        self.assertEqual(s0, s1)
198
        s2 = torch.xpu.Stream()
199
        self.assertFalse(s0 == s2)
200
        torch.xpu.set_stream(s2)
201
        with torch.xpu.stream(s0):
202
            self.assertEqual(s0, torch.xpu.current_stream())
203
        self.assertEqual(s2, torch.xpu.current_stream())
204

205
    def test_stream_priority(self):
206
        low, high = torch.xpu.Stream.priority_range()
207
        s0 = torch.xpu.Stream(device=0, priority=low)
208

209
        self.assertEqual(low, s0.priority)
210
        self.assertEqual(torch.device("xpu:0"), s0.device)
211

212
        s1 = torch.xpu.Stream(device=0, priority=high)
213

214
        self.assertEqual(high, s1.priority)
215
        self.assertEqual(torch.device("xpu:0"), s1.device)
216

217
    def test_stream_event_repr(self):
218
        s = torch.xpu.current_stream()
219
        self.assertTrue("torch.xpu.Stream" in str(s))
220
        e = torch.xpu.Event()
221
        self.assertTrue("torch.xpu.Event(uninitialized)" in str(e))
222
        s.record_event(e)
223
        self.assertTrue("torch.xpu.Event" in str(e))
224

225
    def test_events(self):
226
        stream = torch.xpu.current_stream()
227
        event = torch.xpu.Event()
228
        self.assertTrue(event.query())
229
        stream.record_event(event)
230
        event.synchronize()
231
        self.assertTrue(event.query())
232

233
    def test_generic_stream_event(self):
234
        stream = torch.Stream("xpu")
235
        self.assertEqual(stream.device_index, torch.xpu.current_device())
236
        xpu_stream = torch.xpu.Stream(
237
            stream_id=stream.stream_id,
238
            device_index=stream.device_index,
239
            device_type=stream.device_type,
240
        )
241
        self.assertEqual(stream.stream_id, xpu_stream.stream_id)
242
        self.assertNotEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
243

244
        event1 = torch.Event("xpu")
245
        event2 = torch.Event("xpu")
246
        self.assertEqual(event1.event_id, 0)
247
        a = torch.randn(1000)
248
        b = torch.randn(1000)
249
        with torch.xpu.stream(xpu_stream):
250
            a_xpu = a.to("xpu", non_blocking=True)
251
            b_xpu = b.to("xpu", non_blocking=True)
252
            self.assertEqual(stream.stream_id, torch.xpu.current_stream().stream_id)
253
        event1.record(stream)
254
        event1.synchronize()
255
        self.assertTrue(event1.query())
256
        c_xpu = a_xpu + b_xpu
257
        event2.record()
258
        event2.synchronize()
259
        self.assertTrue(event2.query())
260
        self.assertNotEqual(event1.event_id, event2.event_id)
261
        self.assertEqual(c_xpu.cpu(), a + b)
262
        with self.assertRaisesRegex(
263
            NotImplementedError, "elapsedTime is not supported by XPU backend."
264
        ):
265
            event1.elapsed_time(event2)
266

267
    def test_generator(self):
268
        torch.manual_seed(2024)
269
        g_state0 = torch.xpu.get_rng_state()
270
        torch.manual_seed(1234)
271
        g_state1 = torch.xpu.get_rng_state()
272
        self.assertNotEqual(g_state0, g_state1)
273

274
        torch.xpu.manual_seed(2024)
275
        g_state2 = torch.xpu.get_rng_state()
276
        self.assertEqual(g_state0, g_state2)
277

278
        torch.xpu.set_rng_state(g_state1)
279
        self.assertEqual(g_state1, torch.xpu.get_rng_state())
280

281
        torch.manual_seed(1234)
282
        torch.xpu.set_rng_state(g_state0)
283
        self.assertEqual(2024, torch.xpu.initial_seed())
284

285
    @onlyXPU
286
    @suppress_warnings
287
    @ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
288
    def test_compare_cpu(self, device, dtype, op):
289
        def to_cpu(arg):
290
            if isinstance(arg, torch.Tensor):
291
                return arg.to(device="cpu")
292
            return arg
293

294
        samples = op.reference_inputs(device, dtype)
295

296
        for sample in samples:
297
            cpu_sample = sample.transform(to_cpu)
298
            xpu_results = op(sample.input, *sample.args, **sample.kwargs)
299
            cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
300

301
            xpu_results = sample.output_process_fn_grad(xpu_results)
302
            cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
303

304
            # Lower tolerance because we are running this as a `@slowTest`
305
            # Don't want the periodic tests to fail frequently
306
            self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
307

308
    @onlyXPU
309
    @ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
310
    @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
311
    def test_non_standard_bool_values(self, device, dtype, op):
312
        # Test boolean values other than 0x00 and 0x01 (gh-54789)
313
        def convert_boolean_tensors(x):
314
            if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
315
                return x
316

317
            # Map False -> 0 and True -> Random value in [2, 255]
318
            true_vals = torch.randint(
319
                2, 255, x.shape, dtype=torch.uint8, device=x.device
320
            )
321
            false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
322
            x_int = torch.where(x, true_vals, false_vals)
323

324
            ret = x_int.view(torch.bool)
325
            self.assertEqual(ret, x)
326
            return ret
327

328
        for sample in op.sample_inputs(device, dtype):
329
            expect = op(sample.input, *sample.args, **sample.kwargs)
330

331
            transformed = sample.transform(convert_boolean_tensors)
332
            actual = op(transformed.input, *transformed.args, **transformed.kwargs)
333

334
            self.assertEqual(expect, actual)
335

336
    def test_serialization_array_with_storage(self):
337
        x = torch.randn(5, 5).xpu()
338
        y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
339
        q = [x, y, x, y.storage()]
340
        with tempfile.NamedTemporaryFile() as f:
341
            torch.save(q, f)
342
            f.seek(0)
343
            q_copy = torch.load(f)
344
        self.assertEqual(q_copy, q, atol=0, rtol=0)
345
        q_copy[0].fill_(5)
346
        self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
347
        self.assertEqual(q_copy[0].dtype, torch.float)
348
        self.assertEqual(q_copy[1].dtype, torch.int)
349
        self.assertEqual(q_copy[2].dtype, torch.float)
350
        self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
351
        self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
352
        q_copy[1].fill_(10)
353
        y.fill_(10)
354
        self.assertEqual(q_copy[3], y.storage())
355

356
    def test_serialization_array_with_empty(self):
357
        x = [
358
            torch.randn(4, 4).xpu(),
359
            torch.tensor([], dtype=torch.float, device=torch.device("xpu")),
360
        ]
361
        with tempfile.NamedTemporaryFile() as f:
362
            torch.save(x, f)
363
            f.seek(0)
364
            x_copy = torch.load(f)
365
        for original, copy in zip(x, x_copy):
366
            self.assertEqual(copy, original)
367
            self.assertIs(type(copy), type(original))
368
            self.assertEqual(copy.get_device(), original.get_device())
369

370

371
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
372

373

374
class TestXpuAutocast(TestCase):
375
    # These operators are not implemented on XPU backend and we can NOT fall back
376
    # them to CPU. So we have to skip them at this moment.
377
    # TODO: remove these operators from skip list when they are implemented on XPU backend.
378
    skip_list = ["gru_cell"]
379

380
    def setUp(self):
381
        super().setUp()
382
        self.autocast_lists = AutocastTestLists(torch.device("xpu"))
383

384
    def tearDown(self):
385
        del self.autocast_lists
386
        super().tearDown()
387

388
    def _run_autocast_outofplace(
389
        self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
390
    ):
391
        # helper to cast args
392
        def cast(val, to_type):
393
            if isinstance(val, torch.Tensor):
394
                return val.to(to_type) if val.is_floating_point() else val
395
            elif isinstance(val, collections.abc.Iterable):
396
                return type(val)(cast(v, to_type) for v in val)
397
            else:
398
                return val
399

400
        if add_kwargs is None:
401
            add_kwargs = {}
402
        fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
403
        self.assertFalse(torch.is_autocast_enabled("xpu"))
404
        with torch.amp.autocast("xpu", dtype=fast_dtype):
405
            self.assertTrue(torch.is_autocast_enabled("xpu"))
406

407
            out_type = out_type if out_type is not None else run_as_type
408
            output = output_method = None
409

410
            # Try module.* variant, if requested:
411
            if module is not None and hasattr(module, op):
412
                output = getattr(module, op)(*args, **add_kwargs)
413
                if isinstance(output, torch.Tensor):
414
                    self.assertTrue(
415
                        out_type == output.dtype,
416
                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
417
                    )
418

419
            # Try Tensor.* variant:
420
            if hasattr(torch.Tensor, op):
421
                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
422
                if isinstance(output_method, torch.Tensor):
423
                    self.assertTrue(
424
                        out_type == output_method.dtype,
425
                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
426
                    )
427

428
            self.assertTrue(
429
                (output is not None) or (output_method is not None),
430
                f"{op} not found as an attribute on either Tensor or the requested module {module}",
431
            )
432

433
            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
434
            # For example, lstm_cell returns a tuple and equal returns bool.
435
            def compare(first, second):
436
                if isinstance(first, torch.Tensor):
437
                    return torch.equal(first, second)
438
                elif isinstance(first, collections.abc.Iterable):
439
                    return all(compare(f, s) for f, s in zip(first, second))
440
                else:
441
                    return first == second
442

443
            # If both torch.* and Tensor.* variants were found, check outputs are identical
444
            if (output is not None) and (output_method is not None):
445
                self.assertTrue(type(output) == type(output_method))
446
                comparison = compare(output, output_method)
447
                self.assertTrue(
448
                    comparison, f"torch.{op} result did not match Tensor.{op} result"
449
                )
450

451
            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
452
            # as the C++-side autocasting, and should be bitwise accurate.
453
            output_to_compare = output if output is not None else output_method
454
            with torch.amp.autocast("xpu", enabled=False):
455
                self.assertFalse(torch.is_autocast_enabled("xpu"))
456

457
                if module is not None and hasattr(module, op):
458
                    control = getattr(module, op)(
459
                        *cast(args, run_as_type), **add_kwargs
460
                    )
461
                else:
462
                    control = getattr(args[0].to(run_as_type), op)(
463
                        *cast(args[1:], run_as_type), **add_kwargs
464
                    )
465
                self.assertTrue(type(output_to_compare) == type(control))
466
                comparison = compare(output_to_compare, control)
467
                self.assertTrue(comparison, f"torch.{op} result did not match control")
468
            self.assertTrue(torch.is_autocast_enabled("xpu"))
469
        self.assertFalse(torch.is_autocast_enabled("xpu"))
470

471
    def test_autocast_torch_fp16(self):
472
        for op_with_args in self.autocast_lists.torch_fp16:
473
            skip_test = False
474
            op, args = op_with_args[0], op_with_args[1]
475
            if op in self.skip_list:
476
                skip_test = True  # skip unimplemented op
477
            if len(op_with_args) == 3:
478
                skip_test = True  # skip cudnn op
479
            if not skip_test:
480
                self._run_autocast_outofplace(op, args, torch.float16)
481

482
    def test_autocast_torch_bf16(self):
483
        for op_with_args in self.autocast_lists.torch_fp16:
484
            skip_test = False
485
            op, args = op_with_args[0], op_with_args[1]
486
            if op in self.skip_list:
487
                skip_test = True  # skip unimplemented op
488
            if len(op_with_args) == 3:
489
                skip_test = True  # skip cudnn op
490
            if not skip_test:
491
                self._run_autocast_outofplace(op, args, torch.bfloat16)
492

493
    def test_autocast_torch_need_autocast_promote(self):
494
        for op, args in self.autocast_lists.torch_need_autocast_promote:
495
            self._run_autocast_outofplace(op, args, torch.float32)
496

497
    def test_autocast_torch_expect_builtin_promote(self):
498
        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
499
            self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
500

501
    def test_autocast_checkpointing(self):
502
        model = torch.nn.Sequential(
503
            torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
504
        ).xpu()
505
        input = torch.rand(
506
            (8, 8), device="xpu", dtype=torch.float16, requires_grad=True
507
        )
508
        for reentrant in (True, False):
509
            with torch.autocast("xpu"):
510
                output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
511
            self.assertTrue(output.requires_grad)
512
            self.assertTrue(output.dtype is torch.float16)
513
            output.sum().backward()
514

515
    def test_xpu_autocast_dtype(self):
516
        dtype = torch.get_autocast_dtype("xpu")
517
        self.assertEqual(dtype, torch.float16)
518
        mat0_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
519
        mat1_fp32 = torch.randn((10, 10), dtype=torch.float32, device="xpu")
520
        with torch.amp.autocast("xpu"):
521
            result = torch.mm(mat0_fp32, mat1_fp32)
522
            self.assertEqual(result.dtype, torch.float16)
523

524

525
class TestXpuTrace(TestCase):
526
    def setUp(self):
527
        torch._C._activate_gpu_trace()
528
        self.mock = unittest.mock.MagicMock()
529

530
    def test_event_creation_callback(self):
531
        gpu_trace.register_callback_for_event_creation(self.mock)
532

533
        event = torch.xpu.Event()
534
        event.record()
535
        self.mock.assert_called_once_with(event._as_parameter_.value)
536

537
    def test_event_deletion_callback(self):
538
        gpu_trace.register_callback_for_event_deletion(self.mock)
539

540
        event = torch.xpu.Event()
541
        event.record()
542
        event_id = event._as_parameter_.value
543
        del event
544
        self.mock.assert_called_once_with(event_id)
545

546
    def test_event_record_callback(self):
547
        gpu_trace.register_callback_for_event_record(self.mock)
548

549
        event = torch.xpu.Event()
550
        event.record()
551
        self.mock.assert_called_once_with(
552
            event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
553
        )
554

555
    def test_event_wait_callback(self):
556
        gpu_trace.register_callback_for_event_wait(self.mock)
557

558
        event = torch.xpu.Event()
559
        event.record()
560
        event.wait()
561
        self.mock.assert_called_once_with(
562
            event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
563
        )
564

565
    def test_device_synchronization_callback(self):
566
        gpu_trace.register_callback_for_device_synchronization(self.mock)
567

568
        torch.xpu.synchronize()
569
        self.mock.assert_called()
570

571
    def test_stream_synchronization_callback(self):
572
        gpu_trace.register_callback_for_stream_synchronization(self.mock)
573

574
        stream = torch.xpu.Stream()
575
        stream.synchronize()
576
        self.mock.assert_called_once_with(stream.sycl_queue)
577

578
    def test_event_synchronization_callback(self):
579
        gpu_trace.register_callback_for_event_synchronization(self.mock)
580

581
        event = torch.xpu.Event()
582
        event.record()
583
        event.synchronize()
584
        self.mock.assert_called_once_with(event._as_parameter_.value)
585

586

587
if __name__ == "__main__":
588
    run_tests()
589

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.