1
# Owner(s): ["module: cuda"]
9
import torch.cuda._sanitizer as csan
10
from torch.cuda._sanitizer import DataPtr, EventId, StreamId
11
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase
15
print("CUDA not available, skipping tests", file=sys.stderr)
16
TestCase = NoTest # noqa: F811
19
class TestArgumentHandler(TestCase):
21
add_func = torch.ops.aten.add.Tensor
22
a = torch.ones(5, 3, device="cuda")
23
b = torch.randn(5, 3, device="cuda")
25
argument_handler = csan.ArgumentHandler()
26
argument_handler.parse_inputs(add_func._schema, (a, b), {})
28
argument_handler.parse_outputs(c)
30
self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read)
31
self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written)
34
cat_func = torch.ops.aten.cat.default
35
a = torch.ones(2, 4, 5, device="cuda")
36
b = torch.zeros(2, 1, 5, device="cuda")
37
c = torch.rand(2, 7, 5, device="cuda")
39
argument_handler = csan.ArgumentHandler()
40
argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {})
41
d = torch.cat((a, b, c), dim=1)
42
argument_handler.parse_outputs(d)
45
{a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read
47
self.assertEqual({d.data_ptr()}, argument_handler.dataptrs_written)
50
split_func = torch.ops.aten.split.Tensor
51
a = torch.arange(10, device="cuda").reshape(5, 2)
53
argument_handler = csan.ArgumentHandler()
54
argument_handler.parse_inputs(split_func._schema, (a, 2), {})
55
out = torch.split(a, 2)
56
argument_handler.parse_outputs(out)
58
outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
59
self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
60
self.assertEqual(outputs, argument_handler.dataptrs_written)
62
def test_inplace(self):
63
add_inplace_func = torch.ops.aten.add_.Tensor
64
a = torch.rand(4, 2, device="cuda")
66
argument_handler = csan.ArgumentHandler()
67
argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {})
69
argument_handler.parse_outputs(a)
71
self.assertEqual(set(), argument_handler.dataptrs_read)
72
self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written)
75
mul_out_func = torch.ops.aten.mul.out
76
a = torch.arange(8, device="cuda")
77
b = torch.empty(8, device="cuda")
79
argument_handler = csan.ArgumentHandler()
80
argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b})
81
torch.mul(a, 3, out=b)
82
argument_handler.parse_outputs(b)
84
self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
85
self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written)
87
def test_nonzero(self):
88
nonzero_func = torch.ops.aten.nonzero.default
89
a = torch.ones(5, 3, 2, device="cuda")
91
argument_handler = csan.ArgumentHandler()
92
argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True})
93
out = torch.nonzero(a, as_tuple=True)
94
argument_handler.parse_outputs(out)
96
outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()}
97
self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read)
98
self.assertEqual(outputs, argument_handler.dataptrs_written)
100
def test_tensor_names(self):
101
addr_func = torch.ops.aten.addr.default
102
vec = torch.arange(1, 4, device="cuda")
103
M = torch.zeros(3, 3, device="cuda")
105
argument_handler = csan.ArgumentHandler()
106
argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {})
107
out = torch.addr(M, vec, vec)
108
argument_handler.parse_outputs(out)
111
argument_handler.tensor_aliases,
113
M.data_ptr(): ["self"],
114
vec.data_ptr(): ["vec1", "vec2"],
118
self.assertEqual({out.data_ptr()}, argument_handler.outputs)
121
def tensor_id(i: int) -> DataPtr:
125
def stream_id(i: int) -> StreamId:
129
def event_id(i: int) -> EventId:
133
class TestEventHandler(TestCase):
135
self.handler = csan.EventHandler()
140
read_only: List[DataPtr] = None,
141
read_write: List[DataPtr] = None,
142
) -> List[csan.SynchronizationError]:
143
if read_only is None:
145
if read_write is None:
147
return self.handler._handle_kernel_launch(
153
{k: [""] for k in read_only + read_write},
156
def assert_good_kernel_launch(
159
read_only: List[DataPtr] = None,
160
read_write: List[DataPtr] = None,
162
self.assertEqual(self.kernel_launch(stream, read_only, read_write), [])
164
def assert_bad_kernel_launch(
166
number_of_errors: int,
168
read_only: List[DataPtr] = None,
169
read_write: List[DataPtr] = None,
171
errors = self.kernel_launch(stream, read_only, read_write)
172
self.assertEqual(len(errors), number_of_errors)
174
def test_empty_kernel_launch(self):
175
self.assert_good_kernel_launch(stream_id(0))
177
def test_simple_passing(self):
178
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
179
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
181
def test_simple_error(self):
182
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
183
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
185
def test_simple_sync(self):
186
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
187
self.handler._handle_event_record(event_id(0), stream_id(1))
188
self.handler._handle_event_wait(event_id(0), stream_id(2))
189
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
191
def test_reads_check_last_write(self):
192
# Tests that not only the first read operation checks if it is in conflict
193
# with the last write operation, but all read operations do.
195
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
196
self.handler._handle_event_record(event_id(0), stream_id(1))
197
self.handler._handle_event_wait(event_id(0), stream_id(2))
198
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
200
self.assert_bad_kernel_launch(1, stream_id(3), read_only=[tensor_id(1)])
202
def test_branch_sync(self):
203
# Tests that two streams can read after both waiting for a third, but they
204
# cannot write without further synchronization.
206
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
207
self.handler._handle_event_record(event_id(0), stream_id(1))
208
self.handler._handle_event_wait(event_id(0), stream_id(2))
209
self.handler._handle_event_wait(event_id(0), stream_id(3))
210
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
211
self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
213
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
215
def test_chain_sync(self):
218
self.assert_good_kernel_launch(stream_id(0), read_only=[tensor_id(1)])
219
for i in range(iterations):
220
self.handler._handle_event_record(event_id(i), stream_id(i))
221
self.handler._handle_event_wait(event_id(i), stream_id(i + 1))
222
self.assert_good_kernel_launch(stream_id(iterations), read_write=[tensor_id(1)])
224
def test_expired_record(self):
225
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
226
self.handler._handle_event_record(event_id(0), stream_id(1))
227
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
228
self.handler._handle_event_wait(event_id(0), stream_id(2))
230
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
232
def test_deleted_record(self):
233
for should_delete, should_create in [
239
with self.subTest(should_delete=should_delete, should_create=should_create):
240
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
241
self.handler._handle_event_record(event_id(0), stream_id(1))
244
self.handler._handle_event_deletion(event_id(0))
246
self.handler._handle_event_creation(event_id(0))
248
self.handler._handle_event_wait(event_id(0), stream_id(2))
249
self.assert_bad_kernel_launch(
250
1, stream_id(2), read_write=[tensor_id(1)]
253
def test_all_reads_checked_failing(self):
255
for i in range(1, iterations):
256
self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
257
self.handler._handle_event_record(event_id(i), stream_id(i))
259
for i in range(1, iterations):
260
self.handler._handle_event_wait(event_id(i), stream_id(0))
262
self.assert_good_kernel_launch(stream_id(iterations), read_only=[tensor_id(1)])
263
self.handler._handle_event_record(event_id(iterations), stream_id(i))
265
# Does not synchronize with the last read.
266
self.assert_bad_kernel_launch(1, stream_id(0), read_write=[tensor_id(1)])
268
def test_all_reads_checked_passing(self):
270
for i in range(1, iterations):
271
self.assert_good_kernel_launch(stream_id(i), read_only=[tensor_id(1)])
272
self.handler._handle_event_record(event_id(i), stream_id(i))
274
for i in range(1, iterations):
275
self.handler._handle_event_wait(event_id(i), stream_id(0))
277
self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
279
def test_multiple_errors(self):
281
self.assert_good_kernel_launch(
282
stream_id(0), read_write=[tensor_id(i) for i in range(iterations)]
284
self.assert_bad_kernel_launch(
287
read_write=[tensor_id(i) for i in range(iterations)],
290
def test_correct_state_merging(self):
291
# Tests that after waiting for an event, a stream's state is indeed set
292
# to the pointwise maximum of its old state and the recorded state.
294
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
295
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
296
self.handler._handle_event_record(event_id(1), stream_id(1))
297
self.handler._handle_event_record(event_id(2), stream_id(2))
299
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
300
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(2)])
301
self.handler._handle_event_wait(event_id(1), stream_id(2))
302
self.handler._handle_event_wait(event_id(2), stream_id(1))
304
self.handler._handle_event_record(event_id(3), stream_id(2))
305
self.handler._handle_event_wait(event_id(3), stream_id(1))
306
self.assert_good_kernel_launch(
307
stream_id(1), read_write=[tensor_id(1), tensor_id(2)]
310
def test_record_override(self):
311
self.assert_good_kernel_launch(stream_id(1), read_only=[tensor_id(1)])
312
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(2)])
313
self.handler._handle_event_record(event_id(1), stream_id(1))
314
self.handler._handle_event_record(event_id(1), stream_id(2))
316
self.handler._handle_event_wait(event_id(1), stream_id(3))
317
self.assert_bad_kernel_launch(1, stream_id(3), read_write=[tensor_id(1)])
319
def test_multiple_wait(self):
320
# Tests that a wait operation can be performed multiple times on the same event
321
# by different streams.
323
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
324
self.handler._handle_event_record(event_id(1), stream_id(1))
325
self.handler._handle_event_wait(event_id(1), stream_id(2))
326
self.handler._handle_event_wait(event_id(1), stream_id(3))
328
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
329
self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
331
def test_device_synchronize(self):
332
# Tests that a device synchronization does correctly cause all streams
333
# to synchronize with each other.
336
for i in range(1, iterations):
337
self.assert_good_kernel_launch(stream_id(i), read_write=[tensor_id(i)])
339
self.handler._handle_device_synchronization()
340
self.assert_good_kernel_launch(
341
stream_id(0), read_write=[tensor_id(i) for i in range(1, iterations)]
344
def test_device_synchronization_expired(self):
345
# Tests that a device synchronization is a one-time synchronization.
346
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
347
self.handler._handle_device_synchronization()
348
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
350
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(1)])
352
def test_new_stream_is_synchronized(self):
353
# Tests that after synchronizing operations with the host, any newly created
354
# stream is guaranteed to be synchronized with them as well.
356
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
357
self.handler._handle_device_synchronization()
358
self.handler._handle_stream_creation(stream_id(2))
359
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
361
def test_stream_synchronize(self):
362
# Tests that a stream synchronization does correctly cause all streams to wait
363
# for one specific stream, but does not synchronize all streams with each other.
365
self.assert_good_kernel_launch(stream_id(0), read_write=[tensor_id(1)])
366
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
367
self.handler._handle_stream_synchronization(stream_id(0))
369
self.assert_good_kernel_launch(stream_id(2), read_only=[tensor_id(1)])
370
self.assert_good_kernel_launch(stream_id(3), read_only=[tensor_id(1)])
371
self.assert_bad_kernel_launch(1, stream_id(4), read_only=[tensor_id(2)])
373
def test_event_synchronize(self):
374
# Tests that an event synchronization does correctly cause all streams to wait
375
# for a recorded event, but does not guarantee synchronization with the current
376
# state of the stream that recorded the event.
378
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(1)])
379
self.handler._handle_event_record(event_id(1), stream_id(1))
380
self.assert_good_kernel_launch(stream_id(1), read_write=[tensor_id(2)])
382
self.handler._handle_event_synchronization(event_id(1))
383
self.assert_good_kernel_launch(stream_id(2), read_write=[tensor_id(1)])
384
self.assert_bad_kernel_launch(1, stream_id(2), read_write=[tensor_id(2)])
387
class TestMessages(TestCase):
389
self.handler = csan.EventHandler()
391
def test_ensure_exists(self):
395
self.handler._handle_event_deletion,
396
f"Found Event with id: {ARG}, but no matching event "
397
"creation in the trace. Backfilling the trace now. "
398
"Perhaps the sanitizer was enabled after some torch operations?",
401
self.handler._handle_memory_deallocation,
402
f"Found tensor with pointer: {ARG}, but no matching tensor "
403
"allocation in the trace. Backfilling the trace now. "
404
"Perhaps the sanitizer was enabled after some torch operations?",
407
with self.subTest(func=func, out=out):
408
with self.assertLogs() as captured:
410
self.assertEqual(captured.records[0].getMessage(), out)
412
def test_ensure_does_not_exist(self):
414
self.handler._handle_event_creation(ARG)
415
self.handler._handle_stream_creation(ARG)
418
self.handler._handle_event_creation,
419
"Found duplicate event creation in the trace for event with "
420
f"id: {ARG}. Assuming the trace for event deletion wasn't caught "
421
"and backfilling it now. "
422
"Perhaps the sanitizer was enabled after some torch operations?",
425
self.handler._handle_stream_creation,
426
"Found duplicate Stream creation in the trace for Stream with "
427
f"id: {ARG}. PyTorch Streams are only created once, so this "
428
"trace entry is ignored.",
431
with self.subTest(func=func, out=out):
432
with self.assertLogs() as captured:
434
self.assertEqual(captured.records[0].getMessage(), out)
436
def test_error_message(self):
437
current_access = csan.Access(
438
type=csan.AccessType.WRITE,
444
stack_trace=traceback.StackSummary.from_list(
445
[("file", 0, "name", "trace a")]
448
previous_access = csan.Access(
449
type=csan.AccessType.READ,
455
stack_trace=traceback.StackSummary.from_list(
456
[("file", 0, "name", "trace b")]
459
error = csan.UnsynchronizedAccessError(
460
data_ptr=tensor_id(1),
461
allocation_stack_trace=traceback.StackSummary.from_list(
462
[("file", 0, "name", "alloc")]
464
current_access=current_access,
465
previous_access=previous_access,
471
============================
472
CSAN detected a possible data race on tensor with data pointer 1
473
Access by stream 1001 during kernel:
475
writing to argument(s) b, and to the output
477
File "file", line 0, in name
480
Previous access by stream 1000 during kernel:
482
reading from argument(s) a
484
File "file", line 0, in name
487
Tensor was allocated with stack trace:
488
File "file", line 0, in name
495
if __name__ == "__main__":