4
from typing import Callable, Optional
6
from torch.utils._triton import has_triton
7
from .utils import red_text, triton_config_to_hashable
14
from . import config as inductor_config
16
log = logging.getLogger(__name__)
19
def get_field(config, name):
20
if name == "num_warps":
21
return config.num_warps
22
elif name == "num_stages":
23
return config.num_stages
25
return config.kwargs.get(name, None)
28
def set_field(config, name, value):
29
if name == "num_warps":
30
config.num_warps = value
31
elif name == "num_stages":
32
config.num_stages = value
34
config.kwargs[name] = value
39
The coordinate descent tuner. Tune one field/coordinate at a time.
41
TODO will it be necessary to tune multiple fields simultaneously.
44
TODO: what if both increasing and decreasing a field can improve perf.
45
i.e., there are multiple local optima..
48
def __init__(self, is_mm=False, name="unknown", size_hints=None):
50
self.cached_benchmark_results = {}
52
self.size_hints = size_hints
55
xmax = inductor_config.triton.max_block["X"]
56
if self.size_hints and len(self.size_hints) > 0:
57
xmax = min(xmax, self.size_hints[0])
61
ymax = inductor_config.triton.max_block["Y"]
62
if self.size_hints and len(self.size_hints) > 1:
63
ymax = min(ymax, self.size_hints[1])
67
zmax = inductor_config.triton.max_block["Z"]
68
if self.size_hints and len(self.size_hints) > 2:
69
zmax = min(zmax, self.size_hints[2])
73
if self.size_hints and len(self.size_hints) > 0:
74
return self.size_hints[-1]
79
def get_warpsmax(self):
84
def cache_benchmark_result(self, config, timing):
85
self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
87
def lookup_in_cache(self, config):
88
return self.cached_benchmark_results.get(triton_config_to_hashable(config))
90
def call_func(self, func, config):
91
found = self.lookup_in_cache(config)
96
self.cache_benchmark_result(config, timing)
100
def tunable_fields(self):
116
out.append("num_stages")
120
def value_too_large(self, name, val):
122
return val > self.get_xmax()
124
return val > self.get_ymax()
126
return val > self.get_zmax()
128
return val > self.get_rmax()
129
if name == "num_warps":
130
return val > self.get_warpsmax()
134
def get_neighbour_values(self, name, orig_val, radius=1, include_self=False):
136
Get neighbour values in 'radius' steps. The original value is not
137
returned as it's own neighbour.
141
def update(cur_val, inc=True):
142
if name == "num_stages":
156
for _ in range(radius):
157
cur_val = update(cur_val, True)
158
if self.value_too_large(name, cur_val):
164
for _ in range(radius):
165
cur_val = update(cur_val, False)
175
def has_improvement(baseline, test):
177
return test is not None and test < baseline * (1 - threshold)
179
def check_all_tuning_directions(
181
func: Callable[["triton.Config"], float],
186
Check all directions. We only do this once the regular coordinate
187
descent tuning find no better choices any more.
188
We only have a few tunable fields, so this should be fine.
190
candidate_values_list = []
191
effective_fields = []
192
for field in self.tunable_fields:
193
old_value = get_field(best_config, field)
194
if old_value is None:
196
candidate_values = self.get_neighbour_values(
199
radius=inductor_config.coordinate_descent_search_radius,
202
candidate_values_list.append(candidate_values)
203
effective_fields.append(field)
205
choices = itertools.product(*candidate_values_list)
207
for choice in choices:
208
assert len(choice) == len(effective_fields)
209
candidate_config = copy.deepcopy(best_config)
210
for new_val, field in zip(choice, effective_fields):
211
set_field(candidate_config, field, new_val)
212
cmp_res, candidate_timing = self.compare_config(
213
func, candidate_config, best_config, best_timing
217
best_config = candidate_config
218
best_timing = candidate_timing
220
return improved, best_config, best_timing
222
def compare_config(self, func, candidate_config, best_config, best_timing):
224
Check if candidate_config is better than best_config.
226
Return a touple of (compare_result, candidate_timing).
227
compare_result is true iff candidate_config is better.
229
log.debug("Try config %s", candidate_config)
231
candidate_timing = self.call_func(func, candidate_config)
232
except Exception as e:
233
log.debug("Got exception %s", e)
234
return False, float("inf")
236
if self.has_improvement(best_timing, candidate_timing):
238
"Tune from %s %f -> %s %f",
245
return True, candidate_timing
246
return False, candidate_timing
250
func: Callable[["triton.Config"], float],
251
baseline_config: "triton.Config",
252
baseline_timing: Optional[float] = None,
253
) -> "triton.Config":
254
if baseline_timing is None:
255
baseline_timing = self.call_func(func, baseline_config)
257
log.debug("= Do coordinate descent tuning for %s =", self.name)
259
"Baseline Config %s, baseline timing %f", baseline_config, baseline_timing
262
best_config = baseline_config
263
best_timing = baseline_timing
264
tunable_fields = self.tunable_fields
269
for name in tunable_fields:
270
cur_val = get_field(best_config, name)
278
candidate_values = self.get_neighbour_values(name, cur_val)
280
for next_val in candidate_values:
281
candidate_config = copy.deepcopy(best_config)
282
set_field(candidate_config, name, next_val)
284
cmp_res, candidate_timing = self.compare_config(
285
func, candidate_config, best_config, best_timing
289
best_config, best_timing = candidate_config, candidate_timing
291
if not improved and inductor_config.coordinate_descent_check_all_directions:
292
old_best_timing = best_timing
293
improved, best_config, best_timing = self.check_all_tuning_directions(
294
func, best_config, best_timing
299
"Coordinate descend tuning found improvement of %.3fx by looking in all directions."
303
old_best_timing / best_timing,
307
"Improve from %s %f -> %s %f, %.3fx",
312
baseline_timing / best_timing,