9
import torch.cuda._sanitizer as csan
10
from torch.cuda._sanitizer import StreamId, DataPtr, EventId
11
from torch.testing._internal.common_utils import TestCase, run_tests, NoTest, TEST_CUDA
15
print("CUDA not available, skipping tests", file=sys.stderr)
19
class TestArgumentHandler(TestCase):
21
add_func = torch.ops.aten.add.Tensor
22
a = torch.ones(5, 3, device="cuda")
23
b = torch.randn(5, 3, device="cuda")
25
argument_handler = csan.ArgumentHandler()
26
argument_handler.parse_inputs(add_func._schema, (a, b), {})
28
argument_handler.parse_outputs(c)
30
self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read)
31
self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written)
34
cat_func = torch.ops.aten.cat.default
35
a = torch.ones(2, 4, 5, device="cuda")
36
b = torch.zeros(2, 1, 5, device="cuda")
37
c = torch.rand(2, 7, 5, device="cuda")
39
argument_handler = csan.ArgumentHandler()
40
argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {})
41
d = torch.cat((a, b, c), dim=1)
42
argument_handler.parse_outputs(d)
45
{a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read
47
self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written)
50
split_func = torch.ops.aten.split.Tensor
51
a = torch.arange(10, device="cuda").reshape(5, 2)
53
argument_handler = csan.ArgumentHandler()
54
argument_handler.parse_inputs(split_func._schema, (a, 2), {})
55
out = torch.split(a, 2)
56
argument_handler.parse_outputs(out)
58
outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
59
self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
62
argument_handler.dataptrs_written,
65
def test_inplace(self):
66
add_inplace_func = torch.ops.aten.add_.Tensor
67
a = torch.rand(4, 2, device="cuda")
69
argument_handler = csan.ArgumentHandler()
70
argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {})
72
argument_handler.parse_outputs(a)
74
self.assertEqual(set(), argument_handler.dataptrs_read)
75
self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written)
78
mul_out_func = torch.ops.aten.mul.out
79
a = torch.arange(8, device="cuda")
80
b = torch.empty(8, device="cuda")
82
argument_handler = csan.ArgumentHandler()
83
argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b})
84
torch.mul(a, 3, out=b)
85
argument_handler.parse_outputs(b)
87
self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
88
self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written)
90
def test_nonzero(self):
91
nonzero_func = torch.ops.aten.nonzero.default
92
a = torch.ones(5, 3, 2, device="cuda")
94
argument_handler = csan.ArgumentHandler()
95
argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True})
96
out = torch.nonzero(a, as_tuple=True)
97
argument_handler.parse_outputs(out)
99
outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
100
self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
101
self.assertEqual(outputs, argument_handler.dataptrs_written)
103
def test_tensor_names(self):
104
addr_func = torch.ops.aten.addr.default
105
vec = torch.arange(1, 4, device="cuda")
106
M = torch.zeros(3, 3, device="cuda")
108
argument_handler = csan.ArgumentHandler()
109
argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {})
110
out = torch.addr(M, vec, vec)
111
argument_handler.parse_outputs(out)
114
argument_handler.tensor_aliases,
116
M.data_ptr(): ["self"],
117
vec.data_ptr(): ["vec1", "vec2"],
121
self.assertEqual({out.data_ptr()}, argument_handler.outputs)
124
def tensor_id(i: int) -> DataPtr:
128
def stream_id(i: int) -> StreamId:
132
def event_id(i: int) -> EventId:
136
class TestEventHandler(TestCase):
138
self.handler = csan.EventHandler()
143
read_only: List[DataPtr] = None,
144
read_write: List[DataPtr] = None,
145
) -> List[csan.SynchronizationError]:
146
if read_only is None:
148
if read_write is None:
150
return self.handler._handle_kernel_launch(
156
{k: [""] for k in read_only + read_write},
159
def assert_good_kernel_launch(
162
read_only: List[DataPtr] = None,
163
read_write: List[DataPtr] = None,
165
self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
167
def assert_bad_kernel_launch(
169
number_of_errors: int,
171
read_only: List[DataPtr] = None,
172
read_write: List[DataPtr] = None,
174
errors = self.kernel_launch(stream, read_only, read_write)
175
self.assertEqual(len(errors), number_of_errors)
177
def test_empty_kernel_launch(self):
178
self.assert_good_kernel_launch(stream_id(0))
180
def test_simple_passing(self):
181
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
182
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
184
def test_simple_error(self):
185
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
186
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
188
def test_simple_sync(self):
189
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
190
self.handler._handle_event_record(event_id(0), stream_id(1))
191
self.handler._handle_event_wait(event_id(0), stream_id(2))
192
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
194
def test_reads_check_last_write(self):
198
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
199
self.handler._handle_event_record(event_id(0), stream_id(1))
200
self.handler._handle_event_wait(event_id(0), stream_id(2))
201
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
203
self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)])
205
def test_branch_sync(self):
209
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
210
self.handler._handle_event_record(event_id(0), stream_id(1))
211
self.handler._handle_event_wait(event_id(0), stream_id(2))
212
self.handler._handle_event_wait(event_id(0), stream_id(3))
213
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
214
self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
216
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
218
def test_chain_sync(self):
221
self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)])
222
for i in range(iterations):
223
self.handler._handle_event_record(event_id(i), stream_id(i))
224
self.handler._handle_event_wait(event_id(i), stream_id(i + 1))
225
self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)])
227
def test_expired_record(self):
228
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
229
self.handler._handle_event_record(event_id(0), stream_id(1))
230
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
231
self.handler._handle_event_wait(event_id(0), stream_id(2))
233
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
235
def test_deleted_record(self):
236
for should_delete, should_create in [
242
with self.subTest(should_delete=should_delete, should_create=should_create):
243
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
244
self.handler._handle_event_record(event_id(0), stream_id(1))
247
self.handler._handle_event_deletion(event_id(0))
249
self.handler._handle_event_creation(event_id(0))
251
self.handler._handle_event_wait(event_id(0), stream_id(2))
252
self.assert_bad_kernel_launch(
253
1, stream_id(2), read_write=[tensor_id(1)]
256
def test_all_reads_checked_failing(self):
258
for i in range(1, iterations):
259
self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
260
self.handler._handle_event_record(event_id(i), stream_id(i))
262
for i in range(1, iterations):
263
self.handler._handle_event_wait(event_id(i), stream_id(0))
265
self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)])
266
self.handler._handle_event_record(event_id(iterations), stream_id(i))
269
self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)])
271
def test_all_reads_checked_passing(self):
273
for i in range(1, iterations):
274
self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
275
self.handler._handle_event_record(event_id(i), stream_id(i))
277
for i in range(1, iterations):
278
self.handler._handle_event_wait(event_id(i), stream_id(0))
280
self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
282
def test_multiple_errors(self):
284
self.assert_good_kernel_launch(
285
stream_id(0), read_write=[tensor_id(i) for i in range(iterations)]
287
self.assert_bad_kernel_launch(
290
read_write=[tensor_id(i) for i in range(iterations)],
293
def test_correct_state_merging(self):
297
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
298
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
299
self.handler._handle_event_record(event_id(1), stream_id(1))
300
self.handler._handle_event_record(event_id(2), stream_id(2))
302
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
303
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
304
self.handler._handle_event_wait(event_id(1), stream_id(2))
305
self.handler._handle_event_wait(event_id(2), stream_id(1))
307
self.handler._handle_event_record(event_id(3), stream_id(2))
308
self.handler._handle_event_wait(event_id(3), stream_id(1))
309
self.assert_good_kernel_launch(
310
stream_id(1), read_write=[tensor_id(1), tensor_id(2)]
313
def test_record_override(self):
314
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
315
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)])
316
self.handler._handle_event_record(event_id(1), stream_id(1))
317
self.handler._handle_event_record(event_id(1), stream_id(2))
319
self.handler._handle_event_wait(event_id(1), stream_id(3))
320
self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)])
322
def test_multiple_wait(self):
326
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
327
self.handler._handle_event_record(event_id(1), stream_id(1))
328
self.handler._handle_event_wait(event_id(1), stream_id(2))
329
self.handler._handle_event_wait(event_id(1), stream_id(3))
331
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
332
self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
334
def test_device_synchronize(self):
339
for i in range(1, iterations):
340
self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)])
342
self.handler._handle_device_synchronization()
343
self.assert_good_kernel_launch(
344
stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)]
347
def test_device_synchronization_expired(self):
349
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
350
self.handler._handle_device_synchronization()
351
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
353
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
355
def test_new_stream_is_synchronized(self):
359
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
360
self.handler._handle_device_synchronization()
361
self.handler._handle_stream_creation(stream_id(2))
362
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
364
def test_stream_synchronize(self):
368
self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
369
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
370
self.handler._handle_stream_synchronization(stream_id(0))
372
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
373
self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
374
self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)])
376
def test_event_synchronize(self):
381
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
382
self.handler._handle_event_record(event_id(1), stream_id(1))
383
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
385
self.handler._handle_event_synchronization(event_id(1))
386
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
387
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)])
390
class TestMessages(TestCase):
392
self.handler = csan.EventHandler()
394
def test_ensure_exists(self):
398
self.handler._handle_event_deletion,
399
f"Found Event with id: {ARG}, but no matching event "
400
"creation in the trace. Backfilling the trace now. "
401
"Perhaps the sanitizer was enabled after some torch operations?",
404
self.handler._handle_memory_deallocation,
405
f"Found tensor with pointer: {ARG}, but no matching tensor "
406
"allocation in the trace. Backfilling the trace now. "
407
"Perhaps the sanitizer was enabled after some torch operations?",
410
with self.subTest(func=func, out=out):
411
with self.assertLogs() as captured:
413
self.assertEqual(captured.records[0].getMessage(), out)
415
def test_ensure_does_not_exist(self):
417
self.handler._handle_event_creation(ARG)
418
self.handler._handle_stream_creation(ARG)
421
self.handler._handle_event_creation,
422
"Found duplicate event creation in the trace for event with "
423
f"id: {ARG}. Assuming the trace for event deletion wasn't caught "
424
"and backfilling it now. "
425
"Perhaps the sanitizer was enabled after some torch operations?",
428
self.handler._handle_stream_creation,
429
"Found duplicate Stream creation in the trace for Stream with "
430
f"id: {ARG}. PyTorch Streams are only created once, so this "
431
"trace entry is ignored.",
434
with self.subTest(func=func, out=out):
435
with self.assertLogs() as captured:
437
self.assertEqual(captured.records[0].getMessage(), out)
439
def test_error_message(self):
440
current_access = csan.Access(
441
type=csan.AccessType.WRITE,
447
stack_trace=traceback.StackSummary.from_list(
448
[("file", 0, "name", "trace a")]
451
previous_access = csan.Access(
452
type=csan.AccessType.READ,
458
stack_trace=traceback.StackSummary.from_list(
459
[("file", 0, "name", "trace b")]
462
error = csan.UnsynchronizedAccessError(
463
data_ptr=tensor_id(1),
464
allocation_stack_trace=traceback.StackSummary.from_list(
465
[("file", 0, "name", "alloc")]
467
current_access=current_access,
468
previous_access=previous_access,
474
============================
475
CSAN detected a possible data race on tensor with data pointer 1
476
Access by stream 1001 during kernel:
478
writing to argument(s) b, and to the output
480
File "file", line 0, in name
483
Previous access by stream 1000 during kernel:
485
reading from argument(s) a
487
File "file", line 0, in name
490
Tensor was allocated with stack trace:
491
File "file", line 0, in name
498
if __name__ == "__main__":