pytorch

Форк
0
/
codecache.py 
2664 строки · 93.6 Кб
1
from __future__ import annotations
2

3
import base64
4
import copyreg
5
import dataclasses
6
import functools
7
import hashlib
8
import importlib
9
import io
10
import json
11
import logging
12
import multiprocessing
13
import os
14
import pathlib
15
import pickle
16
import pkgutil
17
import platform
18
import re
19
import shlex
20
import shutil
21
import signal
22
import subprocess
23
import sys
24
import sysconfig
25
import tempfile
26
import textwrap
27
import threading
28
import warnings
29
import weakref
30
from bisect import bisect_right
31
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
32
from copy import copy
33
from ctypes import c_void_p, cdll, CDLL
34
from dataclasses import field
35
from functools import partial
36
from pathlib import Path
37
from threading import Thread
38
from time import sleep, time
39
from types import ModuleType
40
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
41

42
import torch
43

44
from torch._dynamo.device_interface import (
45
    get_interface_for_device,
46
    get_registered_device_interfaces,
47
)
48
from torch._dynamo.utils import counters, dynamo_timed
49
from torch._inductor import config, exc
50
from torch._inductor.codegen.cuda import cuda_env
51
from torch._inductor.utils import cache_dir, developer_warning, is_linux
52
from torch._subclasses.fake_tensor import (
53
    extract_tensor_metadata,
54
    FakeTensor,
55
    TensorMetadata,
56
)
57
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
58

59
if TYPE_CHECKING:
60
    from torch._inductor.graph import GraphLowering
61
    from torch._inductor.select_algorithm import ChoiceCaller
62

63
from torch.hub import _Faketqdm, tqdm
64

65
_HERE = os.path.abspath(__file__)
66
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
67
_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld")
68

69
if config.is_fbcode():
70
    from triton.fb import build_paths
71
    from triton.fb.build import _run_build_command
72

73
    from torch._inductor.fb.utils import (
74
        log_global_cache_errors,
75
        log_global_cache_stats,
76
        log_global_cache_vals,
77
        use_global_cache,
78
    )
79
else:
80

81
    def log_global_cache_errors(*args, **kwargs):
82
        pass
83

84
    def log_global_cache_stats(*args, **kwargs):
85
        pass
86

87
    def log_global_cache_vals(*args, **kwargs):
88
        pass
89

90
    def use_global_cache() -> bool:
91
        return False
92

93

94
LOCK_TIMEOUT = 600
95

96
# timing metrics for time spent in the compilation
97
_cumulative_compile_time = 0.0
98
_t0: Optional[float] = None
99

100

101
def _compile_start() -> None:
102
    global _t0
103
    if _t0 is None:
104
        _t0 = time()
105

106

107
def _compile_end() -> None:
108
    global _cumulative_compile_time, _t0
109
    if _t0 is not None:
110
        t1 = time()
111
        _cumulative_compile_time += t1 - _t0
112
        _t0 = None
113
        # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
114

115

116
log = logging.getLogger(__name__)
117

118

119
def cpp_wrapper_cache_dir(name: str) -> str:
120
    cu_str = (
121
        "cpu"
122
        if torch.version.cuda is None
123
        else f'cu{torch.version.cuda.replace(".", "")}'
124
    )
125
    python_version = f"py{sys.version_info.major}{sys.version_info.minor}"
126
    build_folder = f"{python_version}_{cu_str}"
127

128
    cpp_wrapper_dir = os.path.join(cache_dir(), build_folder)
129
    cpp_wrapper_build_directory = os.path.join(cpp_wrapper_dir, name)
130
    os.makedirs(cpp_wrapper_build_directory, exist_ok=True)
131
    return cpp_wrapper_build_directory
132

133

134
def get_cpp_wrapper_cubin_path_name():
135
    return "cubin_path" if torch.version.hip is None else "hsaco_path"
136

137

138
class CacheBase:
139
    @staticmethod
140
    @functools.lru_cache(None)
141
    def get_system() -> Dict[str, Any]:
142
        try:
143
            import triton
144

145
            triton_version = triton.__version__
146
        except ModuleNotFoundError:
147
            triton_version = None
148

149
        try:
150
            system: Dict[str, Any] = {
151
                "device": {
152
                    "name": torch.cuda.get_device_properties(
153
                        torch.cuda.current_device()
154
                    ).name,
155
                },
156
                "version": {
157
                    "cuda": torch.version.cuda,
158
                    "triton": triton_version,
159
                },
160
            }
161
        except (AssertionError, RuntimeError):
162
            # If cuda is not installed, none of the above config is relevant.
163
            system = {}
164

165
        system["hash"] = hashlib.sha256(
166
            json.dumps(system, sort_keys=True).encode("utf-8")
167
        ).hexdigest()
168

169
        return system
170

171
    @staticmethod
172
    @functools.lru_cache(None)
173
    def get_local_cache_path() -> Path:
174
        return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
175

176
    @staticmethod
177
    @functools.lru_cache(None)
178
    def get_global_cache_path() -> Optional[Path]:
179
        return (
180
            Path(os.path.join(config.global_cache_dir, CacheBase.get_system()["hash"]))
181
            if config.global_cache_dir is not None
182
            else None
183
        )
184

185
    def __init__(self) -> None:
186
        if not torch.cuda.is_available():
187
            return
188

189
        self.system = CacheBase.get_system()
190

191
        self.local_cache_path = CacheBase.get_local_cache_path()
192
        self.global_cache_path = CacheBase.get_global_cache_path()
193

194
    def get_local_cache(self) -> Dict[str, Any]:
195
        if not self.local_cache_path.is_file():
196
            return {}
197
        with open(self.local_cache_path) as local_cache_fp:
198
            local_cache = json.load(local_cache_fp)
199
        return local_cache["cache"]
200

201
    def update_local_cache(self, local_cache: Dict[str, Any]) -> None:
202
        if not os.path.exists(self.local_cache_path.parent):
203
            os.makedirs(self.local_cache_path.parent, exist_ok=True)
204

205
        write_atomic(
206
            str(self.local_cache_path),
207
            json.dumps({"system": self.system, "cache": local_cache}, indent=4),
208
        )
209

210

211
class LocalCache(CacheBase):
212
    def lookup(self, *keys: str) -> Optional[Dict[str, Any]]:
213
        cache = self.get_local_cache()
214

215
        sub_cache = cache
216
        for key in keys:
217
            if key in cache:
218
                sub_cache = cache[key]
219
            else:
220
                return None
221

222
        return sub_cache
223

224
    def set_value(self, *keys: str, value: Any) -> None:
225
        cache = self.get_local_cache()
226

227
        sub_cache = cache
228
        for key in keys[0:-1]:
229
            sub_cache.setdefault(key, {})
230
            sub_cache = sub_cache[key]
231
        sub_cache[keys[-1]] = value
232

233
        self.update_local_cache(cache)
234

235

236
class PersistentCache(CacheBase):
237
    @functools.lru_cache(None)
238
    def get_global_cache(self):
239
        if self.global_cache_path is None or not self.global_cache_path.is_file():
240
            return {}
241
        with open(self.global_cache_path) as global_cache_fp:
242
            global_cache = json.load(global_cache_fp)
243
        return global_cache["cache"]
244

245
    def lookup(
246
        self,
247
        choices: List[ChoiceCaller],
248
        op: str,
249
        inputs: str,
250
        benchmark: Callable[[Any], Dict[ChoiceCaller, float]],
251
    ) -> Dict[ChoiceCaller, float]:
252
        """
253
        Check to see if we have benchmarked the given choice callers. For each
254
        choice caller:
255

256
            1. Check global_cache[op][inputs][choice][precision], return benchmark if cached.
257
            2. Check local_cache[op][inputs][choice][precision], return benchmark if cached.
258
            3.
259
                a. `max_autotune_gemm=True`: benchmark the choice, update
260
                    local_cache[op][inputs][choice], and return the benchmark.
261
                b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing.
262
        """
263
        precision = torch.get_float32_matmul_precision()
264

265
        log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision)
266
        log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision)
267
        log_errors = partial(
268
            log_global_cache_errors, self.system, op, inputs, precision
269
        )
270
        timings = {}
271

272
        def check_cache(cache, callback=None) -> bool:
273
            """Check if `cache` contains data for all the choices"""
274
            hit = True
275
            for choice in choices:
276
                choice_hash = choice.hash_key()
277
                if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}):
278
                    # cache hit
279
                    timings[choice] = cache[op][inputs][precision][choice_hash]
280
                else:
281
                    # cache miss
282
                    hit = False
283
                    break
284
            if callback:
285
                callback(cached=hit)
286
            return hit
287

288
        if config.max_autotune or config.max_autotune_gemm:
289
            local_cache = self.get_local_cache()
290
            # check local cache first since it is data specific to the current machine
291
            if not check_cache(local_cache) and not (
292
                use_global_cache()
293
                and check_cache(self.get_global_cache(), callback=log_stats)
294
            ):
295
                try:
296
                    # re-benchmark everything to try to get consistent numbers from the same machine
297
                    timings = benchmark(choices)
298
                    assert all(choice in timings for choice in choices)
299
                    local_cache.setdefault(op, {})
300
                    local_cache[op].setdefault(inputs, {}).setdefault(precision, {})
301
                    for choice, timing in timings.items():
302
                        local_cache[op][inputs][precision][choice.hash_key()] = timing
303
                except RuntimeError as e:
304
                    # catch and log autotuning failures
305
                    log_errors(e)
306
                    raise e
307

308
                self.update_local_cache(local_cache)
309

310
                timings_to_log = {
311
                    choice.hash_key(): timings[choice] for choice in choices
312
                }
313
                log_vals(timings_to_log)
314
        elif use_global_cache():
315
            # only check global cache, not local one
316
            check_cache(self.get_global_cache(), callback=log_stats)
317
            # may have a partial cache hit, where not everything is benchmarked
318

319
        return timings
320

321

322
def get_lock_dir() -> str:
323
    lock_dir = os.path.join(cache_dir(), "locks")
324
    if not os.path.exists(lock_dir):
325
        os.makedirs(lock_dir, exist_ok=True)
326
    return lock_dir
327

328

329
def sha256_hash(data: bytes) -> str:
330
    # [:51] to strip off the "Q====" suffix common to every hash value.
331
    return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
332

333

334
def code_hash(code: Union[str, bytes], extra: str = ""):
335
    hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
336
    if extra != "":
337
        hashing_str = hashing_str + b"||" + extra.encode("utf-8")
338
    return "c" + sha256_hash(hashing_str)
339

340

341
def get_path(
342
    basename: str, extension: str, specified_dir: str = ""
343
) -> Tuple[str, str, str]:
344
    if specified_dir:
345
        if os.path.isabs(specified_dir):
346
            subdir = specified_dir
347
        else:
348
            subdir = os.path.join(cache_dir(), specified_dir)
349
    else:
350
        subdir = os.path.join(cache_dir(), basename[1:3])
351
    path = os.path.join(subdir, f"{basename}.{extension}")
352
    return basename, subdir, path
353

354

355
def get_hash(content: Union[str, bytes], extra: str = "", hash_type: str = "code"):
356
    if hash_type == "code":
357
        return code_hash(content, extra)
358
    if hash_type in ["cubin", "hsaco"]:
359
        return code_hash(repr(content))
360
    raise AssertionError(f"Unknown hash type {hash_type}")
361

362

363
def write(
364
    content: Union[str, bytes],
365
    extension: str,
366
    extra: str = "",
367
    hash_type: str = "code",
368
    specified_dir: str = "",
369
) -> Tuple[str, str]:
370
    # use striped content to compute hash so we don't end up with different
371
    # hashes just because the content begins/ends with differnet number of
372
    # spaces.
373
    key: str = get_hash(content.strip(), extra, hash_type)
374
    basename, subdir, path = get_path(key, extension, specified_dir)
375
    if not os.path.exists(subdir):
376
        os.makedirs(subdir, exist_ok=True)
377
    if not os.path.exists(path):
378
        write_atomic(path, content)
379
    return basename, path
380

381

382
def write_atomic(path: str, content: Union[str, bytes]) -> None:
383
    # Write into temporary file first to avoid conflicts between threads
384
    # Avoid using a named temporary file, as those have restricted permissions
385
    assert isinstance(
386
        content, (str, bytes)
387
    ), "Only strings and byte arrays can be saved in the cache"
388
    path = pathlib.Path(path)
389
    tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
390
    write_mode = "w" if isinstance(content, str) else "wb"
391
    with tmp_path.open(write_mode) as f:
392
        f.write(content)
393
    tmp_path.rename(path)
394

395

396
@dataclasses.dataclass
397
class TensorMetadataAndValues:
398
    """
399
    TensorMetadata plus the elements as a list of raw values.
400
    Used for hashing inlined constants.
401
    """
402

403
    tensor_metadata: TensorMetadata
404
    values: List[Any]
405

406

407
def _ident(x: Any) -> Any:
408
    return x
409

410

411
def _reduce_fake_tensor(t):
412
    """
413
    See FxGraphCachePickler. Custom reducer to pickle FakeTensors.
414
    """
415
    metadata = extract_tensor_metadata(t)
416
    return (_ident, (metadata,))
417

418

419
def _reduce_tensor(t):
420
    """
421
    See FxGraphCachePickler. Custom reducer to pickle Tensors.
422
    """
423
    # If we see tensors, we know they're constants stored as attributes on
424
    # the GraphModule. See tensor lowering; small constants are inlined. If
425
    # we see a small tensor, therefore, no reference will ultimately remain
426
    # in the generated code. So we need to include its value in the cache key.
427
    # Large constants are effectively treated as inputs and we consider only
428
    # their metadata.
429
    metadata = extract_tensor_metadata(t)
430
    if len(t.shape) == 0 or torch._inductor.graph.GraphLowering.can_inline_constant(t):
431
        return (_ident, (TensorMetadataAndValues(metadata, t.tolist()),))
432
    else:
433
        return (_ident, (metadata,))
434

435

436
def _reduce_symint(s):
437
    """
438
    See FxGraphCachePickler. Custom reducer to pickle SymInts.
439
    """
440
    # For hashing purposes, we only care about the name of the symbol and
441
    # not the backed value. We evaluate guards stored with a cached graph
442
    # to ensure a cached entity with SymInt args is safe to reuse.
443
    return (_ident, (str(s),))
444

445

446
class FxGraphCachePickler(pickle.Pickler):
447
    """
448
    Custom pickler to customize the pickling of some objects (Tensors), only for the
449
    purpose of computing a hash for keying into the FxGraphCache. Tensors contain
450
    objects that don't pickle and/or vary between runs, and we want to capture the
451
    data that allow us to compute a stable, but safe hash.
452
    """
453

454
    dispatch_table = copyreg.dispatch_table.copy()
455
    dispatch_table[FakeTensor] = _reduce_fake_tensor
456
    dispatch_table[torch.Tensor] = _reduce_tensor
457
    dispatch_table[torch.SymInt] = _reduce_symint
458

459
    @staticmethod
460
    def dumps(obj) -> bytes:
461
        """
462
        Pickle an object using the FxGraphCachePickler.
463
        """
464
        with io.BytesIO() as stream:
465
            pickler = FxGraphCachePickler(stream)
466
            pickler.dump(obj)
467
            return stream.getvalue()
468

469
    @staticmethod
470
    def get_hash(obj: Any) -> str:
471
        """
472
        Serialize an object using the FxGraphCachePickler and return a hash
473
        of the pickled object.
474
        """
475
        serialized_data = FxGraphCachePickler.dumps(obj)
476
        return sha256_hash(serialized_data)
477

478

479
@functools.lru_cache(None)
480
def get_inductor_code_hash() -> bytes:
481
    """
482
    Compute a hash of all inductor code modules. Used by the FxGraph cache
483
    so any inductor code changes would result in new cache keys.
484
    """
485
    inductor_root = os.path.dirname(__file__)
486

487
    contents: Dict[str, bytes] = {}
488
    for lib in pkgutil.iter_modules([inductor_root]):
489
        spec = lib.module_finder.find_spec(lib.name, None)
490
        assert spec is not None
491
        module = spec.origin
492
        assert module is not None
493
        with open(module, "rb") as f:
494
            contents[module] = f.read()
495

496
    return hashlib.sha256(pickle.dumps(contents)).digest()
497

498

499
@dataclasses.dataclass
500
class OrderedSetHolder:
501
    """
502
    See FxGraphHashDetails. Holds a sorted list to support stable hashing
503
    of set kwargs.
504
    """
505

506
    items: List[Any]
507

508

509
class FxGraphHashDetails:
510
    """
511
    Object to capture all the details for a compiled FX graph relevant to computing
512
    a safe and stable cache key.
513
    """
514

515
    # Excluded kwargs param that are not stable between runs
516
    EXCLUDED_KWARGS = ["graph_id"]
517

518
    def __init__(
519
        self,
520
        gm: torch.fx.GraphModule,
521
        example_inputs: List[torch.Tensor],
522
        fx_kwargs: Dict[str, Any],
523
    ):
524
        self.gm = gm
525
        self.example_inputs = example_inputs
526

527
        # Order kwargs so hashing is stable to changes in kwarg order.
528
        self.fx_kwargs = {}
529
        for k in sorted(fx_kwargs):
530
            if k not in self.EXCLUDED_KWARGS:
531
                if type(fx_kwargs[k]) is set:
532
                    # Special case to handle set params. Python sets can't be
533
                    # ordered, so sort the elements and store them in a proxy.
534
                    self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k]))
535
                else:
536
                    self.fx_kwargs[k] = fx_kwargs[k]
537

538
        # Also hash on various system info (including the triton compiler version), as
539
        # well as the inductor configuration and code.
540
        self.torch_version = torch.__version__
541
        self.system_info = CacheBase.get_system()
542

543
        self.inductor_config = config.save_config()
544
        self.inductor_code_hash = get_inductor_code_hash()
545

546
    def debug_str(self) -> str:
547
        """
548
        Get a printable string describing in more detail all the attributes
549
        comprising this object. Useful for debugging when one graph hashes
550
        to a different value than another.
551
        """
552

553
        def get_str(obj) -> str:
554
            if isinstance(obj, torch.Tensor):
555
                return str(extract_tensor_metadata(obj))
556
            elif isinstance(obj, bytes):
557
                return "<bytes>"
558
            else:
559
                return str(obj)
560

561
        lines = []
562
        for attr, obj in vars(self).items():
563
            if isinstance(obj, list):
564
                for ii in range(len(obj)):
565
                    h = FxGraphCachePickler.get_hash(obj[ii])
566
                    lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}")
567
            elif isinstance(obj, dict):
568
                for k, v in obj.items():
569
                    h = FxGraphCachePickler.get_hash(v)
570
                    lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}")
571
            else:
572
                h = FxGraphCachePickler.get_hash(obj)
573
                lines.append(f"[{h}] {attr}: {get_str(obj)}")
574
        return "\n".join(lines)
575

576

577
def compiled_fx_graph_hash(
578
    gm: torch.fx.GraphModule,
579
    example_inputs: List[torch.Tensor],
580
    fx_kwargs: Dict[str, Any],
581
) -> str:
582
    """
583
    Generate a unique hash of the FX graph for caching.
584
    """
585
    details = FxGraphHashDetails(gm, example_inputs, fx_kwargs)
586
    # The prefix distinguishes among the other kinds of objects we
587
    # cache in this module.
588
    key = "f" + FxGraphCachePickler.get_hash(details)
589
    log.debug("FX graph cache hash details for key %s:\n%s", key, details.debug_str())
590
    return key
591

592

593
class FxGraphCache:
594
    """
595
    Supports caching and reusing compiled Fx graphs.
596

597
    The overall strategy is as follows:
598
    - This cache stores entries on disk. When saving an entry, we can't
599
      serialize callables (that could be C++, Triton, etc.), so we serialize
600
      their own disk cache location. We then recreate the compiled artifact
601
      after fetching from disk.
602
    - For indexing the cache, we gather the fields relevant to identifying an
603
      FxGraph (the graph module, graph inputs, system settings etc.) into an
604
      FxGraphCacheDetails object, pickle it, and compute a hash for the key.
605
      See FxGraphCachePickler.
606
    - Among the metadata we store, we also include a guards expression that's
607
      appropriate for validating any symbols for Tensor arguments that have
608
      symbolic bounds. On cache lookup then, we evaluate those guards in the
609
      current context to validate that a cached entry can be served.
610
    - A given graph could have multiple compiled versions, corresponding to
611
      different sets of guards. Therefore, we store cache entries in the form:
612
          <temp dir>/<fx graph hash>/<serialized metatdata>
613
    - On lookup, we compute the key from the graph details, iterate over all
614
      leaf files in the corresponding subdirectory, deserialize the entry, and
615
      evaluate its guards expression. If the evaluation succeeds, we have a
616
      cache hit. If it fails, we compile the graph and store a new entry.
617
    - Finally, on a cache hit, we need to make sure any guards that would
618
      have been created during compilation are added to the current context.
619
    """
620

621
    # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs
622
    # in an in-memory cache after loading from disk.
623
    @staticmethod
624
    def _get_tmp_dir() -> str:
625
        """
626
        Get the toplevel temporary directory for storing compiled graphs.
627
        """
628
        return os.path.join(cache_dir(), "fxgraph")
629

630
    @staticmethod
631
    def _get_tmp_dir_for_key(key: str) -> str:
632
        """
633
        Return the disk location for a given cache key.
634
        """
635
        return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
636

637
    @staticmethod
638
    def _filter_symints(inputs: List[Any]) -> List[torch.SymInt]:
639
        """
640
        Get the SymInt objects from the input list.
641
        """
642
        return [s for s in inputs if isinstance(s, torch.SymInt)]
643

644
    @staticmethod
645
    def _get_shape_env() -> ShapeEnv:
646
        """
647
        Helper to get the shape env from the tracing context.
648
        """
649
        return torch._guards.TracingContext.get().fake_mode.shape_env
650

651
    @staticmethod
652
    def _lookup_graph(
653
        key: str,
654
        example_inputs: List[torch.Tensor],
655
    ) -> Optional[CompiledFxGraph]:
656
        """
657
        Lookup a compiled graph in the cache by key. On a hit, return the
658
        deserialized CompiledFxGraph object. On a miss, return None.
659
        """
660
        subdir = FxGraphCache._get_tmp_dir_for_key(key)
661
        if not os.path.exists(subdir):
662
            return None
663

664
        # Iterate over any entries in the subdir for this key and evaluate
665
        # their guards to determine whether there's a hit.
666
        for path in sorted(os.listdir(subdir)):
667
            with open(os.path.join(subdir, path), "rb") as f:
668
                graph: CompiledFxGraph = pickle.load(f)
669

670
            guards_expr = graph.guards_expr
671
            if not guards_expr:
672
                # No guards to evaluate
673
                return graph
674

675
            # Evaluate the guard expression in the current context.
676
            shape_env = FxGraphCache._get_shape_env()
677
            symints = FxGraphCache._filter_symints(example_inputs)
678

679
            # If there's not a cache hit, we don't want the evaluation to
680
            # affect the current env, e.g., cause the creation of new guards,
681
            # so we evaluate with the hints instead of the symbols.
682
            assert all(has_hint(s) for s in symints)
683
            hints = [hint_int(s) for s in symints]
684
            hit = bool(shape_env.evaluate_guards_expression(guards_expr, hints))
685
            log.debug(
686
                "fx graph cache key %s evaluating guards for %s with values %s => %s",
687
                key,
688
                guards_expr,
689
                hints,
690
                hit,
691
            )
692
            if hit:
693
                # Now re-evaluate with the symints to add any guards to the current env.
694
                check = bool(shape_env.evaluate_guards_expression(guards_expr, symints))
695
                assert check is True
696
                log.debug(
697
                    "fx graph cache key %s post-load guards: %s",
698
                    key,
699
                    shape_env.guards,
700
                )
701
                return graph
702

703
        return None
704

705
    @staticmethod
706
    def _save_graph(
707
        key: str, compiled_graph: CompiledFxGraph, example_inputs: List[torch.Tensor]
708
    ):
709
        """
710
        Store a serialized CompiledFxGraph on disk.
711
        """
712
        disk_compiled_graph = copy(compiled_graph)
713
        # Important as compiled models are not pickleable:
714
        disk_compiled_graph.compiled_artifact = None
715

716
        # Before serializing, compute the guard expression that will be used to
717
        # ensure that a CompiledFxGraph is valid when loaded from the cache. It's
718
        # sufficient to consider only the SymInt args to the fx graph since the
719
        # Tensor shapes are already captured in the hash for the cache key. Any
720
        # Tensor arg with a symbolic shape will have a SymInt arg for the graph.
721
        shape_env = FxGraphCache._get_shape_env()
722
        symints = FxGraphCache._filter_symints(example_inputs)
723
        disk_compiled_graph.guards_expr = shape_env.produce_guards_expression(symints)
724

725
        content = pickle.dumps(disk_compiled_graph)
726

727
        subdir = FxGraphCache._get_tmp_dir_for_key(key)
728
        if not os.path.exists(subdir):
729
            os.makedirs(subdir, exist_ok=True)
730

731
        # Use a hash of the serialized CompiledFxGraph to get a unique file
732
        # name. The specific name doesn't matter since a lookup involves
733
        # iterating over all entries in the parent subdir.
734
        path = os.path.join(subdir, sha256_hash(content))
735
        write_atomic(path, content)
736

737
    @staticmethod
738
    def load(
739
        compile_fx_fn: Callable[..., Any],
740
        gm: torch.fx.GraphModule,
741
        example_inputs: List[torch.Tensor],
742
        fx_kwargs: Dict[str, Any],
743
    ):
744
        """
745
        Load a compiled graph from the cache. If a cached entry does not exist,
746
        compile the graph and save it to the cache.
747
        """
748
        from filelock import FileLock
749

750
        key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs)
751

752
        lock_dir = get_lock_dir()
753
        lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
754
        with lock:
755
            compiled_graph = FxGraphCache._lookup_graph(key, example_inputs)
756
            if compiled_graph is None:
757
                log.debug("fx graph cache miss for key %s", key)
758
                counters["inductor"]["fxgraph_cache_miss"] += 1
759
                compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs)
760
                FxGraphCache._save_graph(key, compiled_graph, example_inputs)
761
            else:
762
                log.debug("fx graph cache hit for key %s", key)
763
                counters["inductor"]["fxgraph_cache_hit"] += 1
764

765
            return compiled_graph
766

767
    @staticmethod
768
    def clear():
769
        """
770
        Clear out the on-disk cache.
771
        """
772
        shutil.rmtree(FxGraphCache._get_tmp_dir())
773

774

775
@dataclasses.dataclass
776
class CompiledFxGraph:
777
    """
778
    Class holding a compiled FX graph. This is the object serialized on disk
779
    to support FxGraph caching.
780
    """
781

782
    compiled_artifact: Optional[Callable[..., Any]] = None
783
    current_callable: Optional[Callable[..., Any]] = None
784
    cache_key: Optional[str] = None
785
    artifact_path: Optional[str] = None
786
    cache_linemap: Optional[List[Tuple[int, str]]] = None
787
    device_types: Set[str] = field(default_factory=set)
788
    device_idxs: Set[int] = field(default_factory=set)
789
    mutated_inputs: Set[str] = field(default_factory=set)
790
    mutated_input_idxs: Set[int] = field(default_factory=set)
791
    constants: Dict[str, torch.Tensor] = field(default_factory=dict)
792
    output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
793
    # This is a string representation of an expression we serialize
794
    # with the object so the guards can be evaluated in a different
795
    # context in order to verify the validity of serving a cached
796
    # fx graph. The expression must be generated by:
797
    # ShapeEnv.produce_guards_expression()
798
    guards_expr: Optional[str] = None
799

800
    _boxed_call: Optional[bool] = None
801

802
    disabled_cudagraphs_reason: Optional[str] = None
803

804
    def __init__(
805
        self,
806
        compiled_artifact: Optional[Callable[..., Any]],
807
        graph: GraphLowering,
808
        output_strides: List[Optional[Tuple[int, ...]]],
809
        disabled_cudagraphs_reason: Optional[str],
810
    ):
811
        self.compiled_artifact = compiled_artifact
812
        self.cache_key = graph.cache_key
813
        self.artifact_path = graph.cache_path
814
        self.cache_linemap = graph.cache_linemap
815
        self.device_types = graph.device_types
816
        self.device_idxs = graph.device_idxs
817
        self.mutated_inputs = graph.mutated_inputs
818
        self.mutated_input_idxs = set(graph.mutated_input_idxs)
819
        self.constants = graph.constants
820
        self.output_strides = output_strides
821
        self.guards_expr = None
822
        self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
823

824
    def __call__(self, inputs: List[Any]) -> Any:
825
        return self.get_current_callable()(inputs)
826

827
    def get_current_callable(self) -> Callable[..., Any]:
828
        if self.current_callable is None:
829
            # This prevents a circular reference that makes CompiledFxGraph
830
            # get stuck without getting garbage collected
831
            return functools.partial(_run_from_cache, weakref.proxy(self))
832
        else:
833
            return self.current_callable
834

835

836
def _run_from_cache(compiled_graph: CompiledFxGraph, inputs: List[Any]) -> Any:
837
    # We can't really serialize callables that may be C++/Triton/etc.,
838
    # so we serialize their disk cache location instead
839
    # TODO: When making an API that can save compiled models e2e to disk
840
    # this will need to be better
841
    if compiled_graph.compiled_artifact is None:
842
        from .codecache import PyCodeCache
843

844
        assert compiled_graph.cache_key
845
        assert compiled_graph.artifact_path
846
        compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
847
            compiled_graph.cache_key,
848
            compiled_graph.artifact_path,
849
            compiled_graph.cache_linemap,
850
            compiled_graph.constants,
851
        ).call
852

853
    return compiled_graph.compiled_artifact(inputs)
854

855

856
def cpp_compiler() -> str:
857
    if config.is_fbcode():
858
        return build_paths.cc()
859
    if isinstance(config.cpp.cxx, (list, tuple)):
860
        search = tuple(config.cpp.cxx)
861
    else:
862
        search = (config.cpp.cxx,)
863
    return cpp_compiler_search(search)
864

865

866
@functools.lru_cache(1)
867
def cpp_compiler_search(search: str) -> str:
868
    for cxx in search:
869
        try:
870
            if cxx is None:
871
                # gxx package is only available for Linux
872
                # according to https://anaconda.org/conda-forge/gxx/
873
                if sys.platform != "linux":
874
                    continue
875
                # Do not install GXX by default
876
                if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"):
877
                    continue
878
                from filelock import FileLock
879

880
                lock_dir = get_lock_dir()
881
                lock = FileLock(
882
                    os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT
883
                )
884
                with lock:
885
                    cxx = install_gcc_via_conda()
886
            subprocess.check_output([cxx, "--version"])
887
            return cxx
888
        except (subprocess.SubprocessError, FileNotFoundError, ImportError):
889
            continue
890
    raise exc.InvalidCxxCompiler()
891

892

893
def install_gcc_via_conda() -> str:
894
    """On older systems, this is a quick way to get a modern compiler"""
895
    prefix = os.path.join(cache_dir(), "gcc")
896
    cxx_path = os.path.join(prefix, "bin", "g++")
897
    if not os.path.exists(cxx_path):
898
        log.info("Downloading GCC via conda")
899
        conda = os.environ.get("CONDA_EXE", "conda")
900
        if conda is None:
901
            conda = shutil.which("conda")
902
        if conda is not None:
903
            subprocess.check_call(
904
                [
905
                    conda,
906
                    "create",
907
                    f"--prefix={prefix}",
908
                    "--channel=conda-forge",
909
                    "--quiet",
910
                    "-y",
911
                    "python=3.8",
912
                    "gxx",
913
                ],
914
                stdout=subprocess.PIPE,
915
            )
916
    return cxx_path
917

918

919
def is_gcc() -> bool:
920
    return bool(re.search(r"(gcc|g\+\+)", cpp_compiler()))
921

922

923
def is_clang() -> bool:
924
    return bool(re.search(r"(clang|clang\+\+)", cpp_compiler()))
925

926

927
@functools.lru_cache(None)
928
def is_apple_clang() -> bool:
929
    cxx = cpp_compiler()
930
    version_string = subprocess.check_output([cxx, "--version"]).decode("utf8")
931
    return "Apple" in version_string.splitlines()[0]
932

933

934
class VecISA:
935
    _bit_width: int
936
    _macro: str
937
    _arch_flags: str
938
    _dtype_nelements: Dict[torch.dtype, int]
939

940
    # Note [Checking for Vectorized Support in Inductor]
941
    # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
942
    # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
943
    # like exp, pow, sin, cos and etc.
944
    # But PyTorch and TorchInductor might use different compilers to build code. If
945
    # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
946
    # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
947
    # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
948
    # gcc/g++ compiler by default while it could support the AVX512 compilation.
949
    # Therefore, there would be a conflict sleef version between PyTorch and
950
    # TorchInductor. Hence, we dry-compile the following code to check whether current
951
    # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
952
    # also needs the logic
953
    # In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
954
    # making the runtime check unnecessary.
955
    _avx_code = """
956
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
957
#include <ATen/cpu/vec/functional.h>
958
#include <ATen/cpu/vec/vec.h>
959
#endif
960

961
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
962

963
extern "C" void __avx_chk_kernel() {
964
    auto tmp0 = at::vec::Vectorized<float>(1);
965
    auto tmp1 = tmp0.exp();
966
    tmp1.store(in_out_ptr0);
967
}
968
"""  # noqa: B950
969

970
    _avx_py_load = """
971
import torch
972
from ctypes import cdll
973
cdll.LoadLibrary("__lib_path__")
974
"""
975

976
    def bit_width(self) -> int:
977
        return self._bit_width
978

979
    def nelements(self, dtype: torch.dtype = torch.float) -> int:
980
        return self._dtype_nelements[dtype]
981

982
    def build_macro(self) -> str:
983
        return self._macro
984

985
    def build_arch_flags(self) -> str:
986
        return self._arch_flags
987

988
    def __hash__(self) -> int:
989
        return hash(str(self))
990

991
    @functools.lru_cache(None)
992
    def __bool__(self) -> bool:
993
        if config.cpp.vec_isa_ok is not None:
994
            return config.cpp.vec_isa_ok
995

996
        if config.is_fbcode():
997
            return True
998

999
        key, input_path = write(VecISA._avx_code, "cpp")
1000
        from filelock import FileLock
1001

1002
        lock_dir = get_lock_dir()
1003
        lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
1004
        with lock:
1005
            output_path = input_path[:-3] + "so"
1006
            build_cmd = shlex.split(
1007
                cpp_compile_command(
1008
                    input_path, output_path, warning_all=False, vec_isa=self
1009
                )
1010
            )
1011
            try:
1012
                # Check build result
1013
                compile_file(input_path, output_path, build_cmd)
1014
                subprocess.check_call(
1015
                    [
1016
                        sys.executable,
1017
                        "-c",
1018
                        VecISA._avx_py_load.replace("__lib_path__", output_path),
1019
                    ],
1020
                    stderr=subprocess.DEVNULL,
1021
                    env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
1022
                )
1023
            except Exception as e:
1024
                return False
1025

1026
            return True
1027

1028

1029
@dataclasses.dataclass
1030
class VecNEON(VecISA):
1031
    _bit_width = 256  # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
1032
    _macro = "-DCPU_CAPABILITY_NEON"
1033
    _arch_flags = ""  # Unused
1034
    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16}
1035

1036
    def __str__(self) -> str:
1037
        return "neon"  # Unused
1038

1039
    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1040

1041

1042
@dataclasses.dataclass
1043
class VecAVX512(VecISA):
1044
    _bit_width = 512
1045
    _macro = "-DCPU_CAPABILITY_AVX512"
1046
    _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
1047
    _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
1048

1049
    def __str__(self) -> str:
1050
        return "avx512"
1051

1052
    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1053

1054

1055
@dataclasses.dataclass
1056
class VecAVX2(VecISA):
1057
    _bit_width = 256
1058
    _macro = "-DCPU_CAPABILITY_AVX2"
1059
    _arch_flags = "-mavx2 -mfma"
1060
    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
1061

1062
    def __str__(self) -> str:
1063
        return "avx2"
1064

1065
    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1066

1067

1068
@dataclasses.dataclass
1069
class VecZVECTOR(VecISA):
1070
    _bit_width = 256
1071
    _macro = "-DCPU_CAPABILITY_ZVECTOR -DCPU_CAPABILITY=ZVECTOR -DHAVE_ZVECTOR_CPU_DEFINITION"
1072
    _arch_flags = "-mvx -mzvector"
1073
    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
1074

1075
    def __str__(self) -> str:
1076
        return "zvector"
1077

1078
    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1079

1080

1081
class InvalidVecISA(VecISA):
1082
    _bit_width = 0
1083
    _macro = ""
1084
    _arch_flags = ""
1085
    _dtype_nelements = {}
1086

1087
    def __str__(self) -> str:
1088
        return "INVALID_VEC_ISA"
1089

1090
    def __bool__(self) -> bool:  # type: ignore[override]
1091
        return False
1092

1093
    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1094

1095

1096
invalid_vec_isa = InvalidVecISA()
1097
supported_vec_isa_list = [
1098
    VecAVX512(),
1099
    VecAVX2(),
1100
    VecNEON(),
1101
]  # This order matters for test_cpu_repro
1102

1103

1104
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
1105
# might have too much redundant content that is useless for ISA check. Hence,
1106
# we only cache some key isa information.
1107
@functools.lru_cache(None)
1108
def valid_vec_isa_list() -> List[VecISA]:
1109
    if sys.platform != "linux":
1110
        return []
1111

1112
    if platform.machine() == "s390x":
1113
        return [VecZVECTOR()]
1114

1115
    isa_list = []
1116
    with open("/proc/cpuinfo") as _cpu_info:
1117
        _cpu_info_content = _cpu_info.read()
1118
        for isa in supported_vec_isa_list:
1119
            # cpuinfo does not reveal info about NEON support. All aarch64 processors do support NEON though.
1120
            if (
1121
                (str(isa) in _cpu_info_content)
1122
                or (isinstance(isa, VecNEON) and platform.processor() == "aarch64")
1123
                and isa
1124
            ):
1125
                isa_list.append(isa)
1126
        return isa_list
1127

1128

1129
def pick_vec_isa() -> VecISA:
1130
    if config.is_fbcode():
1131
        return VecAVX2()
1132

1133
    _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
1134
    if not _valid_vec_isa_list:
1135
        return invalid_vec_isa
1136

1137
    # If the simdlen is None, it indicates determin the vectorization length automatically
1138
    if config.cpp.simdlen is None:
1139
        assert _valid_vec_isa_list
1140
        return _valid_vec_isa_list[0]
1141

1142
    for isa in _valid_vec_isa_list:
1143
        if config.cpp.simdlen == isa.bit_width():
1144
            return isa
1145

1146
    return invalid_vec_isa
1147

1148

1149
def get_compile_only(compile_only: bool = True) -> str:
1150
    return "-c" if compile_only else ""
1151

1152

1153
def get_shared(shared: bool = True, compile_only: bool = False) -> str:
1154
    if not shared:
1155
        return ""
1156
    if compile_only:
1157
        return "-fPIC"
1158
    if platform.system() == "Darwin" and "clang" in cpp_compiler():
1159
        # This causes undefined symbols to behave the same as linux
1160
        return "-shared -fPIC -undefined dynamic_lookup"
1161
    else:
1162
        return "-shared -fPIC"
1163

1164

1165
def get_warning_all_flag(warning_all: bool = True) -> str:
1166
    return "-Wall" if warning_all else ""
1167

1168

1169
def get_glibcxx_abi_build_flags() -> str:
1170
    return "-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
1171

1172

1173
def cpp_flags() -> str:
1174
    flags = ["-std=c++17", "-Wno-unused-variable", "-Wno-unknown-pragmas"]
1175
    if is_clang():
1176
        flags.append("-Werror=ignored-optimization-argument")
1177
    return " ".join(flags)
1178

1179

1180
def cpp_wrapper_flags() -> str:
1181
    return "-DTORCH_INDUCTOR_CPP_WRAPPER"
1182

1183

1184
def optimization_flags() -> str:
1185
    base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG"
1186
    base_flags += " -ffast-math -fno-finite-math-only"
1187
    if not config.cpp.enable_unsafe_math_opt_flag:
1188
        base_flags += " -fno-unsafe-math-optimizations"
1189
    if not config.cpp.enable_floating_point_contract_flag:
1190
        base_flags += " -ffp-contract=off"
1191

1192
    if config.is_fbcode():
1193
        # FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies.
1194
        # This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths.
1195
        # We will fix it later by exposing the lib path.
1196
        return base_flags
1197

1198
    if sys.platform == "darwin":
1199
        # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
1200
        # Also, `-march=native` is unrecognized option on M1
1201
        base_flags += " -Xclang"
1202
    else:
1203
        if platform.machine() == "ppc64le":
1204
            base_flags += " -mcpu=native"
1205
        else:
1206
            base_flags += " -march=native"
1207

1208
    # Internal cannot find libgomp.so
1209
    if not config.is_fbcode():
1210
        base_flags += " -fopenmp"
1211
    return base_flags
1212

1213

1214
def use_custom_generated_macros() -> str:
1215
    return "-D C10_USING_CUSTOM_GENERATED_MACROS"
1216

1217

1218
def use_fb_internal_macros() -> str:
1219
    if config.is_fbcode():
1220
        openmp_lib = build_paths.openmp_lib()
1221
        preprocessor_flags = " ".join(
1222
            (
1223
                "-D C10_USE_GLOG",
1224
                "-D C10_USE_MINIMAL_GLOG",
1225
                "-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY",
1226
            )
1227
        )
1228
        return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"
1229
    else:
1230
        return ""
1231

1232

1233
def use_standard_sys_dir_headers() -> str:
1234
    if config.is_fbcode():
1235
        return "-nostdinc"
1236
    else:
1237
        return ""
1238

1239

1240
@functools.lru_cache(None)
1241
def is_conda_llvm_openmp_installed() -> bool:
1242
    try:
1243
        command = "conda list llvm-openmp --json"
1244
        output = subprocess.check_output(command.split()).decode("utf8")
1245
        return len(json.loads(output)) > 0
1246
    except subprocess.SubprocessError:
1247
        return False
1248

1249

1250
@functools.lru_cache(None)
1251
def homebrew_libomp() -> Tuple[bool, str]:
1252
    try:
1253
        # check if `brew` is installed
1254
        subprocess.check_output(["which", "brew"])
1255
        # get the location of `libomp` if it is installed
1256
        # this is the location that `libomp` **would** be installed
1257
        # see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details
1258
        libomp_path = (
1259
            subprocess.check_output(["brew", "--prefix", "libomp"])
1260
            .decode("utf8")
1261
            .strip()
1262
        )
1263
        # check if `libomp` is installed
1264
        omp_available = os.path.exists(libomp_path)
1265
        return omp_available, libomp_path
1266
    except subprocess.SubprocessError:
1267
        return False, ""
1268

1269

1270
def get_include_and_linking_paths(
1271
    include_pytorch: bool = False,
1272
    vec_isa: VecISA = invalid_vec_isa,
1273
    cuda: bool = False,
1274
    aot_mode: bool = False,
1275
) -> Tuple[List[str], str, str, str, str]:
1276
    if (
1277
        config.is_fbcode()
1278
        and "CUDA_HOME" not in os.environ
1279
        and "CUDA_PATH" not in os.environ
1280
    ):
1281
        os.environ["CUDA_HOME"] = os.path.dirname(build_paths.cuda())
1282
    from torch.utils import cpp_extension
1283

1284
    macros = ""
1285
    build_arch_flags = ""
1286
    if sys.platform == "linux" and (
1287
        include_pytorch
1288
        or vec_isa != invalid_vec_isa
1289
        or cuda
1290
        or config.cpp.enable_kernel_profile
1291
    ):
1292
        # Note - We include pytorch only on linux right now. There is more work
1293
        # to do to enable OMP build on darwin where PyTorch is built with IOMP
1294
        # and we need a way to link to what PyTorch links.
1295
        ipaths = cpp_extension.include_paths(cuda) + [sysconfig.get_path("include")]
1296
        lpaths = cpp_extension.library_paths(cuda) + [
1297
            sysconfig.get_config_var("LIBDIR")
1298
        ]
1299

1300
        libs = []
1301

1302
        # No need to manually specify libraries in fbcode.
1303
        if not config.is_fbcode():
1304
            libs += ["torch", "torch_cpu"]
1305
            libs += ["gomp"]
1306
            if not aot_mode:
1307
                libs += ["torch_python"]
1308
        else:
1309
            # internal remote execution is able to find omp, but not gomp
1310
            libs += ["omp"]
1311
            if aot_mode:
1312
                ipaths += [os.path.dirname(cpp_prefix_path())]
1313
                if cuda:
1314
                    # This is a special treatment for Meta internal cuda-12 where all libs
1315
                    # are in lib/cuda-12 and lib/cuda-12/stubs
1316
                    for i, path in enumerate(lpaths):
1317
                        if path.startswith(
1318
                            os.environ["CUDA_HOME"]
1319
                        ) and not os.path.exists(f"{path}/libcudart_static.a"):
1320
                            for root, dirs, files in os.walk(path):
1321
                                if "libcudart_static.a" in files:
1322
                                    lpaths[i] = os.path.join(path, root)
1323
                                    lpaths.append(os.path.join(lpaths[i], "stubs"))
1324
                                    break
1325
        macros = vec_isa.build_macro()
1326
        if macros:
1327
            if config.is_fbcode() and vec_isa != invalid_vec_isa:
1328
                cap = str(vec_isa).upper()
1329
                macros = " ".join(
1330
                    [
1331
                        vec_isa.build_arch_flags(),
1332
                        f"-D CPU_CAPABILITY={cap}",
1333
                        f"-D CPU_CAPABILITY_{cap}",
1334
                        f"-D HAVE_{cap}_CPU_DEFINITION",
1335
                    ]
1336
                )
1337

1338
        if aot_mode and cuda:
1339
            if macros is None:
1340
                macros = ""
1341
            macros += " -D USE_ROCM" if torch.version.hip else " -D USE_CUDA"
1342

1343
        if cuda:
1344
            if torch.version.hip is not None:
1345
                libs += ["c10_hip", "torch_hip"]
1346
                macros += " -D __HIP_PLATFORM_AMD__"
1347
            else:
1348
                if config.is_fbcode():
1349
                    libs += ["cuda"]
1350
                else:
1351
                    libs += ["c10_cuda", "cuda", "torch_cuda"]
1352
        build_arch_flags = vec_isa.build_arch_flags()
1353
    else:
1354
        # Note - this is effectively a header only inclusion. Usage of some header files may result in
1355
        # symbol not found, if those header files require a library.
1356
        # For those cases, include the lpath and libs command as we do for pytorch above.
1357
        # This approach allows us to only pay for what we use.
1358
        ipaths = cpp_extension.include_paths(cuda) + [sysconfig.get_path("include")]
1359
        if aot_mode:
1360
            ipaths += [os.path.dirname(cpp_prefix_path())]
1361
        lpaths = []
1362
        if sys.platform == "darwin":
1363
            # only Apple builtin compilers (Apple Clang++) require openmp
1364
            omp_available = not is_apple_clang()
1365

1366
            # check the `OMP_PREFIX` environment first
1367
            if os.getenv("OMP_PREFIX") is not None:
1368
                header_path = os.path.join(os.getenv("OMP_PREFIX"), "include", "omp.h")  # type: ignore[arg-type]
1369
                valid_env = os.path.exists(header_path)
1370
                if valid_env:
1371
                    ipaths.append(os.path.join(os.getenv("OMP_PREFIX"), "include"))  # type: ignore[arg-type]
1372
                    lpaths.append(os.path.join(os.getenv("OMP_PREFIX"), "lib"))  # type: ignore[arg-type]
1373
                else:
1374
                    warnings.warn("environment variable `OMP_PREFIX` is invalid.")
1375
                omp_available = omp_available or valid_env
1376

1377
            libs = [] if omp_available else ["omp"]
1378

1379
            # prefer to use openmp from `conda install llvm-openmp`
1380
            if not omp_available and os.getenv("CONDA_PREFIX") is not None:
1381
                omp_available = is_conda_llvm_openmp_installed()
1382
                if omp_available:
1383
                    conda_lib_path = os.path.join(os.getenv("CONDA_PREFIX"), "lib")  # type: ignore[arg-type]
1384
                    ipaths.append(os.path.join(os.getenv("CONDA_PREFIX"), "include"))  # type: ignore[arg-type]
1385
                    lpaths.append(conda_lib_path)
1386
                    # Prefer Intel OpenMP on x86 machine
1387
                    if os.uname().machine == "x86_64" and os.path.exists(
1388
                        os.path.join(conda_lib_path, "libiomp5.dylib")
1389
                    ):
1390
                        libs = ["iomp5"]
1391

1392
            # next, try to use openmp from `brew install libomp`
1393
            if not omp_available:
1394
                omp_available, libomp_path = homebrew_libomp()
1395
                if omp_available:
1396
                    ipaths.append(os.path.join(libomp_path, "include"))
1397
                    lpaths.append(os.path.join(libomp_path, "lib"))
1398

1399
            # if openmp is still not available, we let the compiler to have a try,
1400
            # and raise error together with instructions at compilation error later
1401
        else:
1402
            libs = ["omp"] if config.is_fbcode() else ["gomp"]
1403

1404
    # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690
1405
    if not config.abi_compatible:
1406
        libs += ["c10"]
1407
        lpaths += [cpp_extension.TORCH_LIB_PATH]
1408

1409
    # third party libs
1410
    if config.is_fbcode():
1411
        ipaths.append(build_paths.sleef())
1412
        ipaths.append(build_paths.openmp())
1413
        ipaths.append(build_paths.cc_include())
1414
        ipaths.append(build_paths.libgcc())
1415
        ipaths.append(build_paths.libgcc_arch())
1416
        ipaths.append(build_paths.libgcc_backward())
1417
        ipaths.append(build_paths.glibc())
1418
        ipaths.append(build_paths.linux_kernel())
1419
        ipaths.append(build_paths.cuda())
1420
        # We also need to bundle includes with absolute paths into a remote directory
1421
        # (later on, we copy the include paths from cpp_extensions into our remote dir)
1422
        ipaths.append("include")
1423

1424
    static_link_libs = []
1425
    if aot_mode and cuda and config.is_fbcode():
1426
        # For Meta internal cuda-12, it is recommended to static link cudart
1427
        static_link_libs = ["-Wl,-Bstatic", "-lcudart_static", "-Wl,-Bdynamic"]
1428

1429
    lpaths_str = " ".join(["-L" + p for p in lpaths])
1430
    libs_str = " ".join(static_link_libs + ["-l" + p for p in libs])
1431
    return ipaths, lpaths_str, libs_str, macros, build_arch_flags
1432

1433

1434
def cpp_compile_command(
1435
    input: Union[str, List[str]],
1436
    output: str,
1437
    warning_all: bool = True,
1438
    shared: bool = True,
1439
    include_pytorch: bool = False,
1440
    vec_isa: VecISA = invalid_vec_isa,
1441
    cuda: bool = False,
1442
    aot_mode: bool = False,
1443
    compile_only: bool = False,
1444
    use_absolute_path: bool = False,
1445
) -> str:
1446
    ipaths, lpaths, libs, macros, build_arch_flags = get_include_and_linking_paths(
1447
        include_pytorch, vec_isa, cuda, aot_mode
1448
    )
1449
    if isinstance(input, str):
1450
        input = [input]
1451
    ipaths_str = " ".join(["-I" + p for p in ipaths])
1452
    clang_flags = ""
1453
    if config.is_fbcode():
1454
        if aot_mode and not use_absolute_path:
1455
            inp_name = input
1456
            out_name = output
1457
            linker_script = _LINKER_SCRIPT
1458
        else:
1459
            # We need to copy any absolute-path torch includes
1460
            inp_name = [os.path.basename(i) for i in input]
1461
            out_name = os.path.basename(output)
1462
            linker_script = os.path.basename(_LINKER_SCRIPT)
1463
        assert is_clang()
1464
        # Use clang runtime instead of libgcc
1465
        clang_flags += " --rtlib=compiler-rt"
1466
        clang_flags += " -fuse-ld=lld"
1467
        clang_flags += f" -Wl,--script={linker_script}"
1468
        linker_paths = "-B" + build_paths.glibc_lib()
1469
        linker_paths += " -L" + build_paths.glibc_lib()
1470
    else:
1471
        inp_name = input
1472
        out_name = output
1473
        linker_paths = ""  # let the compiler pick
1474
    if compile_only:
1475
        libs, lpaths = "", ""
1476
    inp_name_str = " ".join(inp_name)
1477
    return re.sub(
1478
        r"[ \n]+",
1479
        " ",
1480
        f"""
1481
            {cpp_compiler()} {inp_name_str} {get_shared(shared, compile_only)}
1482
            {get_warning_all_flag(warning_all)} {cpp_flags()}
1483
            {get_glibcxx_abi_build_flags()}
1484
            {ipaths_str} {lpaths} {libs} {build_arch_flags}
1485
            {macros} {linker_paths} {clang_flags}
1486
            {optimization_flags()}
1487
            {use_custom_generated_macros()}
1488
            {use_fb_internal_macros()}
1489
            {use_standard_sys_dir_headers()}
1490
            {get_compile_only(compile_only)}
1491
            -o {out_name}
1492
        """,
1493
    ).strip()
1494

1495

1496
def run_command_and_check(cmd: str):
1497
    cmd = shlex.split(cmd)
1498
    try:
1499
        subprocess.check_call(cmd)
1500
    except subprocess.CalledProcessError as e:
1501
        raise exc.CppCompileError(cmd, e.output) from e
1502

1503

1504
@functools.lru_cache(None)
1505
def split_aot_inductor_output_path(path: str) -> Tuple[str, str]:
1506
    """Returns the path where the AOT Inductor compiled kernels are stored."""
1507
    if path.endswith(".so"):
1508
        return os.path.split(path)
1509
    else:
1510
        return path, ""
1511

1512

1513
class CudaKernelParamCache:
1514
    cache: Dict[str, Dict[str, str]] = dict()
1515
    clear = staticmethod(cache.clear)
1516

1517
    @classmethod
1518
    def set(cls, key: str, params: Dict[str, str], cubin: str) -> None:
1519
        bin_type = "cubin" if torch.version.hip is None else "hsaco"
1520
        _, path = write(
1521
            cubin,
1522
            bin_type,
1523
            hash_type=bin_type,
1524
            specified_dir=split_aot_inductor_output_path(
1525
                config.aot_inductor.output_path
1526
            )[0],
1527
        )
1528

1529
        params[get_cpp_wrapper_cubin_path_name()] = path
1530

1531
        cls.cache[key] = params
1532

1533
    @classmethod
1534
    def get(cls, key: str) -> Optional[Dict[str, str]]:
1535
        return cls.cache.get(key, None)
1536

1537
    @classmethod
1538
    def get_keys(cls):
1539
        return cls.cache.keys()
1540

1541

1542
class AotCodeCompiler:
1543
    @classmethod
1544
    def compile(
1545
        cls,
1546
        graph: GraphLowering,
1547
        source_code: str,
1548
        serialized_extern_kernel_nodes: Optional[str],
1549
        cuda: bool,
1550
    ) -> str:
1551
        picked_vec_isa = pick_vec_isa()
1552
        cpp_command = repr(
1553
            cpp_compile_command(
1554
                "i", "o", vec_isa=picked_vec_isa, cuda=cuda, aot_mode=graph.aot_mode
1555
            )
1556
        )
1557
        fbcode_aot_cpu_re = False
1558
        use_absolute_path = False
1559
        if config.is_fbcode():
1560
            ld_command = build_paths.ld()
1561
            if not cuda and graph.aot_mode:  # Meta internal AOTInductor CPU
1562
                objcopy_command = build_paths.objcopy_fallback()
1563
                fbcode_aot_cpu_re = True
1564
                use_absolute_path = True
1565
            else:
1566
                objcopy_command = build_paths.objcopy()
1567
        else:
1568
            ld_command = "ld"
1569
            objcopy_command = "objcopy"
1570

1571
        (
1572
            specified_output_path,
1573
            specified_so_name,
1574
        ) = split_aot_inductor_output_path(config.aot_inductor.output_path)
1575
        key, input_path = write(
1576
            source_code,
1577
            "cpp",
1578
            extra=cpp_command,
1579
            specified_dir=specified_output_path,
1580
        )
1581

1582
        def _compile_consts_linux(consts: bytes) -> str:
1583
            _, consts_path = write(
1584
                consts,
1585
                "bin",
1586
                specified_dir=specified_output_path,
1587
            )
1588

1589
            consts_o = os.path.splitext(consts_path)[0] + ".o"
1590
            if fbcode_aot_cpu_re:
1591
                cmd = f"{ld_command} -r -b binary -o {os.path.basename(consts_o)} {os.path.basename(consts_path)}"
1592
                compile_file(consts_path, consts_o, cmd.split())
1593
                os.chmod(consts_o, 0o644)
1594
            else:
1595
                cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}"
1596
                run_command_and_check(cmd)
1597
            log.debug("aot constant binary command: %s", cmd)
1598

1599
            cmd = (
1600
                f"{objcopy_command} --rename-section"
1601
                " .data=.lrodata,alloc,load,readonly,data,contents"
1602
                f" {consts_o} {consts_o}"
1603
            )
1604
            log.debug("aot constant obj command: %s", cmd)
1605
            run_command_and_check(cmd)
1606

1607
            cmd = f"rm {consts_path}"
1608
            log.debug("aot constant bin removal command: %s", cmd)
1609
            run_command_and_check(cmd)
1610

1611
            if fbcode_aot_cpu_re:
1612
                body = re.sub(r"[\W]", "_", os.path.basename(consts_path))
1613
            else:
1614
                body = re.sub(r"[\W]", "_", consts_path)
1615

1616
            symbol_list = []
1617
            symbol_list.append(
1618
                f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}"
1619
            )
1620
            symbol_list.append(
1621
                f"{objcopy_command} --redefine-sym _binary_{body}_size=_binary_constants_bin_size {consts_o}"
1622
            )
1623
            symbol_list.append(
1624
                f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}"
1625
            )
1626
            log.debug("aot constant binary redefine symbol: %s", " ".join(symbol_list))
1627
            for cmd in symbol_list:
1628
                run_command_and_check(cmd)
1629
            return consts_o
1630

1631
        def _compile_consts_darwin(consts: bytes) -> str:
1632
            is_large_consts = len(consts) > 1024
1633
            consts_asm = "\t.section\t__TEXT,__const\n"
1634
            consts_asm += "\t.globl\t__binary_constants_bin_start\n"
1635
            consts_asm += "__binary_constants_bin_start:\n"
1636
            if not is_large_consts:
1637
                for c in consts:
1638
                    consts_asm += f"\t.byte {c}\n"
1639
                # Add one element even if constants are empty
1640
                # Otherwise assembler will not put them in data section
1641
                if not consts:
1642
                    consts_asm += "\t.space 1\n"
1643
            else:
1644
                consts_asm += "\t.quad 0x1234567899abcdef\n"
1645
                consts_asm += f"\t.space {len(consts) - 8}\n"
1646
            consts_asm += ".globl\t__binary_constants_bin_end\n"
1647
            consts_asm += "__binary_constants_bin_end:\n"
1648
            _, consts_path = write(
1649
                consts_asm,
1650
                "S",
1651
                specified_dir=specified_output_path,
1652
            )
1653
            consts_o = os.path.splitext(consts_path)[0] + ".o"
1654
            cmd = f"{cpp_compiler()} -c -o {consts_o} {consts_path}"
1655
            run_command_and_check(cmd)
1656
            if is_large_consts:
1657
                with open(consts_o, "r+b") as f:
1658
                    f.seek(0)
1659
                    hdr = f.read(1024)
1660
                    # Search for magic number and write the actual data over it
1661
                    start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12")
1662
                    assert start_idx != -1
1663
                    f.seek(start_idx)
1664
                    pos = 0
1665
                    while pos < len(consts):
1666
                        rc = f.write(consts[pos:])
1667
                        pos += rc
1668
            return consts_o
1669

1670
        from filelock import FileLock
1671

1672
        lock_dir = get_lock_dir()
1673
        lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
1674
        with lock:
1675
            # Currently, this only support serializing extern nodes in fbcode
1676
            # Eventually, we should also have a serializer for OSS.
1677
            if config.is_fbcode() and serialized_extern_kernel_nodes:
1678
                output_json = os.path.splitext(input_path)[0] + ".json"
1679
                with open(output_json, "w") as f:
1680
                    f.write(serialized_extern_kernel_nodes)
1681

1682
            output_so = (
1683
                config.aot_inductor.output_path
1684
                if specified_so_name
1685
                else os.path.splitext(input_path)[0] + ".so"
1686
            )
1687

1688
            output_o = os.path.splitext(input_path)[0] + ".o"
1689
            cmd = cpp_compile_command(
1690
                input=input_path,
1691
                output=output_o,
1692
                vec_isa=picked_vec_isa,
1693
                cuda=cuda,
1694
                aot_mode=graph.aot_mode,
1695
                compile_only=True,
1696
                use_absolute_path=use_absolute_path,
1697
            )
1698
            log.debug("aot compilation command: %s", cmd)
1699
            if fbcode_aot_cpu_re:
1700
                compile_file(input_path, output_o, cmd.split())
1701
                os.chmod(output_o, 0o644)
1702
            else:
1703
                run_command_and_check(cmd)
1704

1705
            def _to_bytes(t: torch.Tensor) -> bytes:
1706
                # This serializes the tensor's untyped_storage to bytes by accessing
1707
                # the raw data of the underlying structure.
1708
                import ctypes
1709

1710
                if t.numel() == 0:
1711
                    return b""
1712

1713
                t_cpu = t.untyped_storage().cpu()
1714
                raw_array = ctypes.cast(
1715
                    t_cpu.data_ptr(),
1716
                    ctypes.POINTER(ctypes.c_ubyte * t_cpu.nbytes()),
1717
                )
1718

1719
                return bytes(raw_array.contents)
1720

1721
            aot_constants = b"".join(
1722
                _to_bytes(tensor)
1723
                for name, tensor in graph.constants.items()
1724
                if name not in graph.folded_constants
1725
            )
1726
            consts_o = {
1727
                "linux": _compile_consts_linux,
1728
                "darwin": _compile_consts_darwin,
1729
            }[sys.platform](aot_constants)
1730

1731
            cmd = cpp_compile_command(
1732
                input=[output_o, consts_o],
1733
                output=output_so,
1734
                vec_isa=picked_vec_isa,
1735
                cuda=cuda,
1736
                aot_mode=graph.aot_mode,
1737
                use_absolute_path=use_absolute_path,
1738
            )
1739
            log.debug("aot linkage command: %s", cmd)
1740
            if fbcode_aot_cpu_re:
1741
                compile_file([output_o, consts_o], output_so, cmd.split())
1742
                os.chmod(output_so, 0o755)
1743
            else:
1744
                run_command_and_check(cmd)
1745

1746
        return output_so
1747

1748

1749
# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py.
1750
# Why? importing from cpp.py invokes codecache.pick_vec_isa(), which takes out a lock.
1751
# Cycle goes:
1752
# - CppCodeCache.load()
1753
# - pick_vec_isa()
1754
# - valid_vec_isa_list()
1755
# - VecISA.__bool__() <-- takes out a lock
1756
# - compile_file() <-- imports cpp_prefix_path from cpp, which causes us to try to take out the same lock.
1757
@functools.lru_cache
1758
def cpp_prefix_path() -> str:
1759
    path = Path(__file__).parent / "codegen/cpp_prefix.h"
1760
    with path.open() as f:
1761
        content = f.read()
1762
        _, filename = write(
1763
            content,
1764
            "h",
1765
        )
1766
    return filename
1767

1768

1769
def cpp_prefix() -> str:
1770
    filename = cpp_prefix_path()
1771
    if config.is_fbcode():
1772
        # We need relative paths, since we bundle up
1773
        # everything that we compile into a folder for remote compilation.
1774
        return f'#include "{os.path.basename(filename)}"'
1775
    else:
1776
        return f'#include "{filename}"'
1777

1778

1779
# Given a path to an input cpp file and an output path,
1780
# Attempts to compile the file, storing the output in "output_path"
1781
@dynamo_timed
1782
def compile_file(
1783
    input_path: Union[str, List[str]], output_path: str, cmd: List[str]
1784
) -> None:
1785
    input_paths = [input_path] if isinstance(input_path, str) else input_path
1786
    input_files = [
1787
        os.path.basename(ip) if config.is_fbcode() else ip for ip in input_paths
1788
    ]
1789
    try:
1790
        if config.is_fbcode():
1791
            # Need to copy our header into the same folder as the sourcecode.
1792
            header_path = cpp_prefix_path()
1793
            header_name = os.path.basename(header_path)
1794
            output_name = os.path.basename(output_path)
1795
            # When we build remotely, we need to make sure to carefully copy any files
1796
            # that are required during the compilation process into our build directly.
1797
            # This is where all of the ATen/c10/Torch includes come from.
1798
            torch_includes_path = os.path.join(_TORCH_PATH, "include")
1799
            with tempfile.TemporaryDirectory() as tmp_dir:
1800
                # Copy everything to tmp compilation folder
1801
                shutil.copy(header_path, os.path.join(tmp_dir, header_name))
1802
                shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld"))
1803
                for p, f in zip(input_paths, input_files):
1804
                    shutil.copy(p, os.path.join(tmp_dir, f))
1805
                dest_include_path = os.path.join(tmp_dir, "include")
1806
                shutil.copytree(torch_includes_path, dest_include_path)
1807
                # Run the build
1808
                output_file_path = _run_build_command(cmd, tmp_dir, output_name)
1809
                # Copy output from the build
1810
                if os.path.exists(output_path):
1811
                    os.remove(output_path)
1812
                shutil.copy(output_file_path, output_path)
1813
        else:
1814
            subprocess.check_output(cmd, stderr=subprocess.STDOUT)
1815
    except subprocess.CalledProcessError as e:
1816
        output = e.output.decode("utf-8")
1817
        openmp_problem = "'omp.h' file not found" in output or "libomp" in output
1818
        if openmp_problem and sys.platform == "darwin":
1819
            instruction = (
1820
                "\n\nOpenMP support not found. Please try one of the following solutions:\n"
1821
                "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ "
1822
                "that has builtin OpenMP support;\n"
1823
                "(2) install OpenMP via conda: `conda install llvm-openmp`;\n"
1824
                "(3) install libomp via brew: `brew install libomp`;\n"
1825
                "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path"
1826
                " with `include/omp.h` under it."
1827
            )
1828
            output += instruction
1829
        raise exc.CppCompileError(cmd, output) from e
1830

1831

1832
_libgomp: Optional[CDLL] = None
1833

1834

1835
class CppCodeCache:
1836
    cache: Dict[str, Union[CDLL, ModuleType]] = {}
1837
    clear = staticmethod(cache.clear)
1838
    cpp_compile_command_flags: Dict[str, Any] = {}
1839

1840
    @staticmethod
1841
    def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
1842
        return cdll.LoadLibrary(path)
1843

1844
    @classmethod
1845
    def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
1846
        try:
1847
            return cls._load_library_inner(path, key)
1848
        except (ImportError, OSError) as e:
1849
            if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"):
1850
                # hacky workaround for fbcode/buck
1851
                global _libgomp
1852
                _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1")
1853
                return cls._load_library_inner(path, key)
1854
            if "failed to map segment from shared object" in str(e):
1855
                raise OSError(
1856
                    f"{e}.  The most common reason this may occur is if the {tempfile.gettempdir()} folder "
1857
                    "is mounted with noexec (e.g., by default Docker mounts tmp file systems "
1858
                    f"as noexec).  Please remount {tempfile.gettempdir()} with exec enabled, or set another "
1859
                    "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable."
1860
                ) from e
1861
            raise
1862

1863
    @classmethod
1864
    def load(cls, source_code: str, cuda: bool = False) -> Union[CDLL, ModuleType]:
1865
        cls.cpp_compile_command_flags.update({"cuda": cuda})
1866
        picked_vec_isa = pick_vec_isa()
1867
        cpp_command = repr(
1868
            cpp_compile_command(
1869
                "i", "o", vec_isa=picked_vec_isa, **cls.cpp_compile_command_flags
1870
            )
1871
        )
1872
        key, input_path = write(source_code, "cpp", extra=cpp_command)
1873
        if key not in cls.cache:
1874
            from filelock import FileLock
1875

1876
            lock_dir = get_lock_dir()
1877
            lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
1878
            with lock:
1879
                output_path = input_path[:-3] + "so"
1880
                if not os.path.exists(output_path):
1881
                    cmd = shlex.split(
1882
                        cpp_compile_command(
1883
                            input=input_path,
1884
                            output=output_path,
1885
                            vec_isa=picked_vec_isa,
1886
                            **cls.cpp_compile_command_flags,
1887
                        )
1888
                    )
1889
                    compile_file(input_path, output_path, cmd)
1890
                cls.cache[key] = cls._load_library(output_path, key)
1891
                cls.cache[key].key = key  # type: ignore[union-attr]
1892

1893
        return cls.cache[key]
1894

1895

1896
# Customized Python binding for cpp kernels
1897
class CppPythonBindingsCodeCache(CppCodeCache):
1898
    cache: Dict[str, Union[CDLL, ModuleType]] = {}
1899
    clear = staticmethod(cache.clear)
1900
    cpp_compile_command_flags = {
1901
        # kernels have no dependency on libtorch
1902
        "include_pytorch": False,
1903
        "shared": True,
1904
    }
1905
    entry_function = "kernel"
1906
    call_entry_function = "kernel(%s);Py_RETURN_NONE;"
1907
    extra_parse_arg = ""
1908
    suffix_template = textwrap.dedent(
1909
        """
1910
        // Python bindings to call %s():
1911
        #define PY_SSIZE_T_CLEAN
1912
        #include <Python.h>
1913
        #include <sstream>
1914
        #include <cstdlib>
1915

1916
        // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow.
1917
        // We manually link it below to workaround issues with fbcode build.
1918
        static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj);
1919

1920
        template <typename T> static inline T parse_arg(PyObject* args, size_t n) {
1921
            static_assert(std::is_pointer<T>::value, "arg type must be pointer or long");
1922
            return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n)));
1923
        }
1924
        template <> inline long parse_arg<long>(PyObject* args, size_t n) {
1925
            auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n));
1926
            if(result == -1 && PyErr_Occurred())
1927
                [[unlikely]] throw std::runtime_error("expected int arg");
1928
            return result;
1929
        }
1930

1931
        %s
1932

1933
        static PyObject* %s_py(PyObject* self, PyObject* args) {
1934
            try {
1935
                if(!PyTuple_CheckExact(args))
1936
                    [[unlikely]] throw std::runtime_error("tuple args required");
1937
                if(PyTuple_GET_SIZE(args) != %s)
1938
                    [[unlikely]] throw std::runtime_error("requires %s args");
1939
                %s
1940
            } catch(std::exception const& e) {
1941
                PyErr_SetString(PyExc_RuntimeError, e.what());
1942
                return nullptr;
1943
            } catch(...) {
1944
                PyErr_SetString(PyExc_RuntimeError, "unhandled error");
1945
                return nullptr;
1946
            }
1947
        }
1948

1949
        static PyMethodDef py_methods[] = {
1950
            {"%s", %s_py, METH_VARARGS, ""},
1951
            {NULL, NULL, 0, NULL}};
1952

1953
        static struct PyModuleDef py_module =
1954
            {PyModuleDef_HEAD_INIT, "%s", NULL, -1, py_methods};
1955

1956
        PyMODINIT_FUNC PyInit_%s(void) {
1957
            const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR");
1958
            if(!str_addr) {
1959
                PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set");
1960
                return nullptr;
1961
            }
1962
            std::istringstream iss(str_addr);
1963
            uintptr_t addr = 0;
1964
            iss >> addr;
1965
            _torchinductor_pyobject_tensor_data_ptr =
1966
                reinterpret_cast<decltype(_torchinductor_pyobject_tensor_data_ptr)>(addr);
1967
            return PyModule_Create(&py_module);
1968
        }
1969
        """
1970
    )
1971

1972
    @classmethod
1973
    def _load_library_inner(cls, path: str, key: str) -> ModuleType:
1974
        os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str(
1975
            torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr  # type: ignore[attr-defined]
1976
        )
1977
        return importlib.machinery.ExtensionFileLoader(
1978
            f"{key}.{cls.entry_function}", path
1979
        ).load_module()  # type: ignore[call-arg]
1980

1981
    @classmethod
1982
    def load_pybinding(
1983
        cls,
1984
        argtypes: List[str],
1985
        source_code: str,
1986
        cuda: bool = False,
1987
        num_outputs: int = -1,
1988
    ) -> Any:
1989
        """
1990
        Wrap a C++ function in fast Python bindings.
1991

1992
        Args:
1993
            argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
1994
            source_code: C++ source code containing a ENTRY_FUNCTION() function
1995

1996
        Returns:
1997
            A python version of ENTRY_FUNCTION()
1998
        """
1999
        parseargs = ", ".join(
2000
            f"parse_arg<{argtype.replace('const ', '')}>(args, {n})"
2001
            for n, argtype in enumerate(argtypes)
2002
        )
2003
        suffix = cls.suffix_template % (
2004
            cls.entry_function,
2005
            cls.extra_parse_arg % num_outputs if cls.extra_parse_arg else "",
2006
            cls.entry_function,
2007
            len(argtypes),
2008
            len(argtypes),
2009
            cls.call_entry_function % parseargs,
2010
            cls.entry_function,
2011
            cls.entry_function,
2012
            cls.entry_function,
2013
            cls.entry_function,
2014
        )
2015
        result = cls.load(source_code + suffix, cuda)
2016
        assert isinstance(result, ModuleType)
2017
        return getattr(result, cls.entry_function)
2018

2019

2020
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
2021
    cache: Dict[str, Union[CDLL, ModuleType]] = {}
2022
    clear = staticmethod(cache.clear)
2023
    cpp_compile_command_flags = {
2024
        "include_pytorch": True,
2025
        "shared": True,
2026
    }
2027
    entry_function = "inductor_entry_cpp"
2028
    call_entry_function = "return THPVariable_WrapList(inductor_entry_cpp(%s));"
2029
    extra_parse_arg = textwrap.dedent(
2030
        """
2031
        #include <torch/csrc/autograd/python_variable.h>
2032
        #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
2033

2034
        template <> inline std::vector<at::Tensor> parse_arg<std::vector<at::Tensor>>(PyObject* args, size_t n) {
2035
            return THPVariable_UnpackList(PyTuple_GET_ITEM(args, n));
2036
        }
2037

2038
        std::vector<at::Tensor> inductor_entry_cpp(std::vector<at::Tensor>&& inputs) {
2039
            auto input_handles =
2040
                torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(inputs);
2041
            // For outputs, we only allocate a vector to hold returned tensor handles,
2042
            // not allocating the actual output tensor storage here
2043
            std::vector<AtenTensorHandle> output_handles(%s);
2044

2045
            try {
2046
                inductor_entry_impl(input_handles.data(), output_handles.data());
2047
            } catch(std::exception const& e) {
2048
                PyErr_SetString(PyExc_RuntimeError, e.what());
2049
                return {};
2050
            } catch(...) {
2051
                PyErr_SetString(PyExc_RuntimeError, "unhandled error");
2052
                return {};
2053
            }
2054

2055
            return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
2056
                output_handles.data(), output_handles.size());
2057
        }
2058
        """
2059
    )
2060

2061

2062
class PyCodeCache:
2063
    cache: Dict[str, ModuleType] = dict()
2064
    linemaps: Dict[str, List[Tuple[Any, ...]]] = dict()
2065
    clear = staticmethod(cache.clear)
2066

2067
    @classmethod
2068
    def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
2069
        return write(source_code, "py", extra=extra)
2070

2071
    @classmethod
2072
    def load(
2073
        cls,
2074
        source_code: str,
2075
        extra: str = "",
2076
        linemap: Optional[List[Tuple[int, str]]] = None,
2077
        attrs: Optional[Dict[str, Any]] = None,
2078
    ) -> ModuleType:
2079
        key, path = write(source_code, "py", extra=extra)
2080
        return cls.load_by_key_path(key, path, linemap, attrs)
2081

2082
    @classmethod
2083
    def load_by_key_path(
2084
        cls,
2085
        key: str,
2086
        path: str,
2087
        linemap: Optional[List[Tuple[int, str]]] = None,
2088
        attrs: Optional[Dict[str, Any]] = None,
2089
    ) -> ModuleType:
2090
        if linemap is None:
2091
            linemap = []
2092
        if key not in cls.cache:
2093
            with open(path) as f:
2094
                try:
2095
                    code = compile(f.read(), path, "exec")
2096
                except Exception as e:
2097
                    raise RuntimeError(
2098
                        f"Failed to import {path}\n{type(e).__name__}: {e}"
2099
                    ) from None
2100
                mod = ModuleType(f"{__name__}.{key}")
2101
                mod.__file__ = path
2102
                mod.key = key  # type: ignore[attr-defined]
2103
                exec(code, mod.__dict__, mod.__dict__)
2104
                sys.modules[mod.__name__] = mod
2105
                # another thread might set this first
2106
                cls.cache.setdefault(key, mod)
2107
                # unzip into separate lines/nodes lists
2108
                cls.linemaps[path] = list(zip(*linemap))
2109

2110
                if attrs is not None:
2111
                    for k, v in attrs.items():
2112
                        setattr(mod, k, v)
2113

2114
        return cls.cache[key]
2115

2116
    @classmethod
2117
    @functools.lru_cache(None)
2118
    def stack_frames_for_code(
2119
        cls, path: str, lineno: int
2120
    ) -> Optional[List[Dict[str, Any]]]:
2121
        if path not in cls.linemaps:
2122
            return None
2123
        # [(starting_line, <fx node>), ...]
2124
        lines, nodes = cls.linemaps[path]
2125
        p = bisect_right(lines, lineno)
2126
        if p == 0:
2127
            return None
2128
        entry = nodes[p - 1]
2129
        if not entry:
2130
            return None
2131

2132
        def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]:
2133
            # ideally fx stores stack traces as data rather than a string
2134
            # but this is not along a performance critical path
2135
            regex = r'File "(.+)", line (\d+), in (.+)\n'
2136
            matches = re.findall(regex, stack_trace)
2137
            return [
2138
                {"filename": f, "line": int(l), "name": n}
2139
                for f, l, n in reversed(matches)
2140
            ]
2141

2142
        return parse_stack_trace(entry)
2143

2144

2145
class TritonCodeCache:
2146
    @classmethod
2147
    def load(cls, kernel_name: str, source_code: str) -> ModuleType:
2148
        mod = PyCodeCache.load(source_code)
2149
        return getattr(mod, kernel_name)
2150

2151

2152
def _cuda_compiler() -> Optional[str]:
2153
    if cuda_env.nvcc_exist(config.cuda.cuda_cxx):
2154
        return config.cuda.cuda_cxx
2155
    if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
2156
        return os.getenv("CUDACXX", "")
2157
    if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
2158
        return os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc")
2159
    return "nvcc"
2160

2161

2162
def _cutlass_include_paths() -> List[str]:
2163
    cutlass_path = config.cuda.cutlass_dir
2164
    return [
2165
        os.path.join(cutlass_path, "include"),
2166
        os.path.join(cutlass_path, "tools/library/include"),
2167
        os.path.join(cutlass_path, "tools/library/src"),
2168
        os.path.join(cutlass_path, "tools/util/include"),
2169
    ]
2170

2171

2172
def _cuda_lib_options() -> List[str]:
2173
    from torch.utils import cpp_extension
2174

2175
    extra_ldflags: List[str] = []
2176
    if is_linux():
2177
        extra_lib_dir = "lib64"
2178
        if not os.path.exists(
2179
            cpp_extension._join_cuda_home(extra_lib_dir)
2180
        ) and os.path.exists(cpp_extension._join_cuda_home("lib")):
2181
            # 64-bit CUDA may be installed in "lib"
2182
            # Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64"
2183
            extra_lib_dir = "lib"
2184
        extra_ldflags.append(f"-L{cpp_extension._join_cuda_home(extra_lib_dir)}")
2185
        extra_ldflags.append(
2186
            f'-L{cpp_extension._join_cuda_home(extra_lib_dir, "stubs")}'
2187
        )
2188
        extra_ldflags.append("-lcuda")
2189
        extra_ldflags.append("-lcudart")
2190
    else:
2191
        raise NotImplementedError(
2192
            "Unsupported env, failed to find cuda libs! Currently only Linux is supported."
2193
        )
2194
    return extra_ldflags
2195

2196

2197
def _nvcc_host_compiler_options() -> List[str]:
2198
    return [
2199
        "-fPIC",
2200
        "-fno-strict-aliasing",
2201
        "-fvisibility=hidden",
2202
        "-Wconversion",
2203
    ]
2204

2205

2206
def _nvcc_compiler_options() -> List[str]:
2207
    arch = cuda_env.get_cuda_arch()
2208
    if arch == "90":
2209
        # Required by cutlass compilation.
2210
        arch = "90a"
2211
    code = [f"sm_{arch}", f"compute_{arch}"]
2212
    if config.cuda.enable_cuda_lto:
2213
        code += [f"lto_{arch}"]
2214
    options = [
2215
        "-t=0",
2216
        "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
2217
        "-w",
2218
        f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
2219
        config.cuda.compile_opt_level,
2220
        "-std=c++17",
2221
        "--expt-relaxed-constexpr",
2222
        "-DNDEBUG",
2223
    ]
2224
    if config.cuda.enable_debug_info:
2225
        options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
2226
    if config.cuda.enable_ptxas_info:
2227
        options.extend(
2228
            [
2229
                "--keep",  # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
2230
                "--ptxas-options=--warn-on-local-memory-usage",  # warn us if local memory is used in CUDA Kernels
2231
                "--ptxas-options=--warn-on-spills",  # warn us if register spilling happens in CUDA Kernels
2232
                "--resource-usage",  # Report on CUDA resource usage (shared mem, registers etc.)
2233
                "--source-in-ptx",
2234
            ]
2235
        )  # Annotate the ptx file with source information
2236
    if config.cuda.use_fast_math:
2237
        options.extend(
2238
            [
2239
                "--use_fast_math",
2240
                "-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
2241
            ]
2242
        )
2243
    return options
2244

2245

2246
def cuda_compile_command(
2247
    src_files: List[str],
2248
    dst_file: str,
2249
    dst_file_ext: str,
2250
) -> str:
2251
    include_paths = _cutlass_include_paths()
2252
    cuda_lib_options = _cuda_lib_options()
2253
    nvcc_host_compiler_options = _nvcc_host_compiler_options()
2254
    nvcc_compiler_options = _nvcc_compiler_options()
2255
    options = (
2256
        nvcc_compiler_options
2257
        + [
2258
            f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
2259
            for opt in nvcc_host_compiler_options
2260
        ]
2261
        + ["-I" + path for path in include_paths]
2262
        + cuda_lib_options
2263
    )
2264
    src_file = " ".join(src_files)
2265
    res = ""
2266
    if dst_file_ext == "o":
2267
        res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
2268
    elif dst_file_ext == "so":
2269
        options.append("-shared")
2270
        res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
2271
    else:
2272
        raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
2273
    log.debug("CUDA command: %s", res)
2274
    return res
2275

2276

2277
class DLLWrapper:
2278
    """A wrapper for a dynamic library."""
2279

2280
    def __init__(
2281
        self,
2282
        lib_path: str,
2283
    ):
2284
        self.lib_path = lib_path
2285
        self.DLL = cdll.LoadLibrary(lib_path)
2286
        self.is_open = True
2287

2288
    def close(self):
2289
        if self.is_open:
2290
            self._dlclose()
2291
            self.is_open = False
2292

2293
    def _dlclose(self):
2294
        f_dlclose = None
2295

2296
        if is_linux():
2297
            syms = CDLL(None)
2298
            if not hasattr(syms, "dlclose"):
2299
                # Apline Linux
2300
                syms = CDLL("libc.so")
2301

2302
            if hasattr(syms, "dlclose"):
2303
                f_dlclose = syms.dlclose
2304
        else:
2305
            raise NotImplementedError("Unsupported env, failed to do dlclose!")
2306

2307
        if f_dlclose is not None:
2308
            f_dlclose.argtypes = [c_void_p]
2309
            f_dlclose(self.DLL._handle)
2310
        else:
2311
            log.warning(
2312
                "dll unloading function was not found, library may not be unloaded properly!"
2313
            )
2314

2315
    def __getattr__(self, name):
2316
        if not self.is_open:
2317
            raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}")
2318

2319
        method = getattr(self.DLL, name)
2320

2321
        def _wrapped_func(*args):
2322
            err = method(*args)
2323
            if err:
2324
                raise RuntimeError(f"Error in function: {method.__name__}")
2325

2326
        return _wrapped_func
2327

2328
    def __enter__(self):
2329
        return self
2330

2331
    def __exit__(self, *args):
2332
        self.close()
2333

2334
    def __del__(self):
2335
        self.close()
2336

2337

2338
class CUDACodeCache:
2339
    @dataclasses.dataclass
2340
    class CacheEntry:
2341
        input_path: str
2342
        output_path: str
2343

2344
    cache: Dict[str, CacheEntry] = dict()
2345
    clear = staticmethod(cache.clear)
2346
    _SOURCE_CODE_SUFFIX = "cu"
2347

2348
    @classmethod
2349
    def write(cls, source_code, dst_file_ext) -> Tuple[str, str]:
2350
        """
2351
        Writes source code into a file with dst_file_ext as the file extension.
2352
        Returns the hash key of source code, and the path to the file.
2353
        """
2354

2355
        cuda_command = repr(
2356
            cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext)
2357
        )
2358
        key, input_path = write(
2359
            source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command
2360
        )
2361
        return key, input_path
2362

2363
    @classmethod
2364
    def compile(cls, source_code, dst_file_ext) -> Tuple[str, str, str]:
2365
        """
2366
        Compiles CUDA source_code into a file with dst_file_ext extension.
2367
        Returns a tuple of dst_file_path, hash_key, source_code_path
2368
        """
2369

2370
        key, input_path = cls.write(source_code, dst_file_ext)
2371
        if key not in cls.cache:
2372
            from filelock import FileLock
2373

2374
            lock_dir = get_lock_dir()
2375
            lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
2376
            with lock:
2377
                output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
2378
                if not os.path.exists(output_path):
2379
                    cmd = cuda_compile_command(
2380
                        [input_path], output_path, dst_file_ext
2381
                    ).split(" ")
2382
                    try:
2383
                        subprocess.check_output(
2384
                            cmd, stderr=subprocess.STDOUT, env=os.environ
2385
                        )
2386
                    except subprocess.CalledProcessError as error:
2387
                        raise exc.CUDACompileError(cmd, error.output) from error
2388
                cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path)
2389

2390
        return (cls.cache[key].output_path, key, input_path)
2391

2392
    @classmethod
2393
    def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]:
2394
        """
2395
        Compiles source code and loads the generated .so file.
2396
        Returns a tuple of DLLWrapper, hash_key, source_code_path
2397
        """
2398

2399
        if dst_file_ext != "so":
2400
            raise RuntimeError(
2401
                f"Only support loading a .so file for now. "
2402
                f"Requested file extension: {dst_file_ext}. Source code: {source_code}"
2403
            )
2404
        dst_file_path, hash_key, source_code_path = cls.compile(
2405
            source_code, dst_file_ext
2406
        )
2407
        return (DLLWrapper(dst_file_path), hash_key, source_code_path)
2408

2409

2410
def caching_device_properties():
2411
    for _, device_interface in get_registered_device_interfaces():
2412
        if device_interface.is_available():
2413
            device_interface.Worker.get_device_properties()
2414

2415

2416
def _set_triton_ptxas_path() -> None:
2417
    if os.environ.get("TRITON_PTXAS_PATH") is not None:
2418
        return
2419
    ptxas_path = os.path.abspath(
2420
        os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas")
2421
    )
2422
    if not os.path.exists(ptxas_path):
2423
        return
2424
    if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK):
2425
        os.environ["TRITON_PTXAS_PATH"] = ptxas_path
2426
    else:
2427
        warnings.warn(f"{ptxas_path} exists but is not an executable")
2428

2429

2430
def _worker_compile(
2431
    kernel_name: str, source_code: str, cc: int, device: torch.device
2432
) -> None:
2433
    device_interface = get_interface_for_device(device.type)
2434
    device_interface.Worker.set_device(device.index)
2435
    kernel = TritonCodeCache.load(kernel_name, source_code)
2436
    kernel.precompile(warm_cache_only_with_cc=cc)
2437

2438

2439
def _load_kernel(kernel_name: str, source_code: str) -> ModuleType:
2440
    _set_triton_ptxas_path()
2441
    kernel = TritonCodeCache.load(kernel_name, source_code)
2442
    kernel.precompile()
2443
    return kernel
2444

2445

2446
class TritonFuture:
2447
    kernel: ModuleType
2448

2449
    def __init__(
2450
        self,
2451
        kernel_name: str,
2452
        source_code: str,
2453
        future: Future[Any],
2454
    ) -> None:
2455
        self.kernel_name = kernel_name
2456
        self.source_code = source_code
2457
        self.future = future
2458

2459
    # @dynamo_utils.dynamo_timed
2460
    def result(self) -> ModuleType:
2461
        t0 = time()
2462
        if hasattr(self, "kernel"):
2463
            return self.kernel
2464
        # If the worker failed this will throw an exception.
2465
        self.future.result()
2466
        kernel = self.kernel = _load_kernel(self.kernel_name, self.source_code)
2467
        latency = time() - t0
2468
        if latency > 50:
2469
            developer_warning(
2470
                f"Detected long compilation time of {latency} seconds for kernel name {self.kernel_name}"
2471
            )
2472
            developer_warning(self.source_code)
2473
        del self.kernel_name, self.source_code, self.future
2474
        return kernel
2475

2476

2477
# If this process dies abnormally (e.g. segfault)
2478
# it will not shut down the workers. Instead
2479
# the workers will have their parent reassigned to the
2480
# init process. This launches a separate thread to
2481
# watch for the worker getting reassigned,
2482
# and cleans it up in this case.
2483
#
2484
# This function cannot be an inner function since otherwise mp_context="spawn" would
2485
# not work for ProcessPoolExecutor since inner functions cannot be pickled.
2486
def _async_compile_initializer(orig_ppid) -> None:
2487
    def run() -> None:
2488
        while True:
2489
            sleep(1)
2490
            if orig_ppid != os.getppid():
2491
                os.kill(os.getpid(), signal.SIGKILL)
2492

2493
    global _watchdog_thread
2494
    _watchdog_thread = Thread(target=run, daemon=True)
2495
    _watchdog_thread.start()
2496
    # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam.
2497
    signal.signal(signal.SIGINT, signal.SIG_IGN)
2498

2499

2500
_watchdog_thread: Optional[Thread] = None
2501

2502
# Used to keep track of all process pools invoked so far.
2503
_pool_set: Set[ProcessPoolExecutor] = set()
2504

2505

2506
def shutdown_compile_workers() -> None:
2507
    """Shut down all outstanding compile-worker pools."""
2508
    global _pool_set
2509
    for pool in _pool_set:
2510
        pool.shutdown()
2511
    _pool_set = set()
2512

2513

2514
class AsyncCompile:
2515
    def __init__(self) -> None:
2516
        pass
2517

2518
    @staticmethod
2519
    @functools.lru_cache(1)
2520
    def pool() -> ThreadPoolExecutor:
2521
        assert config.compile_threads > 1
2522
        return ThreadPoolExecutor(config.compile_threads)
2523

2524
    @staticmethod
2525
    @functools.lru_cache(1)
2526
    def process_pool() -> ProcessPoolExecutor:
2527
        # ensure properties have been calculated before processes
2528
        # are forked
2529
        caching_device_properties()
2530
        assert config.compile_threads > 1
2531
        orig_ppid = os.getpid()
2532

2533
        ctx = multiprocessing.get_context(config.worker_start_method)
2534
        pool = ProcessPoolExecutor(
2535
            config.compile_threads,
2536
            mp_context=ctx,
2537
            initializer=partial(_async_compile_initializer, orig_ppid),
2538
        )
2539

2540
        global _pool_set
2541
        _pool_set.add(pool)
2542

2543
        # when this pool is created in a subprocess object, the normal exit handler
2544
        # doesn't run, and we need to register our own handler.
2545
        # exitpriority has to be high, because another one of the finalizers will
2546
        # kill the worker thread that sends the shutdown message to the workers...
2547
        multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
2548
        return pool
2549

2550
    @classmethod
2551
    def warm_pool(cls) -> None:
2552
        if config.compile_threads <= 1:
2553
            return
2554
        _compile_start()
2555
        pool = cls.process_pool()
2556

2557
        # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the
2558
        # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread.
2559

2560
        # Examples:
2561
        # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup
2562
        # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup
2563

2564
        # So we want to start the workers early when it is still cheap, and also to allow the workers to get
2565
        # ready before we have work for them.
2566

2567
        # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle.
2568
        # But if we waited until then fork time will be long and we will be waiting for the processes to initialize.
2569

2570
        # We force them to start here with some YOLOing of the internal methods.
2571
        if hasattr(pool, "_start_queue_management_thread"):
2572
            pool._start_queue_management_thread()
2573
        else:
2574
            for _ in range(config.compile_threads):
2575
                pool._adjust_process_count()
2576
            if hasattr(pool, "_start_executor_manager_thread"):
2577
                pool._start_executor_manager_thread()
2578
        _compile_end()
2579

2580
    @classmethod
2581
    def submit(cls, task: Callable[..., Any]) -> Any:
2582
        if config.compile_threads <= 1:
2583
            return task()
2584
        return cls.pool().submit(task)
2585

2586
    @classmethod
2587
    def map(cls, fn: Callable[..., Any], seq: List[Any]) -> List[Any]:
2588
        if config.compile_threads <= 1 or len(seq) <= 1:
2589
            return list(map(fn, seq))
2590
        return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]
2591

2592
    def triton(
2593
        self, kernel_name: str, source_code: str, device_str: str = "cuda"
2594
    ) -> Union[TritonFuture, ModuleType]:
2595
        _compile_start()
2596

2597
        if config.compile_threads > 1:
2598
            device_interface = get_interface_for_device(device_str)
2599
            device = torch.device(device_str, device_interface.current_device())
2600
            cc = device_interface.get_compute_capability(device)
2601
            future = self.process_pool().submit(
2602
                _worker_compile, kernel_name, source_code, cc, device
2603
            )
2604
            return TritonFuture(kernel_name, source_code, future)
2605
        else:
2606
            return _load_kernel(kernel_name, source_code)
2607

2608
    def multi_kernel(self, *args, **kwargs) -> ModuleType:
2609
        """
2610
        Async compile the python shim for multi-kernel.
2611
        """
2612

2613
        def task():
2614
            from torch._inductor.codegen.multi_kernel import MultiKernelCall
2615

2616
            return MultiKernelCall(*args, **kwargs)
2617

2618
        return self.submit(task)
2619

2620
    def cpp(self, source_code: str) -> ModuleType:
2621
        def task():
2622
            return CppCodeCache.load(source_code).kernel
2623

2624
        return self.submit(task)
2625

2626
    def cpp_pybinding(self, argtypes: List[str], source_code: str) -> ModuleType:
2627
        return self.submit(
2628
            functools.partial(
2629
                CppPythonBindingsCodeCache.load_pybinding, argtypes, source_code
2630
            )
2631
        )
2632

2633
    def cuda(self, source_code, dst_file_ext):
2634
        def task():
2635
            return CUDACodeCache.load(source_code, dst_file_ext)[0]
2636

2637
        return self.submit(task)
2638

2639
    def wait(self, scope: Dict[str, Any]) -> None:
2640
        num_kernels = len(
2641
            [
2642
                value
2643
                for key, value in scope.items()
2644
                if isinstance(value, (Future, TritonFuture))
2645
            ]
2646
        )
2647
        pbar = tqdm(
2648
            total=num_kernels,
2649
            desc="Inductor Compilation",
2650
            disable=config.disable_progress,
2651
            delay=0,
2652
        )
2653
        if config.compile_threads > 1:
2654
            for key, result in scope.items():
2655
                if config.verbose_progress and not isinstance(pbar, _Faketqdm):
2656
                    pbar.set_postfix_str(key)
2657
                if isinstance(result, (Future, TritonFuture)):
2658
                    scope[key] = result.result()
2659
                    pbar.update(1)
2660

2661
        _compile_end()
2662

2663

2664
AsyncCompile.warm_pool()
2665

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.