pytorch

Форк
0
/
autotune_process.py 
656 строк · 20.5 Кб
1
from __future__ import annotations
2

3
import contextlib
4
import dataclasses
5
import functools
6
import logging
7
import os
8
import queue
9
import time
10
import warnings
11
from concurrent.futures import ThreadPoolExecutor
12
from ctypes import byref, c_size_t, c_void_p
13
from multiprocessing.process import BaseProcess
14
from multiprocessing.queues import Queue
15
from typing import (
16
    Any,
17
    Callable,
18
    Dict,
19
    Iterable,
20
    List,
21
    Optional,
22
    Sequence,
23
    TYPE_CHECKING,
24
    Union,
25
)
26

27
import torch
28
from torch import multiprocessing
29
from torch._dynamo.testing import rand_strided
30

31
from torch._inductor import ir
32
from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
33

34
if TYPE_CHECKING:
35
    from torch._inductor.select_algorithm import TritonTemplateCaller
36

37
from . import config
38
from .utils import do_bench
39
from .virtualized import V
40

41
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
42
EXIT_HANDLER_REGISTERED = False
43

44
log = logging.getLogger(__name__)
45

46

47
# Used to synchronize between parent and child processes
48
class Ping:
49
    pass
50

51

52
class Pong:
53
    pass
54

55

56
@contextlib.contextmanager
57
def set_cuda_visible_device(device: Optional[int]):
58
    """
59
    Context manager to set the CUDA_VISIBLE_DEVICES environment variable to the
60
    specified single device. If device is None, don't manipulate the environment.
61
    """
62
    if device is None:
63
        yield
64
        return
65

66
    current = os.environ.get(CUDA_VISIBLE_DEVICES)
67
    os.environ[CUDA_VISIBLE_DEVICES] = str(device)
68
    try:
69
        yield
70
    finally:
71
        if current is None:
72
            del os.environ[CUDA_VISIBLE_DEVICES]
73
        else:
74
            os.environ[CUDA_VISIBLE_DEVICES] = current
75

76

77
@dataclasses.dataclass
78
class TuningProcess:
79
    """
80
    Abstraction for launching a helper process to benchmark kernels. Spawns
81
    the parent process and uses multiprocessing queues to send benchmark
82
    requests and return results.
83
    """
84

85
    device: Optional[int] = None
86
    process: Optional[BaseProcess] = None
87
    request_queue: Optional[Queue[Any]] = None
88
    response_queue: Optional[Queue[Any]] = None
89

90
    @staticmethod
91
    def process_main(
92
        request_queue: Queue[Any],
93
        response_queue: Queue[Any],
94
    ) -> None:
95
        """
96
        Entry point for the child process.
97
        """
98
        log.debug(
99
            "Entering TuningProcess child. Visible devices = %s",
100
            os.environ.get(CUDA_VISIBLE_DEVICES),
101
        )
102
        try:
103
            TuningProcess.workloop(request_queue, response_queue)
104
        except Exception as ex:
105
            log.exception("Exception in TuningProcess: %s", ex)
106

107
    @staticmethod
108
    def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None:
109
        """
110
        Work loop for the benchmarking subprocess.
111
        """
112
        while True:
113
            obj = request_queue.get()
114

115
            if obj is None:
116
                break  # None is a sentinel for the child to terminate
117
            elif isinstance(obj, Ping):
118
                response_queue.put(Pong())
119
            elif isinstance(obj, BenchmarkRequest):
120
                response_queue.put(obj.benchmark())
121
            else:
122
                raise RuntimeError(f"Invalid request type {type(obj)}")
123

124
    def valid(self) -> bool:
125
        """
126
        True if the sub-process has been initialized.
127
        """
128
        return (
129
            self.process is not None
130
            and self.request_queue is not None
131
            and self.response_queue is not None
132
        )
133

134
    def clear(self) -> None:
135
        """
136
        Reset to an uninitialized state.
137
        """
138
        self.process = self.request_queue = self.response_queue = None
139

140
    def initialize(self) -> None:
141
        """
142
        Create child process, request/response queues, and do the warm up.
143
        Set the environment to make only the provided GPU device visible
144
        to the process.
145
        """
146
        if self.valid():
147
            return
148

149
        # cuda runtime does not work with "fork", use "spawn" to start processes.
150
        ctx = multiprocessing.get_context("spawn")
151
        self.request_queue = ctx.Queue()
152
        self.response_queue = ctx.Queue()
153

154
        self.process = ctx.Process(
155
            target=self.process_main,
156
            args=(
157
                self.request_queue,
158
                self.response_queue,
159
            ),
160
        )
161
        assert self.process is not None
162
        with set_cuda_visible_device(self.device):
163
            self.process.start()
164

165
    def put(self, obj: Any) -> None:
166
        """
167
        Push a work item to the child process.
168
        """
169
        # In case of a prior crash, ensure the subprocess is running
170
        self.initialize()
171
        assert self.request_queue is not None
172
        self.request_queue.put(obj)
173

174
    def get(self) -> Any:
175
        """
176
        Get a response from the child process.
177
        """
178
        assert self.process is not None
179
        assert self.response_queue is not None
180
        while True:
181
            try:
182
                return self.response_queue.get(timeout=1.0)
183
            except queue.Empty:
184
                status = self.process.exitcode
185
                if status is None:
186
                    # child process is still running
187
                    continue
188
                # child process crashed
189
                self.clear()
190
                raise
191

192
    def terminate(self) -> None:
193
        """
194
        Signal the child process to terminate.
195
        """
196
        if self.valid():
197
            assert self.process is not None
198
            assert self.request_queue is not None
199
            self.request_queue.put(None)
200

201
    def wait(self) -> None:
202
        """
203
        Wait for the child process to exit.
204
        """
205
        if self.process is not None:
206
            self.process.join()
207
            self.clear()
208

209

210
@dataclasses.dataclass
211
class TuningProcessPool:
212
    """
213
    Maintains a pool of TuningProcesses to benchmark kernels in parallel
214
    across devices. By default, we create one TuningProcess per device and
215
    set the sub-process environment to make only that device visible.
216
    """
217

218
    processes: Optional[queue.Queue[TuningProcess]] = None
219
    executor: Optional[ThreadPoolExecutor] = None
220

221
    def initialize(self) -> None:
222
        """
223
        Start the child processes.
224
        """
225
        assert (self.processes is None) == (self.executor is None)
226
        if self.processes is not None:
227
            return
228

229
        devices = self.get_device_list()
230
        log.debug("Sub-process autotune device list: %s", devices)
231

232
        # Launch the child processes and push a msg to "warm up"
233
        self.processes = queue.Queue()
234
        for device in devices:
235
            p = TuningProcess(device=device)
236
            p.initialize()
237
            p.put(Ping())
238
            self.processes.put(p)
239

240
        # Wait for the initialization to finish
241
        for p in self.processes.queue:
242
            assert isinstance(p.get(), Pong)
243

244
        # Use a thread pool to manage distributing work to the subprocesses.
245
        # Threads block on an available process, so it makes sense to match
246
        # the number of threads with the number of devices.
247
        self.executor = ThreadPoolExecutor(max_workers=len(devices))
248

249
        # Register the exit handler for the parent process so it will terminate
250
        # the child processes.
251
        global EXIT_HANDLER_REGISTERED
252
        if not EXIT_HANDLER_REGISTERED:
253
            EXIT_HANDLER_REGISTERED = True
254
            import atexit
255

256
            atexit.register(self.terminate)
257

258
    def get_device_list(self) -> Sequence[Optional[int]]:
259
        """
260
        Gather the list of devices to be used in the pool.
261
        """
262
        if not config.autotune_multi_device:
263
            # Don't use multiple devices
264
            return [None]
265

266
        count = torch.cuda.device_count()
267

268
        # If the user specified the visible devices in the env, use those.
269
        if CUDA_VISIBLE_DEVICES in os.environ:
270
            devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")]
271
            assert len(devices) <= count
272
            return devices
273

274
        return list(range(count))
275

276
    def terminate(self) -> None:
277
        """
278
        Signal all child processes to terminate.
279
        """
280
        if self.executor is not None:
281
            self.executor.shutdown()
282
            self.executor = None
283

284
        if self.processes is not None:
285
            for p in self.processes.queue:
286
                p.terminate()
287
            for p in self.processes.queue:
288
                p.wait()
289
            self.processes = None
290

291
    def target(self, choice: TritonTemplateCaller) -> float:
292
        """
293
        Entry point for the thread-pool helper threads: Wait for an open TuningProcess,
294
        remove it from the queue, execute the benchmark in that subprocess, and return
295
        the TuningProcess to the queue.
296
        """
297
        assert choice.bmreq is not None
298
        assert self.processes is not None
299

300
        process = self.processes.get()
301
        process.put(choice.bmreq)
302
        try:
303
            return process.get()
304
        except queue.Empty:
305
            warnings.warn(
306
                f"Failed to benchmark choice '{choice}'. It will be ignored. "
307
                "Please debug the root cause in case the choice can bring perf gains."
308
            )
309
            # set to INF so this choice will be ignored
310
            return float("inf")
311
        finally:
312
            self.processes.put(process)
313

314
    def benchmark(
315
        self,
316
        choices: List[TritonTemplateCaller],
317
    ) -> Dict[TritonTemplateCaller, float]:
318
        """
319
        Benchmark each choice in a separate process.
320
        """
321
        assert self.processes is not None, "Tuning process pool is not initialized"
322
        assert self.executor is not None
323

324
        results = {}
325

326
        # Use a ThreadExecutorPool to spread the work across the subprocesses and
327
        # to grab subprocesses as soon as they're free.
328
        for choice, result in zip(choices, self.executor.map(self.target, choices)):
329
            results[choice] = result
330

331
        return results
332

333

334
tuning_pool = TuningProcessPool()
335

336

337
LayoutOrBuffer = Union[ir.Layout, ir.Buffer]
338

339

340
@dataclasses.dataclass
341
class TensorMeta:
342
    device: torch.device
343
    dtype: torch.dtype
344
    sizes: torch._prims_common.ShapeType
345
    strides: torch._prims_common.StrideType
346
    offset: int
347

348
    @classmethod
349
    def from_irnodes(
350
        cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
351
    ) -> Union[TensorMeta, List[TensorMeta]]:
352
        if isinstance(irnodes, Sequence):
353
            result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
354
            assert all(isinstance(x, TensorMeta) for x in result)
355
            return result
356

357
        node = irnodes
358
        if isinstance(node, ir.Layout):
359
            node = ir.Buffer("fake", node)
360

361
        dtype = node.get_dtype()
362
        assert dtype is not None
363

364
        return TensorMeta(
365
            device=node.get_device(),
366
            dtype=dtype,
367
            sizes=V.graph.sizevars.size_hints(
368
                node.get_size(),
369
                fallback=config.unbacked_symint_fallback,
370
            ),
371
            strides=V.graph.sizevars.size_hints(
372
                node.get_stride(),
373
                fallback=config.unbacked_symint_fallback,
374
            ),
375
            offset=V.graph.sizevars.size_hint(
376
                node.get_layout().offset,
377
                fallback=config.unbacked_symint_fallback,
378
            ),
379
        )
380

381
    def to_tensor(self) -> torch.Tensor:
382
        return rand_strided(
383
            self.sizes,
384
            self.strides,
385
            device=self.device,
386
            dtype=self.dtype,
387
            extra_size=self.offset,
388
        )
389

390

391
@dataclasses.dataclass
392
class BenchmarkRequest:
393
    """
394
    Only handle triton template benchmark for now. The extern kernel benchmark
395
    can be done inside the same process since they usually don't cause crash.
396

397
    Important: Instances of this class and subclasses have to be serializable
398
    across process boundaries. Do not put CUDA Tensors in here!
399
    """
400

401
    def __init__(
402
        self,
403
        kernel_name: str,
404
        input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
405
        output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
406
        extra_args: Iterable[Any],
407
    ):
408
        # the kernel name defined in the module
409
        self.kernel_name = kernel_name
410

411
        if isinstance(input_tensor_meta, TensorMeta):
412
            input_tensor_meta = [input_tensor_meta]
413
        self.input_tensor_meta = input_tensor_meta
414

415
        if isinstance(output_tensor_meta, (tuple, list)):
416
            assert len(output_tensor_meta) == 1
417
            output_tensor_meta = output_tensor_meta[0]
418
        self.output_tensor_meta = output_tensor_meta
419

420
        self.extra_args = extra_args
421

422
    def make_run_fn(
423
        self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
424
    ) -> Callable[[], None]:
425
        raise NotImplementedError()
426

427
    def cleanup_run_fn(self) -> None:
428
        pass
429

430
    def benchmark(
431
        self,
432
        *input_tensors: torch.Tensor,
433
        output_tensor: Optional[torch.Tensor] = None,
434
    ) -> float:
435
        debug = log.isEnabledFor(logging.DEBUG)
436
        if debug:
437
            start_ts = time.time()
438

439
        # create args and out tensor
440
        if output_tensor is None:
441
            assert len(input_tensors) == 0
442
            input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta)
443
            output_tensor = self.output_tensor_meta.to_tensor()
444

445
        if debug:
446
            create_tensor_elapse = time.time() - start_ts  # type: ignore[possibly-undefined]
447
            start_ts = time.time()
448

449
        fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
450

451
        if debug:
452
            load_elapse = time.time() - start_ts  # type: ignore[possibly-undefined]
453
            start_ts = time.time()
454

455
        out = do_bench(fn)
456
        torch.cuda.synchronize()  # shake out any CUDA errors
457

458
        if debug:
459
            bench_elapse = time.time() - start_ts  # type: ignore[possibly-undefined]
460
            log.debug(
461
                "InChildProcess %s: load %f, create tensor %f, bench %f",
462
                str(self),
463
                load_elapse,  # type: ignore[possibly-undefined]
464
                create_tensor_elapse,  # type: ignore[possibly-undefined]
465
                bench_elapse,
466
            )
467
        self.cleanup_run_fn()
468
        return out
469

470

471
class TestBenchmarkRequest(BenchmarkRequest):
472
    """
473
    Supports unit testing. Defined in this file so that the TuningProcess
474
    sub-process knows how to unpickle these objects.
475
    """
476

477
    def __init__(self, value: Optional[float] = None) -> None:
478
        self.value = value
479

480
    def benchmark(
481
        self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
482
    ) -> float:
483
        if self.value is None:
484
            raise Exception("Failed to run")
485
        return self.value
486

487

488
class TritonBenchmarkRequest(BenchmarkRequest):
489
    # Important: Instances of this class have to be serializable
490
    # across process boundaries. Do not put CUDA Tensors in here!
491

492
    def __init__(
493
        self,
494
        kernel_name: str,
495
        input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
496
        output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
497
        extra_args: Iterable[Any],
498
        module_path: str,  # the path of the module defining the triton kernel
499
        module_cache_key: str,
500
        grid: List[int],
501
        num_stages: int,
502
        num_warps: int,
503
        matrix_instr_nonkdim: int = 0,  # only used for hip to choose the shape of mfma instruction.
504
    ):
505
        super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
506
        self.module_path = module_path
507
        self.module_cache_key = module_cache_key
508
        self.grid = grid
509
        self.num_stages = num_stages
510
        self.num_warps = num_warps
511
        self.matrix_instr_nonkdim = matrix_instr_nonkdim
512

513
    def make_run_fn(
514
        self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
515
    ) -> Callable[[], None]:
516
        mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
517
        log.debug(
518
            "benchmark module key: %s, path: %s",
519
            self.module_cache_key,
520
            self.module_path,
521
        )
522

523
        run_method = getattr(mod, self.kernel_name).run
524
        extra_args = list(self.extra_args)
525

526
        # Newer version of triton add warmup argument to JITFunction.run.
527
        # This code handles backward-compatibility.
528
        warmup_arg = {}
529
        import inspect
530

531
        if "warmup" in inspect.signature(run_method).parameters:
532
            warmup_arg["warmup"] = False
533

534
        if torch.version.hip and self.matrix_instr_nonkdim != 0:
535
            return functools.partial(
536
                run_method,
537
                *input_tensors,
538
                output_tensor,
539
                *self.extra_args,
540
                grid=self.grid,
541
                **warmup_arg,
542
                num_stages=self.num_stages,
543
                num_warps=self.num_warps,
544
                matrix_instr_nonkdim=self.matrix_instr_nonkdim,
545
            )
546
        else:
547
            return functools.partial(
548
                run_method,
549
                *input_tensors,
550
                output_tensor,
551
                *self.extra_args,
552
                grid=self.grid,
553
                **warmup_arg,
554
                num_stages=self.num_stages,
555
                num_warps=self.num_warps,
556
            )
557

558
    def __str__(self) -> str:
559
        return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
560

561

562
class CUDABenchmarkRequest(BenchmarkRequest):
563
    # Important: Instances of this class have to be serializable
564
    # across process boundaries. Do not put CUDA Tensors in here!
565

566
    def __init__(
567
        self,
568
        kernel_name: str,
569
        input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
570
        output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
571
        extra_args: Iterable[Any],
572
        source_code: str,
573
    ):
574
        super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
575
        self.source_code = source_code
576
        self.workspace_size: int = 0
577
        self.workspace: Optional[torch.Tensor] = None
578
        self.DLL: Optional[DLLWrapper] = None
579
        self.hash_key: str = ""
580
        self.source_file: str = ""
581
        self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
582

583
    def precompile(self):
584
        # Prepopulate CUDACodeCache
585
        # may happen in separate Threadpool
586
        log.debug("Precompiling %s", self)
587
        CUDACodeCache.load(self.source_code, "so")
588
        log.debug("Done precompiling %s", self)
589

590
    def make_run_fn(
591
        self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
592
    ) -> Callable[[], None]:
593
        self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
594
            self.source_code, "so"
595
        )
596
        args = [
597
            c_void_p(tensor.data_ptr())
598
            for tensor in list(input_tensors) + [output_tensor]
599
        ]
600
        log.debug(
601
            "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
602
            self.kernel_name,
603
            self.source_file,
604
            self.hash_key,
605
            self.DLL,
606
            args,
607
            self.extra_args,
608
        )
609
        run_method = getattr(self.DLL, self.kernel_name)
610
        stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
611

612
        # Retrieve workspace_size and initialize workspace.
613
        c_workspace_size = c_size_t()
614
        run_method(
615
            *args,  # input ptrs and output ptrs
616
            *self.extra_args,
617
            byref(
618
                c_workspace_size
619
            ),  # set workspace size ptr to retrieve workspace size
620
            None,  # null workspace ptr
621
            stream_ptr,
622
        )
623
        self.workspace_size = c_workspace_size.value
624
        # TODO: Support non-zero workspace_size.
625
        assert self.workspace_size == 0, (
626
            "Things need to be fixed to support non-zero workspace_size: "
627
            "1) max autotune cache needs to store workspace size; "
628
            "2) memory allocation needs to allocate / deallocate workspace correctly; "
629
        )
630

631
        # Generate partial function.
632
        return functools.partial(
633
            run_method,
634
            *args,
635
            *self.extra_args,
636
            None,  # null workspace size ptr
637
            None,  # set workspace ptr, TODO: update it to a real ptr if workspace_size > 0
638
            stream_ptr,
639
        )
640

641
    def cleanup_run_fn(self) -> None:
642
        if self.DLL is not None:
643
            self.DLL.close()
644
        self.workspace = None
645

646
    def __str__(self) -> str:
647
        return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
648

649

650
def benchmark_in_sub_process(
651
    choices: List[TritonTemplateCaller],
652
) -> Dict[TritonTemplateCaller, float]:
653
    """
654
    Do benchmarking in a subprocess and return the perf number (latency).
655
    """
656
    return tuning_pool.benchmark(choices)
657

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.