pytorch

Форк
0
/
test_cuda_sanitizer.py 
499 строк · 19.5 Кб
1
# Owner(s): ["module: cuda"]
2

3
import sys
4
import textwrap
5
import traceback
6
from typing import List
7

8
import torch
9
import torch.cuda._sanitizer as csan
10
from torch.cuda._sanitizer import StreamId, DataPtr, EventId
11
from torch.testing._internal.common_utils import TestCase, run_tests, NoTest, TEST_CUDA
12

13

14
if not TEST_CUDA:
15
    print("CUDA not available, skipping tests", file=sys.stderr)
16
    TestCase = NoTest  # noqa: F811
17

18

19
class TestArgumentHandler(TestCase):
20
    def test_add(self):
21
        add_func = torch.ops.aten.add.Tensor
22
        a = torch.ones(5, 3, device="cuda")
23
        b = torch.randn(5, 3, device="cuda")
24

25
        argument_handler = csan.ArgumentHandler()
26
        argument_handler.parse_inputs(add_func._schema, (a, b), {})
27
        c = torch.add(a, b)
28
        argument_handler.parse_outputs(c)
29

30
        self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read)
31
        self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written)
32

33
    def test_cat(self):
34
        cat_func = torch.ops.aten.cat.default
35
        a = torch.ones(2, 4, 5, device="cuda")
36
        b = torch.zeros(2, 1, 5, device="cuda")
37
        c = torch.rand(2, 7, 5, device="cuda")
38

39
        argument_handler = csan.ArgumentHandler()
40
        argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {})
41
        d = torch.cat((a, b, c), dim=1)
42
        argument_handler.parse_outputs(d)
43

44
        self.assertEqual(
45
            {a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read
46
        )
47
        self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written)
48

49
    def test_split(self):
50
        split_func = torch.ops.aten.split.Tensor
51
        a = torch.arange(10, device="cuda").reshape(5, 2)
52

53
        argument_handler = csan.ArgumentHandler()
54
        argument_handler.parse_inputs(split_func._schema, (a, 2), {})
55
        out = torch.split(a, 2)
56
        argument_handler.parse_outputs(out)
57

58
        outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
59
        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
60
        self.assertEqual(
61
            outputs,
62
            argument_handler.dataptrs_written,
63
        )
64

65
    def test_inplace(self):
66
        add_inplace_func = torch.ops.aten.add_.Tensor
67
        a = torch.rand(4, 2, device="cuda")
68

69
        argument_handler = csan.ArgumentHandler()
70
        argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {})
71
        a.add_(5)
72
        argument_handler.parse_outputs(a)
73

74
        self.assertEqual(set(), argument_handler.dataptrs_read)
75
        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written)
76

77
    def test_out(self):
78
        mul_out_func = torch.ops.aten.mul.out
79
        a = torch.arange(8, device="cuda")
80
        b = torch.empty(8, device="cuda")
81

82
        argument_handler = csan.ArgumentHandler()
83
        argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b})
84
        torch.mul(a, 3, out=b)
85
        argument_handler.parse_outputs(b)
86

87
        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
88
        self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written)
89

90
    def test_nonzero(self):
91
        nonzero_func = torch.ops.aten.nonzero.default
92
        a = torch.ones(5, 3, 2, device="cuda")
93

94
        argument_handler = csan.ArgumentHandler()
95
        argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True})
96
        out = torch.nonzero(a, as_tuple=True)
97
        argument_handler.parse_outputs(out)
98

99
        outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
100
        self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
101
        self.assertEqual(outputs, argument_handler.dataptrs_written)
102

103
    def test_tensor_names(self):
104
        addr_func = torch.ops.aten.addr.default
105
        vec = torch.arange(1, 4, device="cuda")
106
        M = torch.zeros(3, 3, device="cuda")
107

108
        argument_handler = csan.ArgumentHandler()
109
        argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {})
110
        out = torch.addr(M, vec, vec)
111
        argument_handler.parse_outputs(out)
112

113
        self.assertEqual(
114
            argument_handler.tensor_aliases,
115
            {
116
                M.data_ptr(): ["self"],
117
                vec.data_ptr(): ["vec1", "vec2"],
118
                out.data_ptr(): [],
119
            },
120
        )
121
        self.assertEqual({out.data_ptr()}, argument_handler.outputs)
122

123

124
def tensor_id(i: int) -> DataPtr:
125
    return i
126

127

128
def stream_id(i: int) -> StreamId:
129
    return 1000 + i
130

131

132
def event_id(i: int) -> EventId:
133
    return 2000 + i
134

135

136
class TestEventHandler(TestCase):
137
    def setUp(self):
138
        self.handler = csan.EventHandler()
139

140
    def kernel_launch(
141
        self,
142
        stream: StreamId,
143
        read_only: List[DataPtr] = None,
144
        read_write: List[DataPtr] = None,
145
    ) -> List[csan.SynchronizationError]:
146
        if read_only is None:
147
            read_only = []
148
        if read_write is None:
149
            read_write = []
150
        return self.handler._handle_kernel_launch(
151
            stream,
152
            read_only,
153
            read_write,
154
            {},
155
            "",
156
            {k: [""] for k in read_only + read_write},
157
        )
158

159
    def assert_good_kernel_launch(
160
        self,
161
        stream: StreamId,
162
        read_only: List[DataPtr] = None,
163
        read_write: List[DataPtr] = None,
164
    ) -> None:
165
        self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
166

167
    def assert_bad_kernel_launch(
168
        self,
169
        number_of_errors: int,
170
        stream: StreamId,
171
        read_only: List[DataPtr] = None,
172
        read_write: List[DataPtr] = None,
173
    ) -> None:
174
        errors = self.kernel_launch(stream, read_only, read_write)
175
        self.assertEqual(len(errors), number_of_errors)
176

177
    def test_empty_kernel_launch(self):
178
        self.assert_good_kernel_launch(stream_id(0))
179

180
    def test_simple_passing(self):
181
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
182
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
183

184
    def test_simple_error(self):
185
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
186
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
187

188
    def test_simple_sync(self):
189
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
190
        self.handler._handle_event_record(event_id(0), stream_id(1))
191
        self.handler._handle_event_wait(event_id(0), stream_id(2))
192
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
193

194
    def test_reads_check_last_write(self):
195
        # Tests that not only the first read operation checks if it is in conflict
196
        # with the last write operation, but all read operations do.
197

198
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
199
        self.handler._handle_event_record(event_id(0), stream_id(1))
200
        self.handler._handle_event_wait(event_id(0), stream_id(2))
201
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
202

203
        self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)])
204

205
    def test_branch_sync(self):
206
        # Tests that two streams can read after both waiting for a third, but they
207
        # cannot write without further synchronization.
208

209
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
210
        self.handler._handle_event_record(event_id(0), stream_id(1))
211
        self.handler._handle_event_wait(event_id(0), stream_id(2))
212
        self.handler._handle_event_wait(event_id(0), stream_id(3))
213
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
214
        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
215

216
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
217

218
    def test_chain_sync(self):
219
        iterations = 10
220

221
        self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)])
222
        for i in range(iterations):
223
            self.handler._handle_event_record(event_id(i), stream_id(i))
224
            self.handler._handle_event_wait(event_id(i), stream_id(i + 1))
225
        self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)])
226

227
    def test_expired_record(self):
228
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
229
        self.handler._handle_event_record(event_id(0), stream_id(1))
230
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
231
        self.handler._handle_event_wait(event_id(0), stream_id(2))
232

233
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
234

235
    def test_deleted_record(self):
236
        for should_delete, should_create in [
237
            (True, True),
238
            (True, False),
239
            (False, True),
240
        ]:
241
            self.setUp()
242
            with self.subTest(should_delete=should_delete, should_create=should_create):
243
                self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
244
                self.handler._handle_event_record(event_id(0), stream_id(1))
245

246
                if should_delete:
247
                    self.handler._handle_event_deletion(event_id(0))
248
                if should_create:
249
                    self.handler._handle_event_creation(event_id(0))
250

251
                self.handler._handle_event_wait(event_id(0), stream_id(2))
252
                self.assert_bad_kernel_launch(
253
                    1, stream_id(2), read_write=[tensor_id(1)]
254
                )
255

256
    def test_all_reads_checked_failing(self):
257
        iterations = 10
258
        for i in range(1, iterations):
259
            self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
260
            self.handler._handle_event_record(event_id(i), stream_id(i))
261

262
        for i in range(1, iterations):
263
            self.handler._handle_event_wait(event_id(i), stream_id(0))
264

265
        self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)])
266
        self.handler._handle_event_record(event_id(iterations), stream_id(i))
267

268
        # Does not synchronize with the last read.
269
        self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)])
270

271
    def test_all_reads_checked_passing(self):
272
        iterations = 10
273
        for i in range(1, iterations):
274
            self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
275
            self.handler._handle_event_record(event_id(i), stream_id(i))
276

277
        for i in range(1, iterations):
278
            self.handler._handle_event_wait(event_id(i), stream_id(0))
279

280
        self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
281

282
    def test_multiple_errors(self):
283
        iterations = 10
284
        self.assert_good_kernel_launch(
285
            stream_id(0), read_write=[tensor_id(i) for i in range(iterations)]
286
        )
287
        self.assert_bad_kernel_launch(
288
            iterations,
289
            stream_id(1),
290
            read_write=[tensor_id(i) for i in range(iterations)],
291
        )
292

293
    def test_correct_state_merging(self):
294
        # Tests that after waiting for an event, a stream's state is indeed set
295
        # to the pointwise maximum of its old state and the recorded state.
296

297
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
298
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
299
        self.handler._handle_event_record(event_id(1), stream_id(1))
300
        self.handler._handle_event_record(event_id(2), stream_id(2))
301

302
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
303
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
304
        self.handler._handle_event_wait(event_id(1), stream_id(2))
305
        self.handler._handle_event_wait(event_id(2), stream_id(1))
306

307
        self.handler._handle_event_record(event_id(3), stream_id(2))
308
        self.handler._handle_event_wait(event_id(3), stream_id(1))
309
        self.assert_good_kernel_launch(
310
            stream_id(1), read_write=[tensor_id(1), tensor_id(2)]
311
        )
312

313
    def test_record_override(self):
314
        self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
315
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)])
316
        self.handler._handle_event_record(event_id(1), stream_id(1))
317
        self.handler._handle_event_record(event_id(1), stream_id(2))
318

319
        self.handler._handle_event_wait(event_id(1), stream_id(3))
320
        self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)])
321

322
    def test_multiple_wait(self):
323
        # Tests that a wait operation can be performed multiple times on the same event
324
        # by different streams.
325

326
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
327
        self.handler._handle_event_record(event_id(1), stream_id(1))
328
        self.handler._handle_event_wait(event_id(1), stream_id(2))
329
        self.handler._handle_event_wait(event_id(1), stream_id(3))
330

331
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
332
        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
333

334
    def test_device_synchronize(self):
335
        # Tests that a device synchronization does correctly cause all streams
336
        # to synchronize with each other.
337

338
        iterations = 10
339
        for i in range(1, iterations):
340
            self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)])
341

342
        self.handler._handle_device_synchronization()
343
        self.assert_good_kernel_launch(
344
            stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)]
345
        )
346

347
    def test_device_synchronization_expired(self):
348
        # Tests that a device synchronization is a one-time synchronization.
349
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
350
        self.handler._handle_device_synchronization()
351
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
352

353
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
354

355
    def test_new_stream_is_synchronized(self):
356
        # Tests that after synchronizing operations with the host, any newly created
357
        # stream is guaranteed to be synchronized with them as well.
358

359
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
360
        self.handler._handle_device_synchronization()
361
        self.handler._handle_stream_creation(stream_id(2))
362
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
363

364
    def test_stream_synchronize(self):
365
        # Tests that a stream synchronization does correctly cause all streams to wait
366
        # for one specific stream, but does not synchronize all streams with each other.
367

368
        self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
369
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
370
        self.handler._handle_stream_synchronization(stream_id(0))
371

372
        self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
373
        self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
374
        self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)])
375

376
    def test_event_synchronize(self):
377
        # Tests that an event synchronization does correctly cause all streams to wait
378
        # for a recorded event, but does not guarantee synchronization with the current
379
        # state of the stream that recorded the event.
380

381
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
382
        self.handler._handle_event_record(event_id(1), stream_id(1))
383
        self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
384

385
        self.handler._handle_event_synchronization(event_id(1))
386
        self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
387
        self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)])
388

389

390
class TestMessages(TestCase):
391
    def setUp(self):
392
        self.handler = csan.EventHandler()
393

394
    def test_ensure_exists(self):
395
        ARG = 0
396
        for func, out in [
397
            (
398
                self.handler._handle_event_deletion,
399
                f"Found Event with id: {ARG}, but no matching event "
400
                "creation in the trace. Backfilling the trace now. "
401
                "Perhaps the sanitizer was enabled after some torch operations?",
402
            ),
403
            (
404
                self.handler._handle_memory_deallocation,
405
                f"Found tensor with pointer: {ARG}, but no matching tensor "
406
                "allocation in the trace. Backfilling the trace now. "
407
                "Perhaps the sanitizer was enabled after some torch operations?",
408
            ),
409
        ]:
410
            with self.subTest(func=func, out=out):
411
                with self.assertLogs() as captured:
412
                    func(ARG)
413
                self.assertEqual(captured.records[0].getMessage(), out)
414

415
    def test_ensure_does_not_exist(self):
416
        ARG = 0
417
        self.handler._handle_event_creation(ARG)
418
        self.handler._handle_stream_creation(ARG)
419
        for func, out in [
420
            (
421
                self.handler._handle_event_creation,
422
                "Found duplicate event creation in the trace for event with "
423
                f"id: {ARG}. Assuming the trace for event deletion wasn't caught "
424
                "and backfilling it now. "
425
                "Perhaps the sanitizer was enabled after some torch operations?",
426
            ),
427
            (
428
                self.handler._handle_stream_creation,
429
                "Found duplicate Stream creation in the trace for Stream with "
430
                f"id: {ARG}. PyTorch Streams are only created once, so this "
431
                "trace entry is ignored.",
432
            ),
433
        ]:
434
            with self.subTest(func=func, out=out):
435
                with self.assertLogs() as captured:
436
                    func(ARG)
437
                self.assertEqual(captured.records[0].getMessage(), out)
438

439
    def test_error_message(self):
440
        current_access = csan.Access(
441
            type=csan.AccessType.WRITE,
442
            seq_num=1,
443
            stream=stream_id(1),
444
            operator="schema",
445
            aliases=["b"],
446
            is_output=True,
447
            stack_trace=traceback.StackSummary.from_list(
448
                [("file", 0, "name", "trace a")]
449
            ),
450
        )
451
        previous_access = csan.Access(
452
            type=csan.AccessType.READ,
453
            seq_num=2,
454
            stream=stream_id(0),
455
            operator="schema",
456
            aliases=["a"],
457
            is_output=False,
458
            stack_trace=traceback.StackSummary.from_list(
459
                [("file", 0, "name", "trace b")]
460
            ),
461
        )
462
        error = csan.UnsynchronizedAccessError(
463
            data_ptr=tensor_id(1),
464
            allocation_stack_trace=traceback.StackSummary.from_list(
465
                [("file", 0, "name", "alloc")]
466
            ),
467
            current_access=current_access,
468
            previous_access=previous_access,
469
        )
470
        self.assertEqual(
471
            str(error),
472
            textwrap.dedent(
473
                """\
474
                ============================
475
                CSAN detected a possible data race on tensor with data pointer 1
476
                Access by stream 1001 during kernel:
477
                schema
478
                writing to argument(s) b, and to the output
479
                With stack trace:
480
                  File "file", line 0, in name
481
                    trace a
482

483
                Previous access by stream 1000 during kernel:
484
                schema
485
                reading from argument(s) a
486
                With stack trace:
487
                  File "file", line 0, in name
488
                    trace b
489

490
                Tensor was allocated with stack trace:
491
                  File "file", line 0, in name
492
                    alloc
493
                """
494
            ),
495
        )
496

497

498
if __name__ == "__main__":
499
    run_tests()
500

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.