1
from __future__ import annotations
30
from bisect import bisect_right
31
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
33
from ctypes import c_void_p, cdll, CDLL
34
from dataclasses import field
35
from functools import partial
36
from pathlib import Path
37
from threading import Thread
38
from time import sleep, time
39
from types import ModuleType
40
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
44
from torch._dynamo.device_interface import (
45
get_interface_for_device,
46
get_registered_device_interfaces,
48
from torch._dynamo.utils import counters, dynamo_timed
49
from torch._inductor import config, exc
50
from torch._inductor.codegen.cuda import cuda_env
51
from torch._inductor.utils import cache_dir, developer_warning, is_linux
52
from torch._subclasses.fake_tensor import (
53
extract_tensor_metadata,
57
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
60
from torch._inductor.graph import GraphLowering
61
from torch._inductor.select_algorithm import ChoiceCaller
63
from torch.hub import _Faketqdm, tqdm
65
_HERE = os.path.abspath(__file__)
66
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
67
_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld")
70
from triton.fb import build_paths
71
from triton.fb.build import _run_build_command
73
from torch._inductor.fb.utils import (
74
log_global_cache_errors,
75
log_global_cache_stats,
76
log_global_cache_vals,
81
def log_global_cache_errors(*args, **kwargs):
84
def log_global_cache_stats(*args, **kwargs):
87
def log_global_cache_vals(*args, **kwargs):
90
def use_global_cache() -> bool:
96
# timing metrics for time spent in the compilation
97
_cumulative_compile_time = 0.0
98
_t0: Optional[float] = None
101
def _compile_start() -> None:
107
def _compile_end() -> None:
108
global _cumulative_compile_time, _t0
111
_cumulative_compile_time += t1 - _t0
113
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
116
log = logging.getLogger(__name__)
119
def cpp_wrapper_cache_dir(name: str) -> str:
122
if torch.version.cuda is None
123
else f'cu{torch.version.cuda.replace(".", "")}'
125
python_version = f"py{sys.version_info.major}{sys.version_info.minor}"
126
build_folder = f"{python_version}_{cu_str}"
128
cpp_wrapper_dir = os.path.join(cache_dir(), build_folder)
129
cpp_wrapper_build_directory = os.path.join(cpp_wrapper_dir, name)
130
os.makedirs(cpp_wrapper_build_directory, exist_ok=True)
131
return cpp_wrapper_build_directory
134
def get_cpp_wrapper_cubin_path_name():
135
return "cubin_path" if torch.version.hip is None else "hsaco_path"
140
@functools.lru_cache(None)
141
def get_system() -> Dict[str, Any]:
145
triton_version = triton.__version__
146
except ModuleNotFoundError:
147
triton_version = None
150
system: Dict[str, Any] = {
152
"name": torch.cuda.get_device_properties(
153
torch.cuda.current_device()
157
"cuda": torch.version.cuda,
158
"triton": triton_version,
161
except (AssertionError, RuntimeError):
162
# If cuda is not installed, none of the above config is relevant.
165
system["hash"] = hashlib.sha256(
166
json.dumps(system, sort_keys=True).encode("utf-8")
172
@functools.lru_cache(None)
173
def get_local_cache_path() -> Path:
174
return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
177
@functools.lru_cache(None)
178
def get_global_cache_path() -> Optional[Path]:
180
Path(os.path.join(config.global_cache_dir, CacheBase.get_system()["hash"]))
181
if config.global_cache_dir is not None
185
def __init__(self) -> None:
186
if not torch.cuda.is_available():
189
self.system = CacheBase.get_system()
191
self.local_cache_path = CacheBase.get_local_cache_path()
192
self.global_cache_path = CacheBase.get_global_cache_path()
194
def get_local_cache(self) -> Dict[str, Any]:
195
if not self.local_cache_path.is_file():
197
with open(self.local_cache_path) as local_cache_fp:
198
local_cache = json.load(local_cache_fp)
199
return local_cache["cache"]
201
def update_local_cache(self, local_cache: Dict[str, Any]) -> None:
202
if not os.path.exists(self.local_cache_path.parent):
203
os.makedirs(self.local_cache_path.parent, exist_ok=True)
206
str(self.local_cache_path),
207
json.dumps({"system": self.system, "cache": local_cache}, indent=4),
211
class LocalCache(CacheBase):
212
def lookup(self, *keys: str) -> Optional[Dict[str, Any]]:
213
cache = self.get_local_cache()
218
sub_cache = cache[key]
224
def set_value(self, *keys: str, value: Any) -> None:
225
cache = self.get_local_cache()
228
for key in keys[0:-1]:
229
sub_cache.setdefault(key, {})
230
sub_cache = sub_cache[key]
231
sub_cache[keys[-1]] = value
233
self.update_local_cache(cache)
236
class PersistentCache(CacheBase):
237
@functools.lru_cache(None)
238
def get_global_cache(self):
239
if self.global_cache_path is None or not self.global_cache_path.is_file():
241
with open(self.global_cache_path) as global_cache_fp:
242
global_cache = json.load(global_cache_fp)
243
return global_cache["cache"]
247
choices: List[ChoiceCaller],
250
benchmark: Callable[[Any], Dict[ChoiceCaller, float]],
251
) -> Dict[ChoiceCaller, float]:
253
Check to see if we have benchmarked the given choice callers. For each
256
1. Check global_cache[op][inputs][choice][precision], return benchmark if cached.
257
2. Check local_cache[op][inputs][choice][precision], return benchmark if cached.
259
a. `max_autotune_gemm=True`: benchmark the choice, update
260
local_cache[op][inputs][choice], and return the benchmark.
261
b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing.
263
precision = torch.get_float32_matmul_precision()
265
log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision)
266
log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision)
267
log_errors = partial(
268
log_global_cache_errors, self.system, op, inputs, precision
272
def check_cache(cache, callback=None) -> bool:
273
"""Check if `cache` contains data for all the choices"""
275
for choice in choices:
276
choice_hash = choice.hash_key()
277
if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}):
279
timings[choice] = cache[op][inputs][precision][choice_hash]
288
if config.max_autotune or config.max_autotune_gemm:
289
local_cache = self.get_local_cache()
290
# check local cache first since it is data specific to the current machine
291
if not check_cache(local_cache) and not (
293
and check_cache(self.get_global_cache(), callback=log_stats)
296
# re-benchmark everything to try to get consistent numbers from the same machine
297
timings = benchmark(choices)
298
assert all(choice in timings for choice in choices)
299
local_cache.setdefault(op, {})
300
local_cache[op].setdefault(inputs, {}).setdefault(precision, {})
301
for choice, timing in timings.items():
302
local_cache[op][inputs][precision][choice.hash_key()] = timing
303
except RuntimeError as e:
304
# catch and log autotuning failures
308
self.update_local_cache(local_cache)
311
choice.hash_key(): timings[choice] for choice in choices
313
log_vals(timings_to_log)
314
elif use_global_cache():
315
# only check global cache, not local one
316
check_cache(self.get_global_cache(), callback=log_stats)
317
# may have a partial cache hit, where not everything is benchmarked
322
def get_lock_dir() -> str:
323
lock_dir = os.path.join(cache_dir(), "locks")
324
if not os.path.exists(lock_dir):
325
os.makedirs(lock_dir, exist_ok=True)
329
def sha256_hash(data: bytes) -> str:
330
# [:51] to strip off the "Q====" suffix common to every hash value.
331
return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
334
def code_hash(code: Union[str, bytes], extra: str = ""):
335
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
337
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
338
return "c" + sha256_hash(hashing_str)
342
basename: str, extension: str, specified_dir: str = ""
343
) -> Tuple[str, str, str]:
345
if os.path.isabs(specified_dir):
346
subdir = specified_dir
348
subdir = os.path.join(cache_dir(), specified_dir)
350
subdir = os.path.join(cache_dir(), basename[1:3])
351
path = os.path.join(subdir, f"{basename}.{extension}")
352
return basename, subdir, path
355
def get_hash(content: Union[str, bytes], extra: str = "", hash_type: str = "code"):
356
if hash_type == "code":
357
return code_hash(content, extra)
358
if hash_type in ["cubin", "hsaco"]:
359
return code_hash(repr(content))
360
raise AssertionError(f"Unknown hash type {hash_type}")
364
content: Union[str, bytes],
367
hash_type: str = "code",
368
specified_dir: str = "",
370
# use striped content to compute hash so we don't end up with different
371
# hashes just because the content begins/ends with differnet number of
373
key: str = get_hash(content.strip(), extra, hash_type)
374
basename, subdir, path = get_path(key, extension, specified_dir)
375
if not os.path.exists(subdir):
376
os.makedirs(subdir, exist_ok=True)
377
if not os.path.exists(path):
378
write_atomic(path, content)
379
return basename, path
382
def write_atomic(path: str, content: Union[str, bytes]) -> None:
383
# Write into temporary file first to avoid conflicts between threads
384
# Avoid using a named temporary file, as those have restricted permissions
386
content, (str, bytes)
387
), "Only strings and byte arrays can be saved in the cache"
388
path = pathlib.Path(path)
389
tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
390
write_mode = "w" if isinstance(content, str) else "wb"
391
with tmp_path.open(write_mode) as f:
393
tmp_path.rename(path)
396
@dataclasses.dataclass
397
class TensorMetadataAndValues:
399
TensorMetadata plus the elements as a list of raw values.
400
Used for hashing inlined constants.
403
tensor_metadata: TensorMetadata
407
def _ident(x: Any) -> Any:
411
def _reduce_fake_tensor(t):
413
See FxGraphCachePickler. Custom reducer to pickle FakeTensors.
415
metadata = extract_tensor_metadata(t)
416
return (_ident, (metadata,))
419
def _reduce_tensor(t):
421
See FxGraphCachePickler. Custom reducer to pickle Tensors.
423
# If we see tensors, we know they're constants stored as attributes on
424
# the GraphModule. See tensor lowering; small constants are inlined. If
425
# we see a small tensor, therefore, no reference will ultimately remain
426
# in the generated code. So we need to include its value in the cache key.
427
# Large constants are effectively treated as inputs and we consider only
429
metadata = extract_tensor_metadata(t)
430
if len(t.shape) == 0 or torch._inductor.graph.GraphLowering.can_inline_constant(t):
431
return (_ident, (TensorMetadataAndValues(metadata, t.tolist()),))
433
return (_ident, (metadata,))
436
def _reduce_symint(s):
438
See FxGraphCachePickler. Custom reducer to pickle SymInts.
440
# For hashing purposes, we only care about the name of the symbol and
441
# not the backed value. We evaluate guards stored with a cached graph
442
# to ensure a cached entity with SymInt args is safe to reuse.
443
return (_ident, (str(s),))
446
class FxGraphCachePickler(pickle.Pickler):
448
Custom pickler to customize the pickling of some objects (Tensors), only for the
449
purpose of computing a hash for keying into the FxGraphCache. Tensors contain
450
objects that don't pickle and/or vary between runs, and we want to capture the
451
data that allow us to compute a stable, but safe hash.
454
dispatch_table = copyreg.dispatch_table.copy()
455
dispatch_table[FakeTensor] = _reduce_fake_tensor
456
dispatch_table[torch.Tensor] = _reduce_tensor
457
dispatch_table[torch.SymInt] = _reduce_symint
460
def dumps(obj) -> bytes:
462
Pickle an object using the FxGraphCachePickler.
464
with io.BytesIO() as stream:
465
pickler = FxGraphCachePickler(stream)
467
return stream.getvalue()
470
def get_hash(obj: Any) -> str:
472
Serialize an object using the FxGraphCachePickler and return a hash
473
of the pickled object.
475
serialized_data = FxGraphCachePickler.dumps(obj)
476
return sha256_hash(serialized_data)
479
@functools.lru_cache(None)
480
def get_inductor_code_hash() -> bytes:
482
Compute a hash of all inductor code modules. Used by the FxGraph cache
483
so any inductor code changes would result in new cache keys.
485
inductor_root = os.path.dirname(__file__)
487
contents: Dict[str, bytes] = {}
488
for lib in pkgutil.iter_modules([inductor_root]):
489
spec = lib.module_finder.find_spec(lib.name, None)
490
assert spec is not None
492
assert module is not None
493
with open(module, "rb") as f:
494
contents[module] = f.read()
496
return hashlib.sha256(pickle.dumps(contents)).digest()
499
@dataclasses.dataclass
500
class OrderedSetHolder:
502
See FxGraphHashDetails. Holds a sorted list to support stable hashing
509
class FxGraphHashDetails:
511
Object to capture all the details for a compiled FX graph relevant to computing
512
a safe and stable cache key.
515
# Excluded kwargs param that are not stable between runs
516
EXCLUDED_KWARGS = ["graph_id"]
520
gm: torch.fx.GraphModule,
521
example_inputs: List[torch.Tensor],
522
fx_kwargs: Dict[str, Any],
525
self.example_inputs = example_inputs
527
# Order kwargs so hashing is stable to changes in kwarg order.
529
for k in sorted(fx_kwargs):
530
if k not in self.EXCLUDED_KWARGS:
531
if type(fx_kwargs[k]) is set:
532
# Special case to handle set params. Python sets can't be
533
# ordered, so sort the elements and store them in a proxy.
534
self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k]))
536
self.fx_kwargs[k] = fx_kwargs[k]
538
# Also hash on various system info (including the triton compiler version), as
539
# well as the inductor configuration and code.
540
self.torch_version = torch.__version__
541
self.system_info = CacheBase.get_system()
543
self.inductor_config = config.save_config()
544
self.inductor_code_hash = get_inductor_code_hash()
546
def debug_str(self) -> str:
548
Get a printable string describing in more detail all the attributes
549
comprising this object. Useful for debugging when one graph hashes
550
to a different value than another.
553
def get_str(obj) -> str:
554
if isinstance(obj, torch.Tensor):
555
return str(extract_tensor_metadata(obj))
556
elif isinstance(obj, bytes):
562
for attr, obj in vars(self).items():
563
if isinstance(obj, list):
564
for ii in range(len(obj)):
565
h = FxGraphCachePickler.get_hash(obj[ii])
566
lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}")
567
elif isinstance(obj, dict):
568
for k, v in obj.items():
569
h = FxGraphCachePickler.get_hash(v)
570
lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}")
572
h = FxGraphCachePickler.get_hash(obj)
573
lines.append(f"[{h}] {attr}: {get_str(obj)}")
574
return "\n".join(lines)
577
def compiled_fx_graph_hash(
578
gm: torch.fx.GraphModule,
579
example_inputs: List[torch.Tensor],
580
fx_kwargs: Dict[str, Any],
583
Generate a unique hash of the FX graph for caching.
585
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs)
586
# The prefix distinguishes among the other kinds of objects we
587
# cache in this module.
588
key = "f" + FxGraphCachePickler.get_hash(details)
589
log.debug("FX graph cache hash details for key %s:\n%s", key, details.debug_str())
595
Supports caching and reusing compiled Fx graphs.
597
The overall strategy is as follows:
598
- This cache stores entries on disk. When saving an entry, we can't
599
serialize callables (that could be C++, Triton, etc.), so we serialize
600
their own disk cache location. We then recreate the compiled artifact
601
after fetching from disk.
602
- For indexing the cache, we gather the fields relevant to identifying an
603
FxGraph (the graph module, graph inputs, system settings etc.) into an
604
FxGraphCacheDetails object, pickle it, and compute a hash for the key.
605
See FxGraphCachePickler.
606
- Among the metadata we store, we also include a guards expression that's
607
appropriate for validating any symbols for Tensor arguments that have
608
symbolic bounds. On cache lookup then, we evaluate those guards in the
609
current context to validate that a cached entry can be served.
610
- A given graph could have multiple compiled versions, corresponding to
611
different sets of guards. Therefore, we store cache entries in the form:
612
<temp dir>/<fx graph hash>/<serialized metatdata>
613
- On lookup, we compute the key from the graph details, iterate over all
614
leaf files in the corresponding subdirectory, deserialize the entry, and
615
evaluate its guards expression. If the evaluation succeeds, we have a
616
cache hit. If it fails, we compile the graph and store a new entry.
617
- Finally, on a cache hit, we need to make sure any guards that would
618
have been created during compilation are added to the current context.
621
# TODO(masnesral): Investigate whether it's beneficial to store compiled graphs
622
# in an in-memory cache after loading from disk.
624
def _get_tmp_dir() -> str:
626
Get the toplevel temporary directory for storing compiled graphs.
628
return os.path.join(cache_dir(), "fxgraph")
631
def _get_tmp_dir_for_key(key: str) -> str:
633
Return the disk location for a given cache key.
635
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
638
def _filter_symints(inputs: List[Any]) -> List[torch.SymInt]:
640
Get the SymInt objects from the input list.
642
return [s for s in inputs if isinstance(s, torch.SymInt)]
645
def _get_shape_env() -> ShapeEnv:
647
Helper to get the shape env from the tracing context.
649
return torch._guards.TracingContext.get().fake_mode.shape_env
654
example_inputs: List[torch.Tensor],
655
) -> Optional[CompiledFxGraph]:
657
Lookup a compiled graph in the cache by key. On a hit, return the
658
deserialized CompiledFxGraph object. On a miss, return None.
660
subdir = FxGraphCache._get_tmp_dir_for_key(key)
661
if not os.path.exists(subdir):
664
# Iterate over any entries in the subdir for this key and evaluate
665
# their guards to determine whether there's a hit.
666
for path in sorted(os.listdir(subdir)):
667
with open(os.path.join(subdir, path), "rb") as f:
668
graph: CompiledFxGraph = pickle.load(f)
670
guards_expr = graph.guards_expr
672
# No guards to evaluate
675
# Evaluate the guard expression in the current context.
676
shape_env = FxGraphCache._get_shape_env()
677
symints = FxGraphCache._filter_symints(example_inputs)
679
# If there's not a cache hit, we don't want the evaluation to
680
# affect the current env, e.g., cause the creation of new guards,
681
# so we evaluate with the hints instead of the symbols.
682
assert all(has_hint(s) for s in symints)
683
hints = [hint_int(s) for s in symints]
684
hit = bool(shape_env.evaluate_guards_expression(guards_expr, hints))
686
"fx graph cache key %s evaluating guards for %s with values %s => %s",
693
# Now re-evaluate with the symints to add any guards to the current env.
694
check = bool(shape_env.evaluate_guards_expression(guards_expr, symints))
697
"fx graph cache key %s post-load guards: %s",
707
key: str, compiled_graph: CompiledFxGraph, example_inputs: List[torch.Tensor]
710
Store a serialized CompiledFxGraph on disk.
712
disk_compiled_graph = copy(compiled_graph)
713
# Important as compiled models are not pickleable:
714
disk_compiled_graph.compiled_artifact = None
716
# Before serializing, compute the guard expression that will be used to
717
# ensure that a CompiledFxGraph is valid when loaded from the cache. It's
718
# sufficient to consider only the SymInt args to the fx graph since the
719
# Tensor shapes are already captured in the hash for the cache key. Any
720
# Tensor arg with a symbolic shape will have a SymInt arg for the graph.
721
shape_env = FxGraphCache._get_shape_env()
722
symints = FxGraphCache._filter_symints(example_inputs)
723
disk_compiled_graph.guards_expr = shape_env.produce_guards_expression(symints)
725
content = pickle.dumps(disk_compiled_graph)
727
subdir = FxGraphCache._get_tmp_dir_for_key(key)
728
if not os.path.exists(subdir):
729
os.makedirs(subdir, exist_ok=True)
731
# Use a hash of the serialized CompiledFxGraph to get a unique file
732
# name. The specific name doesn't matter since a lookup involves
733
# iterating over all entries in the parent subdir.
734
path = os.path.join(subdir, sha256_hash(content))
735
write_atomic(path, content)
739
compile_fx_fn: Callable[..., Any],
740
gm: torch.fx.GraphModule,
741
example_inputs: List[torch.Tensor],
742
fx_kwargs: Dict[str, Any],
745
Load a compiled graph from the cache. If a cached entry does not exist,
746
compile the graph and save it to the cache.
748
from filelock import FileLock
750
key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs)
752
lock_dir = get_lock_dir()
753
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
755
compiled_graph = FxGraphCache._lookup_graph(key, example_inputs)
756
if compiled_graph is None:
757
log.debug("fx graph cache miss for key %s", key)
758
counters["inductor"]["fxgraph_cache_miss"] += 1
759
compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs)
760
FxGraphCache._save_graph(key, compiled_graph, example_inputs)
762
log.debug("fx graph cache hit for key %s", key)
763
counters["inductor"]["fxgraph_cache_hit"] += 1
765
return compiled_graph
770
Clear out the on-disk cache.
772
shutil.rmtree(FxGraphCache._get_tmp_dir())
775
@dataclasses.dataclass
776
class CompiledFxGraph:
778
Class holding a compiled FX graph. This is the object serialized on disk
779
to support FxGraph caching.
782
compiled_artifact: Optional[Callable[..., Any]] = None
783
current_callable: Optional[Callable[..., Any]] = None
784
cache_key: Optional[str] = None
785
artifact_path: Optional[str] = None
786
cache_linemap: Optional[List[Tuple[int, str]]] = None
787
device_types: Set[str] = field(default_factory=set)
788
device_idxs: Set[int] = field(default_factory=set)
789
mutated_inputs: Set[str] = field(default_factory=set)
790
mutated_input_idxs: Set[int] = field(default_factory=set)
791
constants: Dict[str, torch.Tensor] = field(default_factory=dict)
792
output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
793
# This is a string representation of an expression we serialize
794
# with the object so the guards can be evaluated in a different
795
# context in order to verify the validity of serving a cached
796
# fx graph. The expression must be generated by:
797
# ShapeEnv.produce_guards_expression()
798
guards_expr: Optional[str] = None
800
_boxed_call: Optional[bool] = None
802
disabled_cudagraphs_reason: Optional[str] = None
806
compiled_artifact: Optional[Callable[..., Any]],
807
graph: GraphLowering,
808
output_strides: List[Optional[Tuple[int, ...]]],
809
disabled_cudagraphs_reason: Optional[str],
811
self.compiled_artifact = compiled_artifact
812
self.cache_key = graph.cache_key
813
self.artifact_path = graph.cache_path
814
self.cache_linemap = graph.cache_linemap
815
self.device_types = graph.device_types
816
self.device_idxs = graph.device_idxs
817
self.mutated_inputs = graph.mutated_inputs
818
self.mutated_input_idxs = set(graph.mutated_input_idxs)
819
self.constants = graph.constants
820
self.output_strides = output_strides
821
self.guards_expr = None
822
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
824
def __call__(self, inputs: List[Any]) -> Any:
825
return self.get_current_callable()(inputs)
827
def get_current_callable(self) -> Callable[..., Any]:
828
if self.current_callable is None:
829
# This prevents a circular reference that makes CompiledFxGraph
830
# get stuck without getting garbage collected
831
return functools.partial(_run_from_cache, weakref.proxy(self))
833
return self.current_callable
836
def _run_from_cache(compiled_graph: CompiledFxGraph, inputs: List[Any]) -> Any:
837
# We can't really serialize callables that may be C++/Triton/etc.,
838
# so we serialize their disk cache location instead
839
# TODO: When making an API that can save compiled models e2e to disk
840
# this will need to be better
841
if compiled_graph.compiled_artifact is None:
842
from .codecache import PyCodeCache
844
assert compiled_graph.cache_key
845
assert compiled_graph.artifact_path
846
compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
847
compiled_graph.cache_key,
848
compiled_graph.artifact_path,
849
compiled_graph.cache_linemap,
850
compiled_graph.constants,
853
return compiled_graph.compiled_artifact(inputs)
856
def cpp_compiler() -> str:
857
if config.is_fbcode():
858
return build_paths.cc()
859
if isinstance(config.cpp.cxx, (list, tuple)):
860
search = tuple(config.cpp.cxx)
862
search = (config.cpp.cxx,)
863
return cpp_compiler_search(search)
866
@functools.lru_cache(1)
867
def cpp_compiler_search(search: str) -> str:
871
# gxx package is only available for Linux
872
# according to https://anaconda.org/conda-forge/gxx/
873
if sys.platform != "linux":
875
# Do not install GXX by default
876
if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"):
878
from filelock import FileLock
880
lock_dir = get_lock_dir()
882
os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT
885
cxx = install_gcc_via_conda()
886
subprocess.check_output([cxx, "--version"])
888
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
890
raise exc.InvalidCxxCompiler()
893
def install_gcc_via_conda() -> str:
894
"""On older systems, this is a quick way to get a modern compiler"""
895
prefix = os.path.join(cache_dir(), "gcc")
896
cxx_path = os.path.join(prefix, "bin", "g++")
897
if not os.path.exists(cxx_path):
898
log.info("Downloading GCC via conda")
899
conda = os.environ.get("CONDA_EXE", "conda")
901
conda = shutil.which("conda")
902
if conda is not None:
903
subprocess.check_call(
907
f"--prefix={prefix}",
908
"--channel=conda-forge",
914
stdout=subprocess.PIPE,
920
return bool(re.search(r"(gcc|g\+\+)", cpp_compiler()))
923
def is_clang() -> bool:
924
return bool(re.search(r"(clang|clang\+\+)", cpp_compiler()))
927
@functools.lru_cache(None)
928
def is_apple_clang() -> bool:
930
version_string = subprocess.check_output([cxx, "--version"]).decode("utf8")
931
return "Apple" in version_string.splitlines()[0]
938
_dtype_nelements: Dict[torch.dtype, int]
940
# Note [Checking for Vectorized Support in Inductor]
941
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
942
# Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
943
# like exp, pow, sin, cos and etc.
944
# But PyTorch and TorchInductor might use different compilers to build code. If
945
# PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
946
# will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
947
# avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
948
# gcc/g++ compiler by default while it could support the AVX512 compilation.
949
# Therefore, there would be a conflict sleef version between PyTorch and
950
# TorchInductor. Hence, we dry-compile the following code to check whether current
951
# HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
952
# also needs the logic
953
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
954
# making the runtime check unnecessary.
956
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
957
#include <ATen/cpu/vec/functional.h>
958
#include <ATen/cpu/vec/vec.h>
961
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
963
extern "C" void __avx_chk_kernel() {
964
auto tmp0 = at::vec::Vectorized<float>(1);
965
auto tmp1 = tmp0.exp();
966
tmp1.store(in_out_ptr0);
972
from ctypes import cdll
973
cdll.LoadLibrary("__lib_path__")
976
def bit_width(self) -> int:
977
return self._bit_width
979
def nelements(self, dtype: torch.dtype = torch.float) -> int:
980
return self._dtype_nelements[dtype]
982
def build_macro(self) -> str:
985
def build_arch_flags(self) -> str:
986
return self._arch_flags
988
def __hash__(self) -> int:
989
return hash(str(self))
991
@functools.lru_cache(None)
992
def __bool__(self) -> bool:
993
if config.cpp.vec_isa_ok is not None:
994
return config.cpp.vec_isa_ok
996
if config.is_fbcode():
999
key, input_path = write(VecISA._avx_code, "cpp")
1000
from filelock import FileLock
1002
lock_dir = get_lock_dir()
1003
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
1005
output_path = input_path[:-3] + "so"
1006
build_cmd = shlex.split(
1007
cpp_compile_command(
1008
input_path, output_path, warning_all=False, vec_isa=self
1012
# Check build result
1013
compile_file(input_path, output_path, build_cmd)
1014
subprocess.check_call(
1018
VecISA._avx_py_load.replace("__lib_path__", output_path),
1020
stderr=subprocess.DEVNULL,
1021
env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
1023
except Exception as e:
1029
@dataclasses.dataclass
1030
class VecNEON(VecISA):
1031
_bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
1032
_macro = "-DCPU_CAPABILITY_NEON"
1033
_arch_flags = "" # Unused
1034
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16}
1036
def __str__(self) -> str:
1037
return "neon" # Unused
1039
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
1042
@dataclasses.dataclass
1043
class VecAVX512(VecISA):
1045
_macro = "-DCPU_CAPABILITY_AVX512"
1046
_arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
1047
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
1049
def __str__(self) -> str:
1052
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
1055
@dataclasses.dataclass
1056
class VecAVX2(VecISA):
1058
_macro = "-DCPU_CAPABILITY_AVX2"
1059
_arch_flags = "-mavx2 -mfma"
1060
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
1062
def __str__(self) -> str:
1065
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
1068
@dataclasses.dataclass
1069
class VecZVECTOR(VecISA):
1071
_macro = "-DCPU_CAPABILITY_ZVECTOR -DCPU_CAPABILITY=ZVECTOR -DHAVE_ZVECTOR_CPU_DEFINITION"
1072
_arch_flags = "-mvx -mzvector"
1073
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
1075
def __str__(self) -> str:
1078
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
1081
class InvalidVecISA(VecISA):
1085
_dtype_nelements = {}
1087
def __str__(self) -> str:
1088
return "INVALID_VEC_ISA"
1090
def __bool__(self) -> bool: # type: ignore[override]
1093
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
1096
invalid_vec_isa = InvalidVecISA()
1097
supported_vec_isa_list = [
1101
] # This order matters for test_cpu_repro
1104
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
1105
# might have too much redundant content that is useless for ISA check. Hence,
1106
# we only cache some key isa information.
1107
@functools.lru_cache(None)
1108
def valid_vec_isa_list() -> List[VecISA]:
1109
if sys.platform != "linux":
1112
if platform.machine() == "s390x":
1113
return [VecZVECTOR()]
1116
with open("/proc/cpuinfo") as _cpu_info:
1117
_cpu_info_content = _cpu_info.read()
1118
for isa in supported_vec_isa_list:
1119
# cpuinfo does not reveal info about NEON support. All aarch64 processors do support NEON though.
1121
(str(isa) in _cpu_info_content)
1122
or (isinstance(isa, VecNEON) and platform.processor() == "aarch64")
1125
isa_list.append(isa)
1129
def pick_vec_isa() -> VecISA:
1130
if config.is_fbcode():
1133
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
1134
if not _valid_vec_isa_list:
1135
return invalid_vec_isa
1137
# If the simdlen is None, it indicates determin the vectorization length automatically
1138
if config.cpp.simdlen is None:
1139
assert _valid_vec_isa_list
1140
return _valid_vec_isa_list[0]
1142
for isa in _valid_vec_isa_list:
1143
if config.cpp.simdlen == isa.bit_width():
1146
return invalid_vec_isa
1149
def get_compile_only(compile_only: bool = True) -> str:
1150
return "-c" if compile_only else ""
1153
def get_shared(shared: bool = True, compile_only: bool = False) -> str:
1158
if platform.system() == "Darwin" and "clang" in cpp_compiler():
1159
# This causes undefined symbols to behave the same as linux
1160
return "-shared -fPIC -undefined dynamic_lookup"
1162
return "-shared -fPIC"
1165
def get_warning_all_flag(warning_all: bool = True) -> str:
1166
return "-Wall" if warning_all else ""
1169
def get_glibcxx_abi_build_flags() -> str:
1170
return "-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
1173
def cpp_flags() -> str:
1174
flags = ["-std=c++17", "-Wno-unused-variable", "-Wno-unknown-pragmas"]
1176
flags.append("-Werror=ignored-optimization-argument")
1177
return " ".join(flags)
1180
def cpp_wrapper_flags() -> str:
1181
return "-DTORCH_INDUCTOR_CPP_WRAPPER"
1184
def optimization_flags() -> str:
1185
base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG"
1186
base_flags += " -ffast-math -fno-finite-math-only"
1187
if not config.cpp.enable_unsafe_math_opt_flag:
1188
base_flags += " -fno-unsafe-math-optimizations"
1189
if not config.cpp.enable_floating_point_contract_flag:
1190
base_flags += " -ffp-contract=off"
1192
if config.is_fbcode():
1193
# FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies.
1194
# This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths.
1195
# We will fix it later by exposing the lib path.
1198
if sys.platform == "darwin":
1199
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
1200
# Also, `-march=native` is unrecognized option on M1
1201
base_flags += " -Xclang"
1203
if platform.machine() == "ppc64le":
1204
base_flags += " -mcpu=native"
1206
base_flags += " -march=native"
1208
# Internal cannot find libgomp.so
1209
if not config.is_fbcode():
1210
base_flags += " -fopenmp"
1214
def use_custom_generated_macros() -> str:
1215
return "-D C10_USING_CUSTOM_GENERATED_MACROS"
1218
def use_fb_internal_macros() -> str:
1219
if config.is_fbcode():
1220
openmp_lib = build_paths.openmp_lib()
1221
preprocessor_flags = " ".join(
1224
"-D C10_USE_MINIMAL_GLOG",
1225
"-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY",
1228
return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags}"
1233
def use_standard_sys_dir_headers() -> str:
1234
if config.is_fbcode():
1240
@functools.lru_cache(None)
1241
def is_conda_llvm_openmp_installed() -> bool:
1243
command = "conda list llvm-openmp --json"
1244
output = subprocess.check_output(command.split()).decode("utf8")
1245
return len(json.loads(output)) > 0
1246
except subprocess.SubprocessError:
1250
@functools.lru_cache(None)
1251
def homebrew_libomp() -> Tuple[bool, str]:
1253
# check if `brew` is installed
1254
subprocess.check_output(["which", "brew"])
1255
# get the location of `libomp` if it is installed
1256
# this is the location that `libomp` **would** be installed
1257
# see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details
1259
subprocess.check_output(["brew", "--prefix", "libomp"])
1263
# check if `libomp` is installed
1264
omp_available = os.path.exists(libomp_path)
1265
return omp_available, libomp_path
1266
except subprocess.SubprocessError:
1270
def get_include_and_linking_paths(
1271
include_pytorch: bool = False,
1272
vec_isa: VecISA = invalid_vec_isa,
1274
aot_mode: bool = False,
1275
) -> Tuple[List[str], str, str, str, str]:
1278
and "CUDA_HOME" not in os.environ
1279
and "CUDA_PATH" not in os.environ
1281
os.environ["CUDA_HOME"] = os.path.dirname(build_paths.cuda())
1282
from torch.utils import cpp_extension
1285
build_arch_flags = ""
1286
if sys.platform == "linux" and (
1288
or vec_isa != invalid_vec_isa
1290
or config.cpp.enable_kernel_profile
1292
# Note - We include pytorch only on linux right now. There is more work
1293
# to do to enable OMP build on darwin where PyTorch is built with IOMP
1294
# and we need a way to link to what PyTorch links.
1295
ipaths = cpp_extension.include_paths(cuda) + [sysconfig.get_path("include")]
1296
lpaths = cpp_extension.library_paths(cuda) + [
1297
sysconfig.get_config_var("LIBDIR")
1302
# No need to manually specify libraries in fbcode.
1303
if not config.is_fbcode():
1304
libs += ["torch", "torch_cpu"]
1307
libs += ["torch_python"]
1309
# internal remote execution is able to find omp, but not gomp
1312
ipaths += [os.path.dirname(cpp_prefix_path())]
1314
# This is a special treatment for Meta internal cuda-12 where all libs
1315
# are in lib/cuda-12 and lib/cuda-12/stubs
1316
for i, path in enumerate(lpaths):
1318
os.environ["CUDA_HOME"]
1319
) and not os.path.exists(f"{path}/libcudart_static.a"):
1320
for root, dirs, files in os.walk(path):
1321
if "libcudart_static.a" in files:
1322
lpaths[i] = os.path.join(path, root)
1323
lpaths.append(os.path.join(lpaths[i], "stubs"))
1325
macros = vec_isa.build_macro()
1327
if config.is_fbcode() and vec_isa != invalid_vec_isa:
1328
cap = str(vec_isa).upper()
1331
vec_isa.build_arch_flags(),
1332
f"-D CPU_CAPABILITY={cap}",
1333
f"-D CPU_CAPABILITY_{cap}",
1334
f"-D HAVE_{cap}_CPU_DEFINITION",
1338
if aot_mode and cuda:
1341
macros += " -D USE_ROCM" if torch.version.hip else " -D USE_CUDA"
1344
if torch.version.hip is not None:
1345
libs += ["c10_hip", "torch_hip"]
1346
macros += " -D __HIP_PLATFORM_AMD__"
1348
if config.is_fbcode():
1351
libs += ["c10_cuda", "cuda", "torch_cuda"]
1352
build_arch_flags = vec_isa.build_arch_flags()
1354
# Note - this is effectively a header only inclusion. Usage of some header files may result in
1355
# symbol not found, if those header files require a library.
1356
# For those cases, include the lpath and libs command as we do for pytorch above.
1357
# This approach allows us to only pay for what we use.
1358
ipaths = cpp_extension.include_paths(cuda) + [sysconfig.get_path("include")]
1360
ipaths += [os.path.dirname(cpp_prefix_path())]
1362
if sys.platform == "darwin":
1363
# only Apple builtin compilers (Apple Clang++) require openmp
1364
omp_available = not is_apple_clang()
1366
# check the `OMP_PREFIX` environment first
1367
if os.getenv("OMP_PREFIX") is not None:
1368
header_path = os.path.join(os.getenv("OMP_PREFIX"), "include", "omp.h") # type: ignore[arg-type]
1369
valid_env = os.path.exists(header_path)
1371
ipaths.append(os.path.join(os.getenv("OMP_PREFIX"), "include")) # type: ignore[arg-type]
1372
lpaths.append(os.path.join(os.getenv("OMP_PREFIX"), "lib")) # type: ignore[arg-type]
1374
warnings.warn("environment variable `OMP_PREFIX` is invalid.")
1375
omp_available = omp_available or valid_env
1377
libs = [] if omp_available else ["omp"]
1379
# prefer to use openmp from `conda install llvm-openmp`
1380
if not omp_available and os.getenv("CONDA_PREFIX") is not None:
1381
omp_available = is_conda_llvm_openmp_installed()
1383
conda_lib_path = os.path.join(os.getenv("CONDA_PREFIX"), "lib") # type: ignore[arg-type]
1384
ipaths.append(os.path.join(os.getenv("CONDA_PREFIX"), "include")) # type: ignore[arg-type]
1385
lpaths.append(conda_lib_path)
1386
# Prefer Intel OpenMP on x86 machine
1387
if os.uname().machine == "x86_64" and os.path.exists(
1388
os.path.join(conda_lib_path, "libiomp5.dylib")
1392
# next, try to use openmp from `brew install libomp`
1393
if not omp_available:
1394
omp_available, libomp_path = homebrew_libomp()
1396
ipaths.append(os.path.join(libomp_path, "include"))
1397
lpaths.append(os.path.join(libomp_path, "lib"))
1399
# if openmp is still not available, we let the compiler to have a try,
1400
# and raise error together with instructions at compilation error later
1402
libs = ["omp"] if config.is_fbcode() else ["gomp"]
1404
# Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690
1405
if not config.abi_compatible:
1407
lpaths += [cpp_extension.TORCH_LIB_PATH]
1410
if config.is_fbcode():
1411
ipaths.append(build_paths.sleef())
1412
ipaths.append(build_paths.openmp())
1413
ipaths.append(build_paths.cc_include())
1414
ipaths.append(build_paths.libgcc())
1415
ipaths.append(build_paths.libgcc_arch())
1416
ipaths.append(build_paths.libgcc_backward())
1417
ipaths.append(build_paths.glibc())
1418
ipaths.append(build_paths.linux_kernel())
1419
ipaths.append(build_paths.cuda())
1420
# We also need to bundle includes with absolute paths into a remote directory
1421
# (later on, we copy the include paths from cpp_extensions into our remote dir)
1422
ipaths.append("include")
1424
static_link_libs = []
1425
if aot_mode and cuda and config.is_fbcode():
1426
# For Meta internal cuda-12, it is recommended to static link cudart
1427
static_link_libs = ["-Wl,-Bstatic", "-lcudart_static", "-Wl,-Bdynamic"]
1429
lpaths_str = " ".join(["-L" + p for p in lpaths])
1430
libs_str = " ".join(static_link_libs + ["-l" + p for p in libs])
1431
return ipaths, lpaths_str, libs_str, macros, build_arch_flags
1434
def cpp_compile_command(
1435
input: Union[str, List[str]],
1437
warning_all: bool = True,
1438
shared: bool = True,
1439
include_pytorch: bool = False,
1440
vec_isa: VecISA = invalid_vec_isa,
1442
aot_mode: bool = False,
1443
compile_only: bool = False,
1444
use_absolute_path: bool = False,
1446
ipaths, lpaths, libs, macros, build_arch_flags = get_include_and_linking_paths(
1447
include_pytorch, vec_isa, cuda, aot_mode
1449
if isinstance(input, str):
1451
ipaths_str = " ".join(["-I" + p for p in ipaths])
1453
if config.is_fbcode():
1454
if aot_mode and not use_absolute_path:
1457
linker_script = _LINKER_SCRIPT
1459
# We need to copy any absolute-path torch includes
1460
inp_name = [os.path.basename(i) for i in input]
1461
out_name = os.path.basename(output)
1462
linker_script = os.path.basename(_LINKER_SCRIPT)
1464
# Use clang runtime instead of libgcc
1465
clang_flags += " --rtlib=compiler-rt"
1466
clang_flags += " -fuse-ld=lld"
1467
clang_flags += f" -Wl,--script={linker_script}"
1468
linker_paths = "-B" + build_paths.glibc_lib()
1469
linker_paths += " -L" + build_paths.glibc_lib()
1473
linker_paths = "" # let the compiler pick
1475
libs, lpaths = "", ""
1476
inp_name_str = " ".join(inp_name)
1481
{cpp_compiler()} {inp_name_str} {get_shared(shared, compile_only)}
1482
{get_warning_all_flag(warning_all)} {cpp_flags()}
1483
{get_glibcxx_abi_build_flags()}
1484
{ipaths_str} {lpaths} {libs} {build_arch_flags}
1485
{macros} {linker_paths} {clang_flags}
1486
{optimization_flags()}
1487
{use_custom_generated_macros()}
1488
{use_fb_internal_macros()}
1489
{use_standard_sys_dir_headers()}
1490
{get_compile_only(compile_only)}
1496
def run_command_and_check(cmd: str):
1497
cmd = shlex.split(cmd)
1499
subprocess.check_call(cmd)
1500
except subprocess.CalledProcessError as e:
1501
raise exc.CppCompileError(cmd, e.output) from e
1504
@functools.lru_cache(None)
1505
def split_aot_inductor_output_path(path: str) -> Tuple[str, str]:
1506
"""Returns the path where the AOT Inductor compiled kernels are stored."""
1507
if path.endswith(".so"):
1508
return os.path.split(path)
1513
class CudaKernelParamCache:
1514
cache: Dict[str, Dict[str, str]] = dict()
1515
clear = staticmethod(cache.clear)
1518
def set(cls, key: str, params: Dict[str, str], cubin: str) -> None:
1519
bin_type = "cubin" if torch.version.hip is None else "hsaco"
1524
specified_dir=split_aot_inductor_output_path(
1525
config.aot_inductor.output_path
1529
params[get_cpp_wrapper_cubin_path_name()] = path
1531
cls.cache[key] = params
1534
def get(cls, key: str) -> Optional[Dict[str, str]]:
1535
return cls.cache.get(key, None)
1539
return cls.cache.keys()
1542
class AotCodeCompiler:
1546
graph: GraphLowering,
1548
serialized_extern_kernel_nodes: Optional[str],
1551
picked_vec_isa = pick_vec_isa()
1553
cpp_compile_command(
1554
"i", "o", vec_isa=picked_vec_isa, cuda=cuda, aot_mode=graph.aot_mode
1557
fbcode_aot_cpu_re = False
1558
use_absolute_path = False
1559
if config.is_fbcode():
1560
ld_command = build_paths.ld()
1561
if not cuda and graph.aot_mode: # Meta internal AOTInductor CPU
1562
objcopy_command = build_paths.objcopy_fallback()
1563
fbcode_aot_cpu_re = True
1564
use_absolute_path = True
1566
objcopy_command = build_paths.objcopy()
1569
objcopy_command = "objcopy"
1572
specified_output_path,
1574
) = split_aot_inductor_output_path(config.aot_inductor.output_path)
1575
key, input_path = write(
1579
specified_dir=specified_output_path,
1582
def _compile_consts_linux(consts: bytes) -> str:
1583
_, consts_path = write(
1586
specified_dir=specified_output_path,
1589
consts_o = os.path.splitext(consts_path)[0] + ".o"
1590
if fbcode_aot_cpu_re:
1591
cmd = f"{ld_command} -r -b binary -o {os.path.basename(consts_o)} {os.path.basename(consts_path)}"
1592
compile_file(consts_path, consts_o, cmd.split())
1593
os.chmod(consts_o, 0o644)
1595
cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}"
1596
run_command_and_check(cmd)
1597
log.debug("aot constant binary command: %s", cmd)
1600
f"{objcopy_command} --rename-section"
1601
" .data=.lrodata,alloc,load,readonly,data,contents"
1602
f" {consts_o} {consts_o}"
1604
log.debug("aot constant obj command: %s", cmd)
1605
run_command_and_check(cmd)
1607
cmd = f"rm {consts_path}"
1608
log.debug("aot constant bin removal command: %s", cmd)
1609
run_command_and_check(cmd)
1611
if fbcode_aot_cpu_re:
1612
body = re.sub(r"[\W]", "_", os.path.basename(consts_path))
1614
body = re.sub(r"[\W]", "_", consts_path)
1618
f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}"
1621
f"{objcopy_command} --redefine-sym _binary_{body}_size=_binary_constants_bin_size {consts_o}"
1624
f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}"
1626
log.debug("aot constant binary redefine symbol: %s", " ".join(symbol_list))
1627
for cmd in symbol_list:
1628
run_command_and_check(cmd)
1631
def _compile_consts_darwin(consts: bytes) -> str:
1632
is_large_consts = len(consts) > 1024
1633
consts_asm = "\t.section\t__TEXT,__const\n"
1634
consts_asm += "\t.globl\t__binary_constants_bin_start\n"
1635
consts_asm += "__binary_constants_bin_start:\n"
1636
if not is_large_consts:
1638
consts_asm += f"\t.byte {c}\n"
1639
# Add one element even if constants are empty
1640
# Otherwise assembler will not put them in data section
1642
consts_asm += "\t.space 1\n"
1644
consts_asm += "\t.quad 0x1234567899abcdef\n"
1645
consts_asm += f"\t.space {len(consts) - 8}\n"
1646
consts_asm += ".globl\t__binary_constants_bin_end\n"
1647
consts_asm += "__binary_constants_bin_end:\n"
1648
_, consts_path = write(
1651
specified_dir=specified_output_path,
1653
consts_o = os.path.splitext(consts_path)[0] + ".o"
1654
cmd = f"{cpp_compiler()} -c -o {consts_o} {consts_path}"
1655
run_command_and_check(cmd)
1657
with open(consts_o, "r+b") as f:
1660
# Search for magic number and write the actual data over it
1661
start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12")
1662
assert start_idx != -1
1665
while pos < len(consts):
1666
rc = f.write(consts[pos:])
1670
from filelock import FileLock
1672
lock_dir = get_lock_dir()
1673
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
1675
# Currently, this only support serializing extern nodes in fbcode
1676
# Eventually, we should also have a serializer for OSS.
1677
if config.is_fbcode() and serialized_extern_kernel_nodes:
1678
output_json = os.path.splitext(input_path)[0] + ".json"
1679
with open(output_json, "w") as f:
1680
f.write(serialized_extern_kernel_nodes)
1683
config.aot_inductor.output_path
1684
if specified_so_name
1685
else os.path.splitext(input_path)[0] + ".so"
1688
output_o = os.path.splitext(input_path)[0] + ".o"
1689
cmd = cpp_compile_command(
1692
vec_isa=picked_vec_isa,
1694
aot_mode=graph.aot_mode,
1696
use_absolute_path=use_absolute_path,
1698
log.debug("aot compilation command: %s", cmd)
1699
if fbcode_aot_cpu_re:
1700
compile_file(input_path, output_o, cmd.split())
1701
os.chmod(output_o, 0o644)
1703
run_command_and_check(cmd)
1705
def _to_bytes(t: torch.Tensor) -> bytes:
1706
# This serializes the tensor's untyped_storage to bytes by accessing
1707
# the raw data of the underlying structure.
1713
t_cpu = t.untyped_storage().cpu()
1714
raw_array = ctypes.cast(
1716
ctypes.POINTER(ctypes.c_ubyte * t_cpu.nbytes()),
1719
return bytes(raw_array.contents)
1721
aot_constants = b"".join(
1723
for name, tensor in graph.constants.items()
1724
if name not in graph.folded_constants
1727
"linux": _compile_consts_linux,
1728
"darwin": _compile_consts_darwin,
1729
}[sys.platform](aot_constants)
1731
cmd = cpp_compile_command(
1732
input=[output_o, consts_o],
1734
vec_isa=picked_vec_isa,
1736
aot_mode=graph.aot_mode,
1737
use_absolute_path=use_absolute_path,
1739
log.debug("aot linkage command: %s", cmd)
1740
if fbcode_aot_cpu_re:
1741
compile_file([output_o, consts_o], output_so, cmd.split())
1742
os.chmod(output_so, 0o755)
1744
run_command_and_check(cmd)
1749
# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py.
1750
# Why? importing from cpp.py invokes codecache.pick_vec_isa(), which takes out a lock.
1752
# - CppCodeCache.load()
1754
# - valid_vec_isa_list()
1755
# - VecISA.__bool__() <-- takes out a lock
1756
# - compile_file() <-- imports cpp_prefix_path from cpp, which causes us to try to take out the same lock.
1758
def cpp_prefix_path() -> str:
1759
path = Path(__file__).parent / "codegen/cpp_prefix.h"
1760
with path.open() as f:
1762
_, filename = write(
1769
def cpp_prefix() -> str:
1770
filename = cpp_prefix_path()
1771
if config.is_fbcode():
1772
# We need relative paths, since we bundle up
1773
# everything that we compile into a folder for remote compilation.
1774
return f'#include "{os.path.basename(filename)}"'
1776
return f'#include "{filename}"'
1779
# Given a path to an input cpp file and an output path,
1780
# Attempts to compile the file, storing the output in "output_path"
1783
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
1785
input_paths = [input_path] if isinstance(input_path, str) else input_path
1787
os.path.basename(ip) if config.is_fbcode() else ip for ip in input_paths
1790
if config.is_fbcode():
1791
# Need to copy our header into the same folder as the sourcecode.
1792
header_path = cpp_prefix_path()
1793
header_name = os.path.basename(header_path)
1794
output_name = os.path.basename(output_path)
1795
# When we build remotely, we need to make sure to carefully copy any files
1796
# that are required during the compilation process into our build directly.
1797
# This is where all of the ATen/c10/Torch includes come from.
1798
torch_includes_path = os.path.join(_TORCH_PATH, "include")
1799
with tempfile.TemporaryDirectory() as tmp_dir:
1800
# Copy everything to tmp compilation folder
1801
shutil.copy(header_path, os.path.join(tmp_dir, header_name))
1802
shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld"))
1803
for p, f in zip(input_paths, input_files):
1804
shutil.copy(p, os.path.join(tmp_dir, f))
1805
dest_include_path = os.path.join(tmp_dir, "include")
1806
shutil.copytree(torch_includes_path, dest_include_path)
1808
output_file_path = _run_build_command(cmd, tmp_dir, output_name)
1809
# Copy output from the build
1810
if os.path.exists(output_path):
1811
os.remove(output_path)
1812
shutil.copy(output_file_path, output_path)
1814
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
1815
except subprocess.CalledProcessError as e:
1816
output = e.output.decode("utf-8")
1817
openmp_problem = "'omp.h' file not found" in output or "libomp" in output
1818
if openmp_problem and sys.platform == "darwin":
1820
"\n\nOpenMP support not found. Please try one of the following solutions:\n"
1821
"(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ "
1822
"that has builtin OpenMP support;\n"
1823
"(2) install OpenMP via conda: `conda install llvm-openmp`;\n"
1824
"(3) install libomp via brew: `brew install libomp`;\n"
1825
"(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path"
1826
" with `include/omp.h` under it."
1828
output += instruction
1829
raise exc.CppCompileError(cmd, output) from e
1832
_libgomp: Optional[CDLL] = None
1836
cache: Dict[str, Union[CDLL, ModuleType]] = {}
1837
clear = staticmethod(cache.clear)
1838
cpp_compile_command_flags: Dict[str, Any] = {}
1841
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
1842
return cdll.LoadLibrary(path)
1845
def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
1847
return cls._load_library_inner(path, key)
1848
except (ImportError, OSError) as e:
1849
if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"):
1850
# hacky workaround for fbcode/buck
1852
_libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1")
1853
return cls._load_library_inner(path, key)
1854
if "failed to map segment from shared object" in str(e):
1856
f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder "
1857
"is mounted with noexec (e.g., by default Docker mounts tmp file systems "
1858
f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another "
1859
"temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable."
1864
def load(cls, source_code: str, cuda: bool = False) -> Union[CDLL, ModuleType]:
1865
cls.cpp_compile_command_flags.update({"cuda": cuda})
1866
picked_vec_isa = pick_vec_isa()
1868
cpp_compile_command(
1869
"i", "o", vec_isa=picked_vec_isa, **cls.cpp_compile_command_flags
1872
key, input_path = write(source_code, "cpp", extra=cpp_command)
1873
if key not in cls.cache:
1874
from filelock import FileLock
1876
lock_dir = get_lock_dir()
1877
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
1879
output_path = input_path[:-3] + "so"
1880
if not os.path.exists(output_path):
1882
cpp_compile_command(
1885
vec_isa=picked_vec_isa,
1886
**cls.cpp_compile_command_flags,
1889
compile_file(input_path, output_path, cmd)
1890
cls.cache[key] = cls._load_library(output_path, key)
1891
cls.cache[key].key = key # type: ignore[union-attr]
1893
return cls.cache[key]
1896
# Customized Python binding for cpp kernels
1897
class CppPythonBindingsCodeCache(CppCodeCache):
1898
cache: Dict[str, Union[CDLL, ModuleType]] = {}
1899
clear = staticmethod(cache.clear)
1900
cpp_compile_command_flags = {
1901
# kernels have no dependency on libtorch
1902
"include_pytorch": False,
1905
entry_function = "kernel"
1906
call_entry_function = "kernel(%s);Py_RETURN_NONE;"
1907
extra_parse_arg = ""
1908
suffix_template = textwrap.dedent(
1910
// Python bindings to call %s():
1911
#define PY_SSIZE_T_CLEAN
1916
// This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow.
1917
// We manually link it below to workaround issues with fbcode build.
1918
static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj);
1920
template <typename T> static inline T parse_arg(PyObject* args, size_t n) {
1921
static_assert(std::is_pointer<T>::value, "arg type must be pointer or long");
1922
return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n)));
1924
template <> inline long parse_arg<long>(PyObject* args, size_t n) {
1925
auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n));
1926
if(result == -1 && PyErr_Occurred())
1927
[[unlikely]] throw std::runtime_error("expected int arg");
1933
static PyObject* %s_py(PyObject* self, PyObject* args) {
1935
if(!PyTuple_CheckExact(args))
1936
[[unlikely]] throw std::runtime_error("tuple args required");
1937
if(PyTuple_GET_SIZE(args) != %s)
1938
[[unlikely]] throw std::runtime_error("requires %s args");
1940
} catch(std::exception const& e) {
1941
PyErr_SetString(PyExc_RuntimeError, e.what());
1944
PyErr_SetString(PyExc_RuntimeError, "unhandled error");
1949
static PyMethodDef py_methods[] = {
1950
{"%s", %s_py, METH_VARARGS, ""},
1951
{NULL, NULL, 0, NULL}};
1953
static struct PyModuleDef py_module =
1954
{PyModuleDef_HEAD_INIT, "%s", NULL, -1, py_methods};
1956
PyMODINIT_FUNC PyInit_%s(void) {
1957
const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR");
1959
PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set");
1962
std::istringstream iss(str_addr);
1965
_torchinductor_pyobject_tensor_data_ptr =
1966
reinterpret_cast<decltype(_torchinductor_pyobject_tensor_data_ptr)>(addr);
1967
return PyModule_Create(&py_module);
1973
def _load_library_inner(cls, path: str, key: str) -> ModuleType:
1974
os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str(
1975
torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined]
1977
return importlib.machinery.ExtensionFileLoader(
1978
f"{key}.{cls.entry_function}", path
1979
).load_module() # type: ignore[call-arg]
1984
argtypes: List[str],
1987
num_outputs: int = -1,
1990
Wrap a C++ function in fast Python bindings.
1993
argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
1994
source_code: C++ source code containing a ENTRY_FUNCTION() function
1997
A python version of ENTRY_FUNCTION()
1999
parseargs = ", ".join(
2000
f"parse_arg<{argtype.replace('const ', '')}>(args, {n})"
2001
for n, argtype in enumerate(argtypes)
2003
suffix = cls.suffix_template % (
2005
cls.extra_parse_arg % num_outputs if cls.extra_parse_arg else "",
2009
cls.call_entry_function % parseargs,
2015
result = cls.load(source_code + suffix, cuda)
2016
assert isinstance(result, ModuleType)
2017
return getattr(result, cls.entry_function)
2020
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
2021
cache: Dict[str, Union[CDLL, ModuleType]] = {}
2022
clear = staticmethod(cache.clear)
2023
cpp_compile_command_flags = {
2024
"include_pytorch": True,
2027
entry_function = "inductor_entry_cpp"
2028
call_entry_function = "return THPVariable_WrapList(inductor_entry_cpp(%s));"
2029
extra_parse_arg = textwrap.dedent(
2031
#include <torch/csrc/autograd/python_variable.h>
2032
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
2034
template <> inline std::vector<at::Tensor> parse_arg<std::vector<at::Tensor>>(PyObject* args, size_t n) {
2035
return THPVariable_UnpackList(PyTuple_GET_ITEM(args, n));
2038
std::vector<at::Tensor> inductor_entry_cpp(std::vector<at::Tensor>&& inputs) {
2039
auto input_handles =
2040
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(inputs);
2041
// For outputs, we only allocate a vector to hold returned tensor handles,
2042
// not allocating the actual output tensor storage here
2043
std::vector<AtenTensorHandle> output_handles(%s);
2046
inductor_entry_impl(input_handles.data(), output_handles.data());
2047
} catch(std::exception const& e) {
2048
PyErr_SetString(PyExc_RuntimeError, e.what());
2051
PyErr_SetString(PyExc_RuntimeError, "unhandled error");
2055
return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
2056
output_handles.data(), output_handles.size());
2063
cache: Dict[str, ModuleType] = dict()
2064
linemaps: Dict[str, List[Tuple[Any, ...]]] = dict()
2065
clear = staticmethod(cache.clear)
2068
def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
2069
return write(source_code, "py", extra=extra)
2076
linemap: Optional[List[Tuple[int, str]]] = None,
2077
attrs: Optional[Dict[str, Any]] = None,
2079
key, path = write(source_code, "py", extra=extra)
2080
return cls.load_by_key_path(key, path, linemap, attrs)
2083
def load_by_key_path(
2087
linemap: Optional[List[Tuple[int, str]]] = None,
2088
attrs: Optional[Dict[str, Any]] = None,
2092
if key not in cls.cache:
2093
with open(path) as f:
2095
code = compile(f.read(), path, "exec")
2096
except Exception as e:
2098
f"Failed to import {path}\n{type(e).__name__}: {e}"
2100
mod = ModuleType(f"{__name__}.{key}")
2102
mod.key = key # type: ignore[attr-defined]
2103
exec(code, mod.__dict__, mod.__dict__)
2104
sys.modules[mod.__name__] = mod
2105
# another thread might set this first
2106
cls.cache.setdefault(key, mod)
2107
# unzip into separate lines/nodes lists
2108
cls.linemaps[path] = list(zip(*linemap))
2110
if attrs is not None:
2111
for k, v in attrs.items():
2114
return cls.cache[key]
2117
@functools.lru_cache(None)
2118
def stack_frames_for_code(
2119
cls, path: str, lineno: int
2120
) -> Optional[List[Dict[str, Any]]]:
2121
if path not in cls.linemaps:
2123
# [(starting_line, <fx node>), ...]
2124
lines, nodes = cls.linemaps[path]
2125
p = bisect_right(lines, lineno)
2128
entry = nodes[p - 1]
2132
def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]:
2133
# ideally fx stores stack traces as data rather than a string
2134
# but this is not along a performance critical path
2135
regex = r'File "(.+)", line (\d+), in (.+)\n'
2136
matches = re.findall(regex, stack_trace)
2138
{"filename": f, "line": int(l), "name": n}
2139
for f, l, n in reversed(matches)
2142
return parse_stack_trace(entry)
2145
class TritonCodeCache:
2147
def load(cls, kernel_name: str, source_code: str) -> ModuleType:
2148
mod = PyCodeCache.load(source_code)
2149
return getattr(mod, kernel_name)
2152
def _cuda_compiler() -> Optional[str]:
2153
if cuda_env.nvcc_exist(config.cuda.cuda_cxx):
2154
return config.cuda.cuda_cxx
2155
if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
2156
return os.getenv("CUDACXX", "")
2157
if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
2158
return os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc")
2162
def _cutlass_include_paths() -> List[str]:
2163
cutlass_path = config.cuda.cutlass_dir
2165
os.path.join(cutlass_path, "include"),
2166
os.path.join(cutlass_path, "tools/library/include"),
2167
os.path.join(cutlass_path, "tools/library/src"),
2168
os.path.join(cutlass_path, "tools/util/include"),
2172
def _cuda_lib_options() -> List[str]:
2173
from torch.utils import cpp_extension
2175
extra_ldflags: List[str] = []
2177
extra_lib_dir = "lib64"
2178
if not os.path.exists(
2179
cpp_extension._join_cuda_home(extra_lib_dir)
2180
) and os.path.exists(cpp_extension._join_cuda_home("lib")):
2181
# 64-bit CUDA may be installed in "lib"
2182
# Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64"
2183
extra_lib_dir = "lib"
2184
extra_ldflags.append(f"-L{cpp_extension._join_cuda_home(extra_lib_dir)}")
2185
extra_ldflags.append(
2186
f'-L{cpp_extension._join_cuda_home(extra_lib_dir, "stubs")}'
2188
extra_ldflags.append("-lcuda")
2189
extra_ldflags.append("-lcudart")
2191
raise NotImplementedError(
2192
"Unsupported env, failed to find cuda libs! Currently only Linux is supported."
2194
return extra_ldflags
2197
def _nvcc_host_compiler_options() -> List[str]:
2200
"-fno-strict-aliasing",
2201
"-fvisibility=hidden",
2206
def _nvcc_compiler_options() -> List[str]:
2207
arch = cuda_env.get_cuda_arch()
2209
# Required by cutlass compilation.
2211
code = [f"sm_{arch}", f"compute_{arch}"]
2212
if config.cuda.enable_cuda_lto:
2213
code += [f"lto_{arch}"]
2216
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
2218
f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
2219
config.cuda.compile_opt_level,
2221
"--expt-relaxed-constexpr",
2224
if config.cuda.enable_debug_info:
2225
options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
2226
if config.cuda.enable_ptxas_info:
2229
"--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
2230
"--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels
2231
"--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels
2232
"--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.)
2235
) # Annotate the ptx file with source information
2236
if config.cuda.use_fast_math:
2240
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
2246
def cuda_compile_command(
2247
src_files: List[str],
2251
include_paths = _cutlass_include_paths()
2252
cuda_lib_options = _cuda_lib_options()
2253
nvcc_host_compiler_options = _nvcc_host_compiler_options()
2254
nvcc_compiler_options = _nvcc_compiler_options()
2256
nvcc_compiler_options
2258
f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
2259
for opt in nvcc_host_compiler_options
2261
+ ["-I" + path for path in include_paths]
2264
src_file = " ".join(src_files)
2266
if dst_file_ext == "o":
2267
res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
2268
elif dst_file_ext == "so":
2269
options.append("-shared")
2270
res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
2272
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
2273
log.debug("CUDA command: %s", res)
2278
"""A wrapper for a dynamic library."""
2284
self.lib_path = lib_path
2285
self.DLL = cdll.LoadLibrary(lib_path)
2291
self.is_open = False
2298
if not hasattr(syms, "dlclose"):
2300
syms = CDLL("libc.so")
2302
if hasattr(syms, "dlclose"):
2303
f_dlclose = syms.dlclose
2305
raise NotImplementedError("Unsupported env, failed to do dlclose!")
2307
if f_dlclose is not None:
2308
f_dlclose.argtypes = [c_void_p]
2309
f_dlclose(self.DLL._handle)
2312
"dll unloading function was not found, library may not be unloaded properly!"
2315
def __getattr__(self, name):
2316
if not self.is_open:
2317
raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}")
2319
method = getattr(self.DLL, name)
2321
def _wrapped_func(*args):
2324
raise RuntimeError(f"Error in function: {method.__name__}")
2326
return _wrapped_func
2328
def __enter__(self):
2331
def __exit__(self, *args):
2339
@dataclasses.dataclass
2344
cache: Dict[str, CacheEntry] = dict()
2345
clear = staticmethod(cache.clear)
2346
_SOURCE_CODE_SUFFIX = "cu"
2349
def write(cls, source_code, dst_file_ext) -> Tuple[str, str]:
2351
Writes source code into a file with dst_file_ext as the file extension.
2352
Returns the hash key of source code, and the path to the file.
2355
cuda_command = repr(
2356
cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext)
2358
key, input_path = write(
2359
source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command
2361
return key, input_path
2364
def compile(cls, source_code, dst_file_ext) -> Tuple[str, str, str]:
2366
Compiles CUDA source_code into a file with dst_file_ext extension.
2367
Returns a tuple of dst_file_path, hash_key, source_code_path
2370
key, input_path = cls.write(source_code, dst_file_ext)
2371
if key not in cls.cache:
2372
from filelock import FileLock
2374
lock_dir = get_lock_dir()
2375
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
2377
output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
2378
if not os.path.exists(output_path):
2379
cmd = cuda_compile_command(
2380
[input_path], output_path, dst_file_ext
2383
subprocess.check_output(
2384
cmd, stderr=subprocess.STDOUT, env=os.environ
2386
except subprocess.CalledProcessError as error:
2387
raise exc.CUDACompileError(cmd, error.output) from error
2388
cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path)
2390
return (cls.cache[key].output_path, key, input_path)
2393
def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]:
2395
Compiles source code and loads the generated .so file.
2396
Returns a tuple of DLLWrapper, hash_key, source_code_path
2399
if dst_file_ext != "so":
2401
f"Only support loading a .so file for now. "
2402
f"Requested file extension: {dst_file_ext}. Source code: {source_code}"
2404
dst_file_path, hash_key, source_code_path = cls.compile(
2405
source_code, dst_file_ext
2407
return (DLLWrapper(dst_file_path), hash_key, source_code_path)
2410
def caching_device_properties():
2411
for _, device_interface in get_registered_device_interfaces():
2412
if device_interface.is_available():
2413
device_interface.Worker.get_device_properties()
2416
def _set_triton_ptxas_path() -> None:
2417
if os.environ.get("TRITON_PTXAS_PATH") is not None:
2419
ptxas_path = os.path.abspath(
2420
os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas")
2422
if not os.path.exists(ptxas_path):
2424
if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK):
2425
os.environ["TRITON_PTXAS_PATH"] = ptxas_path
2427
warnings.warn(f"{ptxas_path} exists but is not an executable")
2431
kernel_name: str, source_code: str, cc: int, device: torch.device
2433
device_interface = get_interface_for_device(device.type)
2434
device_interface.Worker.set_device(device.index)
2435
kernel = TritonCodeCache.load(kernel_name, source_code)
2436
kernel.precompile(warm_cache_only_with_cc=cc)
2439
def _load_kernel(kernel_name: str, source_code: str) -> ModuleType:
2440
_set_triton_ptxas_path()
2441
kernel = TritonCodeCache.load(kernel_name, source_code)
2453
future: Future[Any],
2455
self.kernel_name = kernel_name
2456
self.source_code = source_code
2457
self.future = future
2459
# @dynamo_utils.dynamo_timed
2460
def result(self) -> ModuleType:
2462
if hasattr(self, "kernel"):
2464
# If the worker failed this will throw an exception.
2465
self.future.result()
2466
kernel = self.kernel = _load_kernel(self.kernel_name, self.source_code)
2467
latency = time() - t0
2470
f"Detected long compilation time of {latency} seconds for kernel name {self.kernel_name}"
2472
developer_warning(self.source_code)
2473
del self.kernel_name, self.source_code, self.future
2477
# If this process dies abnormally (e.g. segfault)
2478
# it will not shut down the workers. Instead
2479
# the workers will have their parent reassigned to the
2480
# init process. This launches a separate thread to
2481
# watch for the worker getting reassigned,
2482
# and cleans it up in this case.
2484
# This function cannot be an inner function since otherwise mp_context="spawn" would
2485
# not work for ProcessPoolExecutor since inner functions cannot be pickled.
2486
def _async_compile_initializer(orig_ppid) -> None:
2490
if orig_ppid != os.getppid():
2491
os.kill(os.getpid(), signal.SIGKILL)
2493
global _watchdog_thread
2494
_watchdog_thread = Thread(target=run, daemon=True)
2495
_watchdog_thread.start()
2496
# Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam.
2497
signal.signal(signal.SIGINT, signal.SIG_IGN)
2500
_watchdog_thread: Optional[Thread] = None
2502
# Used to keep track of all process pools invoked so far.
2503
_pool_set: Set[ProcessPoolExecutor] = set()
2506
def shutdown_compile_workers() -> None:
2507
"""Shut down all outstanding compile-worker pools."""
2509
for pool in _pool_set:
2515
def __init__(self) -> None:
2519
@functools.lru_cache(1)
2520
def pool() -> ThreadPoolExecutor:
2521
assert config.compile_threads > 1
2522
return ThreadPoolExecutor(config.compile_threads)
2525
@functools.lru_cache(1)
2526
def process_pool() -> ProcessPoolExecutor:
2527
# ensure properties have been calculated before processes
2529
caching_device_properties()
2530
assert config.compile_threads > 1
2531
orig_ppid = os.getpid()
2533
ctx = multiprocessing.get_context(config.worker_start_method)
2534
pool = ProcessPoolExecutor(
2535
config.compile_threads,
2537
initializer=partial(_async_compile_initializer, orig_ppid),
2543
# when this pool is created in a subprocess object, the normal exit handler
2544
# doesn't run, and we need to register our own handler.
2545
# exitpriority has to be high, because another one of the finalizers will
2546
# kill the worker thread that sends the shutdown message to the workers...
2547
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
2551
def warm_pool(cls) -> None:
2552
if config.compile_threads <= 1:
2555
pool = cls.process_pool()
2557
# We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the
2558
# slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread.
2561
# A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup
2562
# tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup
2564
# So we want to start the workers early when it is still cheap, and also to allow the workers to get
2565
# ready before we have work for them.
2567
# ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle.
2568
# But if we waited until then fork time will be long and we will be waiting for the processes to initialize.
2570
# We force them to start here with some YOLOing of the internal methods.
2571
if hasattr(pool, "_start_queue_management_thread"):
2572
pool._start_queue_management_thread()
2574
for _ in range(config.compile_threads):
2575
pool._adjust_process_count()
2576
if hasattr(pool, "_start_executor_manager_thread"):
2577
pool._start_executor_manager_thread()
2581
def submit(cls, task: Callable[..., Any]) -> Any:
2582
if config.compile_threads <= 1:
2584
return cls.pool().submit(task)
2587
def map(cls, fn: Callable[..., Any], seq: List[Any]) -> List[Any]:
2588
if config.compile_threads <= 1 or len(seq) <= 1:
2589
return list(map(fn, seq))
2590
return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]
2593
self, kernel_name: str, source_code: str, device_str: str = "cuda"
2594
) -> Union[TritonFuture, ModuleType]:
2597
if config.compile_threads > 1:
2598
device_interface = get_interface_for_device(device_str)
2599
device = torch.device(device_str, device_interface.current_device())
2600
cc = device_interface.get_compute_capability(device)
2601
future = self.process_pool().submit(
2602
_worker_compile, kernel_name, source_code, cc, device
2604
return TritonFuture(kernel_name, source_code, future)
2606
return _load_kernel(kernel_name, source_code)
2608
def multi_kernel(self, *args, **kwargs) -> ModuleType:
2610
Async compile the python shim for multi-kernel.
2614
from torch._inductor.codegen.multi_kernel import MultiKernelCall
2616
return MultiKernelCall(*args, **kwargs)
2618
return self.submit(task)
2620
def cpp(self, source_code: str) -> ModuleType:
2622
return CppCodeCache.load(source_code).kernel
2624
return self.submit(task)
2626
def cpp_pybinding(self, argtypes: List[str], source_code: str) -> ModuleType:
2629
CppPythonBindingsCodeCache.load_pybinding, argtypes, source_code
2633
def cuda(self, source_code, dst_file_ext):
2635
return CUDACodeCache.load(source_code, dst_file_ext)[0]
2637
return self.submit(task)
2639
def wait(self, scope: Dict[str, Any]) -> None:
2643
for key, value in scope.items()
2644
if isinstance(value, (Future, TritonFuture))
2649
desc="Inductor Compilation",
2650
disable=config.disable_progress,
2653
if config.compile_threads > 1:
2654
for key, result in scope.items():
2655
if config.verbose_progress and not isinstance(pbar, _Faketqdm):
2656
pbar.set_postfix_str(key)
2657
if isinstance(result, (Future, TritonFuture)):
2658
scope[key] = result.result()
2664
AsyncCompile.warm_pool()