1
from __future__ import annotations
11
from concurrent.futures import ThreadPoolExecutor
12
from ctypes import byref, c_size_t, c_void_p
13
from multiprocessing.process import BaseProcess
14
from multiprocessing.queues import Queue
28
from torch import multiprocessing
29
from torch._dynamo.testing import rand_strided
31
from torch._inductor import ir
32
from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
35
from torch._inductor.select_algorithm import TritonTemplateCaller
38
from .utils import do_bench
39
from .virtualized import V
41
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
42
EXIT_HANDLER_REGISTERED = False
44
log = logging.getLogger(__name__)
56
@contextlib.contextmanager
57
def set_cuda_visible_device(device: Optional[int]):
59
Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
60
specified single device. If device is None, don't manipulate the environment.
66
current = os.environ.get(CUDA_VISIBLE_DEVICES)
67
os.environ[CUDA_VISIBLE_DEVICES] = str(device)
72
del os.environ[CUDA_VISIBLE_DEVICES]
74
os.environ[CUDA_VISIBLE_DEVICES] = current
80
Abstraction for launching a helper process to benchmark kernels. Spawns
81
the parent process and uses multiprocessing queues to send benchmark
82
requests and return results.
85
device: Optional[int] = None
86
process: Optional[BaseProcess] = None
87
request_queue: Optional[Queue[Any]] = None
88
response_queue: Optional[Queue[Any]] = None
92
request_queue: Queue[Any],
93
response_queue: Queue[Any],
96
Entry point for the child process.
99
"Entering TuningProcess child. Visible devices = %s",
100
os.environ.get(CUDA_VISIBLE_DEVICES),
103
TuningProcess.workloop(request_queue, response_queue)
104
except Exception as ex:
105
log.exception("Exception in TuningProcess: %s", ex)
108
def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
110
Work loop for the benchmarking subprocess.
113
obj = request_queue.get()
117
elif isinstance(obj, Ping):
118
response_queue.put(Pong())
119
elif isinstance(obj, BenchmarkRequest):
120
response_queue.put(obj.benchmark())
122
raise RuntimeError(f"Invalid request type {type(obj)}")
124
def valid(self) -> bool:
126
True if the sub-process has been initialized.
129
self.process is not None
130
and self.request_queue is not None
131
and self.response_queue is not None
134
def clear(self) -> None:
136
Reset to an uninitialized state.
138
self.process = self.request_queue = self.response_queue = None
140
def initialize(self) -> None:
142
Create child process, request/response queues, and do the warm up.
143
Set the environment to make only the provided GPU device visible
150
ctx = multiprocessing.get_context("spawn")
151
self.request_queue = ctx.Queue()
152
self.response_queue = ctx.Queue()
154
self.process = ctx.Process(
155
target=self.process_main,
161
assert self.process is not None
162
with set_cuda_visible_device(self.device):
165
def put(self, obj: Any) -> None:
167
Push a work item to the child process.
171
assert self.request_queue is not None
172
self.request_queue.put(obj)
174
def get(self) -> Any:
176
Get a response from the child process.
178
assert self.process is not None
179
assert self.response_queue is not None
182
return self.response_queue.get(timeout=1.0)
184
status = self.process.exitcode
192
def terminate(self) -> None:
194
Signal the child process to terminate.
197
assert self.process is not None
198
assert self.request_queue is not None
199
self.request_queue.put(None)
201
def wait(self) -> None:
203
Wait for the child process to exit.
205
if self.process is not None:
210
@dataclasses.dataclass
211
class TuningProcessPool:
213
Maintains a pool of TuningProcesses to benchmark kernels in parallel
214
across devices. By default, we create one TuningProcess per device and
215
set the sub-process environment to make only that device visible.
218
processes: Optional[queue.Queue[TuningProcess]] = None
219
executor: Optional[ThreadPoolExecutor] = None
221
def initialize(self) -> None:
223
Start the child processes.
225
assert (self.processes is None) == (self.executor is None)
226
if self.processes is not None:
229
devices = self.get_device_list()
230
log.debug("Sub-process autotune device list: %s", devices)
233
self.processes = queue.Queue()
234
for device in devices:
235
p = TuningProcess(device=device)
238
self.processes.put(p)
241
for p in self.processes.queue:
242
assert isinstance(p.get(), Pong)
247
self.executor = ThreadPoolExecutor(max_workers=len(devices))
251
global EXIT_HANDLER_REGISTERED
252
if not EXIT_HANDLER_REGISTERED:
253
EXIT_HANDLER_REGISTERED = True
256
atexit.register(self.terminate)
258
def get_device_list(self) -> Sequence[Optional[int]]:
260
Gather the list of devices to be used in the pool.
262
if not config.autotune_multi_device:
266
count = torch.cuda.device_count()
269
if CUDA_VISIBLE_DEVICES in os.environ:
270
devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
271
assert len(devices) <= count
274
return list(range(count))
276
def terminate(self) -> None:
278
Signal all child processes to terminate.
280
if self.executor is not None:
281
self.executor.shutdown()
284
if self.processes is not None:
285
for p in self.processes.queue:
287
for p in self.processes.queue:
289
self.processes = None
291
def target(self, choice: TritonTemplateCaller) -> float:
293
Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
294
remove it from the queue, execute the benchmark in that subprocess, and return
295
the TuningProcess to the queue.
297
assert choice.bmreq is not None
298
assert self.processes is not None
300
process = self.processes.get()
301
process.put(choice.bmreq)
306
f"Failed to benchmark choice '{choice}'. It will be ignored. "
307
"Please debug the root cause in case the choice can bring perf gains."
312
self.processes.put(process)
316
choices: List[TritonTemplateCaller],
317
) -> Dict[TritonTemplateCaller, float]:
319
Benchmark each choice in a separate process.
321
assert self.processes is not None, "Tuning process pool is not initialized"
322
assert self.executor is not None
328
for choice, result in zip(choices, self.executor.map(self.target, choices)):
329
results[choice] = result
334
tuning_pool = TuningProcessPool()
337
LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
340
@dataclasses.dataclass
344
sizes: torch._prims_common.ShapeType
345
strides: torch._prims_common.StrideType
350
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
351
) -> Union[TensorMeta, List[TensorMeta]]:
352
if isinstance(irnodes, Sequence):
353
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
354
assert all(isinstance(x, TensorMeta) for x in result)
358
if isinstance(node, ir.Layout):
359
node = ir.Buffer("fake", node)
361
dtype = node.get_dtype()
362
assert dtype is not None
365
device=node.get_device(),
367
sizes=V.graph.sizevars.size_hints(
369
fallback=config.unbacked_symint_fallback,
371
strides=V.graph.sizevars.size_hints(
373
fallback=config.unbacked_symint_fallback,
375
offset=V.graph.sizevars.size_hint(
376
node.get_layout().offset,
377
fallback=config.unbacked_symint_fallback,
381
def to_tensor(self) -> torch.Tensor:
387
extra_size=self.offset,
391
@dataclasses.dataclass
392
class BenchmarkRequest:
394
Only handle triton template benchmark for now. The extern kernel benchmark
395
can be done inside the same process since they usually don't cause crash.
397
Important: Instances of this class and subclasses have to be serializable
398
across process boundaries. Do not put CUDA Tensors in here!
404
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
405
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
406
extra_args: Iterable[Any],
409
self.kernel_name = kernel_name
411
if isinstance(input_tensor_meta, TensorMeta):
412
input_tensor_meta = [input_tensor_meta]
413
self.input_tensor_meta = input_tensor_meta
415
if isinstance(output_tensor_meta, (tuple, list)):
416
assert len(output_tensor_meta) == 1
417
output_tensor_meta = output_tensor_meta[0]
418
self.output_tensor_meta = output_tensor_meta
420
self.extra_args = extra_args
423
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
424
) -> Callable[[], None]:
425
raise NotImplementedError()
427
def cleanup_run_fn(self) -> None:
432
*input_tensors: torch.Tensor,
433
output_tensor: Optional[torch.Tensor] = None,
435
debug = log.isEnabledFor(logging.DEBUG)
437
start_ts = time.time()
440
if output_tensor is None:
441
assert len(input_tensors) == 0
442
input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
443
output_tensor = self.output_tensor_meta.to_tensor()
446
create_tensor_elapse = time.time() - start_ts
447
start_ts = time.time()
449
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
452
load_elapse = time.time() - start_ts
453
start_ts = time.time()
456
torch.cuda.synchronize()
459
bench_elapse = time.time() - start_ts
461
"InChildProcess %s: load %f, create tensor %f, bench %f",
464
create_tensor_elapse,
467
self.cleanup_run_fn()
471
class TestBenchmarkRequest(BenchmarkRequest):
473
Supports unit testing. Defined in this file so that the TuningProcess
474
sub-process knows how to unpickle these objects.
477
def __init__(self, value: Optional[float] = None) -> None:
481
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
483
if self.value is None:
484
raise Exception("Failed to run")
488
class TritonBenchmarkRequest(BenchmarkRequest):
495
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
496
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
497
extra_args: Iterable[Any],
499
module_cache_key: str,
503
matrix_instr_nonkdim: int = 0,
505
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
506
self.module_path = module_path
507
self.module_cache_key = module_cache_key
509
self.num_stages = num_stages
510
self.num_warps = num_warps
511
self.matrix_instr_nonkdim = matrix_instr_nonkdim
514
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
515
) -> Callable[[], None]:
516
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
518
"benchmark module key: %s, path: %s",
519
self.module_cache_key,
523
run_method = getattr(mod, self.kernel_name).run
524
extra_args = list(self.extra_args)
531
if "warmup" in inspect.signature(run_method).parameters:
532
warmup_arg["warmup"] = False
534
if torch.version.hip and self.matrix_instr_nonkdim != 0:
535
return functools.partial(
542
num_stages=self.num_stages,
543
num_warps=self.num_warps,
544
matrix_instr_nonkdim=self.matrix_instr_nonkdim,
547
return functools.partial(
554
num_stages=self.num_stages,
555
num_warps=self.num_warps,
558
def __str__(self) -> str:
559
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
562
class CUDABenchmarkRequest(BenchmarkRequest):
569
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
570
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
571
extra_args: Iterable[Any],
574
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
575
self.source_code = source_code
576
self.workspace_size: int = 0
577
self.workspace: Optional[torch.Tensor] = None
578
self.DLL: Optional[DLLWrapper] = None
579
self.hash_key: str = ""
580
self.source_file: str = ""
581
self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
583
def precompile(self):
586
log.debug("Precompiling %s", self)
587
CUDACodeCache.load(self.source_code, "so")
588
log.debug("Done precompiling %s", self)
591
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
592
) -> Callable[[], None]:
593
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
594
self.source_code, "so"
597
c_void_p(tensor.data_ptr())
598
for tensor in list(input_tensors) + [output_tensor]
601
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
609
run_method = getattr(self.DLL, self.kernel_name)
610
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
613
c_workspace_size = c_size_t()
623
self.workspace_size = c_workspace_size.value
625
assert self.workspace_size == 0, (
626
"Things need to be fixed to support non-zero workspace_size: "
627
"1) max autotune cache needs to store workspace size; "
628
"2) memory allocation needs to allocate / deallocate workspace correctly; "
632
return functools.partial(
641
def cleanup_run_fn(self) -> None:
642
if self.DLL is not None:
644
self.workspace = None
646
def __str__(self) -> str:
647
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
650
def benchmark_in_sub_process(
651
choices: List[TritonTemplateCaller],
652
) -> Dict[TritonTemplateCaller, float]:
654
Do benchmarking in a subprocess and return the perf number (latency).
656
return tuning_pool.benchmark(choices)