pytorch

Форк
0
/
test_cuda_sanitizer.py 
496 строк · 19.5 Кб
1
# Owner(s): ["module: cuda"]
2

3
import sys
4
import textwrap
5
import traceback
6
from typing import List
7

8
import torch
9
import torch.cuda._sanitizer as csan
10
from torch.cuda._sanitizer import DataPtr, EventId, StreamId
11
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase
12

13

14
if not TEST_CUDA:
15
    print("CUDA not available, skipping tests", file=sys.stderr)
16
    TestCase = NoTest  # noqa: F811
17

18

19
class TestArgumentHandler(TestCase):
20
    def test_add(self):
21
        add_func = torch.ops.aten.add.Tensor
22
        a = torch.ones(5, 3, device="cuda")
23
        b = torch.randn(5, 3, device="cuda")
24

25
        argument_handler = csan.ArgumentHandler()
26
        argument_handler.parse_inputs(add_func._schema, (a, b), {})
27
        c = torch.add(a, b)
28
        argument_handler.parse_outputs(c)
29

30
        self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read)
31
        self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written)
32

33
    def test_cat(self):
34
        cat_func = torch.ops.aten.cat.default
35
        a = torch.ones(2, 4, 5, device="cuda")
36
        b = torch.zeros(2, 1, 5, device="cuda")
37
        c = torch.rand(2, 7, 5, device="cuda")
38

39
        argument_handler = csan.ArgumentHandler()
40
        argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {})
41
        d = torch.cat((a, b, c), dim=1)
42
        argument_handler.parse_outputs(d)
43

44
        self.assertEqual(
45
            {a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read
46
        )
47
        self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written)
48

49
    def test_split(self):
50
        split_func = torch.ops.aten.split.Tensor
51
        a = torch.arange(10, device="cuda").reshape(5, 2)
52

53
        argument_handler = csan.ArgumentHandler()
54
        argument_handler.parse_inputs(split_func._schema, (a, 2), {})
55
        out = torch.split(a, 2)
56
        argument_handler.parse_outputs(out)
57

58
        outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
59
        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
60
        self.assertEqual(outputs, argument_handler.dataptrs_written)
61

62
    def test_inplace(self):
63
        add_inplace_func = torch.ops.aten.add_.Tensor
64
        a = torch.rand(4, 2, device="cuda")
65

66
        argument_handler = csan.ArgumentHandler()
67
        argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {})
68
        a.add_(5)
69
        argument_handler.parse_outputs(a)
70

71
        self.assertEqual(set(), argument_handler.dataptrs_read)
72
        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written)
73

74
    def test_out(self):
75
        mul_out_func = torch.ops.aten.mul.out
76
        a = torch.arange(8, device="cuda")
77
        b = torch.empty(8, device="cuda")
78

79
        argument_handler = csan.ArgumentHandler()
80
        argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b})
81
        torch.mul(a, 3, out=b)
82
        argument_handler.parse_outputs(b)
83

84
        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
85
        self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written)
86

87
    def test_nonzero(self):
88
        nonzero_func = torch.ops.aten.nonzero.default
89
        a = torch.ones(5, 3, 2, device="cuda")
90

91
        argument_handler = csan.ArgumentHandler()
92
        argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True})
93
        out = torch.nonzero(a, as_tuple=True)
94
        argument_handler.parse_outputs(out)
95

96
        outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
97
        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
98
        self.assertEqual(outputs, argument_handler.dataptrs_written)
99

100
    def test_tensor_names(self):
101
        addr_func = torch.ops.aten.addr.default
102
        vec = torch.arange(1, 4, device="cuda")
103
        M = torch.zeros(3, 3, device="cuda")
104

105
        argument_handler = csan.ArgumentHandler()
106
        argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {})
107
        out = torch.addr(M, vec, vec)
108
        argument_handler.parse_outputs(out)
109

110
        self.assertEqual(
111
            argument_handler.tensor_aliases,
112
            {
113
                M.data_ptr(): ["self"],
114
                vec.data_ptr(): ["vec1", "vec2"],
115
                out.data_ptr(): [],
116
            },
117
        )
118
        self.assertEqual({out.data_ptr()}, argument_handler.outputs)
119

120

121
def tensor_id(i: int) -> DataPtr:
122
    return i
123

124

125
def stream_id(i: int) -> StreamId:
126
    return 1000 + i
127

128

129
def event_id(i: int) -> EventId:
130
    return 2000 + i
131

132

133
class TestEventHandler(TestCase):
134
    def setUp(self):
135
        self.handler = csan.EventHandler()
136

137
    def kernel_launch(
138
        self,
139
        stream: StreamId,
140
        read_only: List[DataPtr] = None,
141
        read_write: List[DataPtr] = None,
142
    ) -> List[csan.SynchronizationError]:
143
        if read_only is None:
144
            read_only = []
145
        if read_write is None:
146
            read_write = []
147
        return self.handler._handle_kernel_launch(
148
            stream,
149
            read_only,
150
            read_write,
151
            {},
152
            "",
153
            {k: [""] for k in read_only + read_write},
154
        )
155

156
    def assert_good_kernel_launch(
157
        self,
158
        stream: StreamId,
159
        read_only: List[DataPtr] = None,
160
        read_write: List[DataPtr] = None,
161
    ) -> None:
162
        self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
163

164
    def assert_bad_kernel_launch(
165
        self,
166
        number_of_errors: int,
167
        stream: StreamId,
168
        read_only: List[DataPtr] = None,
169
        read_write: List[DataPtr] = None,
170
    ) -> None:
171
        errors = self.kernel_launch(stream, read_only, read_write)
172
        self.assertEqual(len(errors), number_of_errors)
173

174
    def test_empty_kernel_launch(self):
175
        self.assert_good_kernel_launch(stream_id(0))
176

177
    def test_simple_passing(self):
178
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
179
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
180

181
    def test_simple_error(self):
182
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
183
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
184

185
    def test_simple_sync(self):
186
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
187
        self.handler._handle_event_record(event_id(0), stream_id(1))
188
        self.handler._handle_event_wait(event_id(0), stream_id(2))
189
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
190

191
    def test_reads_check_last_write(self):
192
        # Tests that not only the first read operation checks if it is in conflict
193
        # with the last write operation, but all read operations do.
194

195
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
196
        self.handler._handle_event_record(event_id(0), stream_id(1))
197
        self.handler._handle_event_wait(event_id(0), stream_id(2))
198
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
199

200
        self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)])
201

202
    def test_branch_sync(self):
203
        # Tests that two streams can read after both waiting for a third, but they
204
        # cannot write without further synchronization.
205

206
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
207
        self.handler._handle_event_record(event_id(0), stream_id(1))
208
        self.handler._handle_event_wait(event_id(0), stream_id(2))
209
        self.handler._handle_event_wait(event_id(0), stream_id(3))
210
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
211
        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
212

213
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
214

215
    def test_chain_sync(self):
216
        iterations = 10
217

218
        self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)])
219
        for i in range(iterations):
220
            self.handler._handle_event_record(event_id(i), stream_id(i))
221
            self.handler._handle_event_wait(event_id(i), stream_id(i + 1))
222
        self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)])
223

224
    def test_expired_record(self):
225
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
226
        self.handler._handle_event_record(event_id(0), stream_id(1))
227
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
228
        self.handler._handle_event_wait(event_id(0), stream_id(2))
229

230
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
231

232
    def test_deleted_record(self):
233
        for should_delete, should_create in [
234
            (True, True),
235
            (True, False),
236
            (False, True),
237
        ]:
238
            self.setUp()
239
            with self.subTest(should_delete=should_delete, should_create=should_create):
240
                self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
241
                self.handler._handle_event_record(event_id(0), stream_id(1))
242

243
                if should_delete:
244
                    self.handler._handle_event_deletion(event_id(0))
245
                if should_create:
246
                    self.handler._handle_event_creation(event_id(0))
247

248
                self.handler._handle_event_wait(event_id(0), stream_id(2))
249
                self.assert_bad_kernel_launch(
250
                    1, stream_id(2), read_write=[tensor_id(1)]
251
                )
252

253
    def test_all_reads_checked_failing(self):
254
        iterations = 10
255
        for i in range(1, iterations):
256
            self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
257
            self.handler._handle_event_record(event_id(i), stream_id(i))
258

259
        for i in range(1, iterations):
260
            self.handler._handle_event_wait(event_id(i), stream_id(0))
261

262
        self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)])
263
        self.handler._handle_event_record(event_id(iterations), stream_id(i))
264

265
        # Does not synchronize with the last read.
266
        self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)])
267

268
    def test_all_reads_checked_passing(self):
269
        iterations = 10
270
        for i in range(1, iterations):
271
            self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
272
            self.handler._handle_event_record(event_id(i), stream_id(i))
273

274
        for i in range(1, iterations):
275
            self.handler._handle_event_wait(event_id(i), stream_id(0))
276

277
        self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
278

279
    def test_multiple_errors(self):
280
        iterations = 10
281
        self.assert_good_kernel_launch(
282
            stream_id(0), read_write=[tensor_id(i) for i in range(iterations)]
283
        )
284
        self.assert_bad_kernel_launch(
285
            iterations,
286
            stream_id(1),
287
            read_write=[tensor_id(i) for i in range(iterations)],
288
        )
289

290
    def test_correct_state_merging(self):
291
        # Tests that after waiting for an event, a stream's state is indeed set
292
        # to the pointwise maximum of its old state and the recorded state.
293

294
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
295
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
296
        self.handler._handle_event_record(event_id(1), stream_id(1))
297
        self.handler._handle_event_record(event_id(2), stream_id(2))
298

299
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
300
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
301
        self.handler._handle_event_wait(event_id(1), stream_id(2))
302
        self.handler._handle_event_wait(event_id(2), stream_id(1))
303

304
        self.handler._handle_event_record(event_id(3), stream_id(2))
305
        self.handler._handle_event_wait(event_id(3), stream_id(1))
306
        self.assert_good_kernel_launch(
307
            stream_id(1), read_write=[tensor_id(1), tensor_id(2)]
308
        )
309

310
    def test_record_override(self):
311
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
312
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)])
313
        self.handler._handle_event_record(event_id(1), stream_id(1))
314
        self.handler._handle_event_record(event_id(1), stream_id(2))
315

316
        self.handler._handle_event_wait(event_id(1), stream_id(3))
317
        self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)])
318

319
    def test_multiple_wait(self):
320
        # Tests that a wait operation can be performed multiple times on the same event
321
        # by different streams.
322

323
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
324
        self.handler._handle_event_record(event_id(1), stream_id(1))
325
        self.handler._handle_event_wait(event_id(1), stream_id(2))
326
        self.handler._handle_event_wait(event_id(1), stream_id(3))
327

328
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
329
        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
330

331
    def test_device_synchronize(self):
332
        # Tests that a device synchronization does correctly cause all streams
333
        # to synchronize with each other.
334

335
        iterations = 10
336
        for i in range(1, iterations):
337
            self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)])
338

339
        self.handler._handle_device_synchronization()
340
        self.assert_good_kernel_launch(
341
            stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)]
342
        )
343

344
    def test_device_synchronization_expired(self):
345
        # Tests that a device synchronization is a one-time synchronization.
346
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
347
        self.handler._handle_device_synchronization()
348
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
349

350
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
351

352
    def test_new_stream_is_synchronized(self):
353
        # Tests that after synchronizing operations with the host, any newly created
354
        # stream is guaranteed to be synchronized with them as well.
355

356
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
357
        self.handler._handle_device_synchronization()
358
        self.handler._handle_stream_creation(stream_id(2))
359
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
360

361
    def test_stream_synchronize(self):
362
        # Tests that a stream synchronization does correctly cause all streams to wait
363
        # for one specific stream, but does not synchronize all streams with each other.
364

365
        self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
366
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
367
        self.handler._handle_stream_synchronization(stream_id(0))
368

369
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
370
        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
371
        self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)])
372

373
    def test_event_synchronize(self):
374
        # Tests that an event synchronization does correctly cause all streams to wait
375
        # for a recorded event, but does not guarantee synchronization with the current
376
        # state of the stream that recorded the event.
377

378
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
379
        self.handler._handle_event_record(event_id(1), stream_id(1))
380
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
381

382
        self.handler._handle_event_synchronization(event_id(1))
383
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
384
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)])
385

386

387
class TestMessages(TestCase):
388
    def setUp(self):
389
        self.handler = csan.EventHandler()
390

391
    def test_ensure_exists(self):
392
        ARG = 0
393
        for func, out in [
394
            (
395
                self.handler._handle_event_deletion,
396
                f"Found Event with id: {ARG}, but no matching event "
397
                "creation in the trace. Backfilling the trace now. "
398
                "Perhaps the sanitizer was enabled after some torch operations?",
399
            ),
400
            (
401
                self.handler._handle_memory_deallocation,
402
                f"Found tensor with pointer: {ARG}, but no matching tensor "
403
                "allocation in the trace. Backfilling the trace now. "
404
                "Perhaps the sanitizer was enabled after some torch operations?",
405
            ),
406
        ]:
407
            with self.subTest(func=func, out=out):
408
                with self.assertLogs() as captured:
409
                    func(ARG)
410
                self.assertEqual(captured.records[0].getMessage(), out)
411

412
    def test_ensure_does_not_exist(self):
413
        ARG = 0
414
        self.handler._handle_event_creation(ARG)
415
        self.handler._handle_stream_creation(ARG)
416
        for func, out in [
417
            (
418
                self.handler._handle_event_creation,
419
                "Found duplicate event creation in the trace for event with "
420
                f"id: {ARG}. Assuming the trace for event deletion wasn't caught "
421
                "and backfilling it now. "
422
                "Perhaps the sanitizer was enabled after some torch operations?",
423
            ),
424
            (
425
                self.handler._handle_stream_creation,
426
                "Found duplicate Stream creation in the trace for Stream with "
427
                f"id: {ARG}. PyTorch Streams are only created once, so this "
428
                "trace entry is ignored.",
429
            ),
430
        ]:
431
            with self.subTest(func=func, out=out):
432
                with self.assertLogs() as captured:
433
                    func(ARG)
434
                self.assertEqual(captured.records[0].getMessage(), out)
435

436
    def test_error_message(self):
437
        current_access = csan.Access(
438
            type=csan.AccessType.WRITE,
439
            seq_num=1,
440
            stream=stream_id(1),
441
            operator="schema",
442
            aliases=["b"],
443
            is_output=True,
444
            stack_trace=traceback.StackSummary.from_list(
445
                [("file", 0, "name", "trace a")]
446
            ),
447
        )
448
        previous_access = csan.Access(
449
            type=csan.AccessType.READ,
450
            seq_num=2,
451
            stream=stream_id(0),
452
            operator="schema",
453
            aliases=["a"],
454
            is_output=False,
455
            stack_trace=traceback.StackSummary.from_list(
456
                [("file", 0, "name", "trace b")]
457
            ),
458
        )
459
        error = csan.UnsynchronizedAccessError(
460
            data_ptr=tensor_id(1),
461
            allocation_stack_trace=traceback.StackSummary.from_list(
462
                [("file", 0, "name", "alloc")]
463
            ),
464
            current_access=current_access,
465
            previous_access=previous_access,
466
        )
467
        self.assertEqual(
468
            str(error),
469
            textwrap.dedent(
470
                """\
471
                ============================
472
                CSAN detected a possible data race on tensor with data pointer 1
473
                Access by stream 1001 during kernel:
474
                schema
475
                writing to argument(s) b, and to the output
476
                With stack trace:
477
                  File "file", line 0, in name
478
                    trace a
479

480
                Previous access by stream 1000 during kernel:
481
                schema
482
                reading from argument(s) a
483
                With stack trace:
484
                  File "file", line 0, in name
485
                    trace b
486

487
                Tensor was allocated with stack trace:
488
                  File "file", line 0, in name
489
                    alloc
490
                """
491
            ),
492
        )
493

494

495
if __name__ == "__main__":
496
    run_tests()
497

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.