pytorch

Форк
0
/
coordinate_descent_tuner.py 
315 строк · 10.1 Кб
1
import copy
2
import itertools
3
import logging
4
from typing import Callable, Optional
5

6
from torch.utils._triton import has_triton
7
from .utils import red_text, triton_config_to_hashable
8

9
if has_triton():
10
    import triton
11
else:
12
    triton = None
13

14
from . import config as inductor_config
15

16
log = logging.getLogger(__name__)
17

18

19
def get_field(config, name):
20
    if name == "num_warps":
21
        return config.num_warps
22
    elif name == "num_stages":
23
        return config.num_stages
24
    else:
25
        return config.kwargs.get(name, None)
26

27

28
def set_field(config, name, value):
29
    if name == "num_warps":
30
        config.num_warps = value
31
    elif name == "num_stages":
32
        config.num_stages = value
33
    else:
34
        config.kwargs[name] = value
35

36

37
class CoordescTuner:
38
    """
39
    The coordinate descent tuner. Tune one field/coordinate at a time.
40

41
    TODO will it be necessary to tune multiple fields simultaneously.
42

43

44
    TODO: what if both increasing and decreasing a field can improve perf.
45
          i.e., there are multiple local optima..
46
    """
47

48
    def __init__(self, is_mm=False, name="unknown", size_hints=None):
49
        self.is_mm = is_mm  # we will tune num_stages for mm
50
        self.cached_benchmark_results = {}
51
        self.name = name
52
        self.size_hints = size_hints
53

54
    def get_xmax(self):
55
        xmax = inductor_config.triton.max_block["X"]
56
        if self.size_hints and len(self.size_hints) > 0:
57
            xmax = min(xmax, self.size_hints[0])
58
        return xmax
59

60
    def get_ymax(self):
61
        ymax = inductor_config.triton.max_block["Y"]
62
        if self.size_hints and len(self.size_hints) > 1:
63
            ymax = min(ymax, self.size_hints[1])
64
        return ymax
65

66
    def get_zmax(self):
67
        zmax = inductor_config.triton.max_block["Z"]
68
        if self.size_hints and len(self.size_hints) > 2:
69
            zmax = min(zmax, self.size_hints[2])
70
        return zmax
71

72
    def get_rmax(self):
73
        if self.size_hints and len(self.size_hints) > 0:
74
            return self.size_hints[-1]  # the last one is for reduction
75
        else:
76
            # large enough. We should not pick this large RBLOCK anyway
77
            return 2**30
78

79
    def get_warpsmax(self):
80
        # Currently, CUDA has a maximum of 1024 threads, so 32 is the max
81
        # number of warps.
82
        return 1024 // 32
83

84
    def cache_benchmark_result(self, config, timing):
85
        self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
86

87
    def lookup_in_cache(self, config):
88
        return self.cached_benchmark_results.get(triton_config_to_hashable(config))
89

90
    def call_func(self, func, config):
91
        found = self.lookup_in_cache(config)
92
        if found is not None:
93
            log.debug("  CACHED")
94
            return found
95
        timing = func(config)
96
        self.cache_benchmark_result(config, timing)
97
        return timing
98

99
    @property
100
    def tunable_fields(self):
101
        out = [
102
            "XBLOCK",
103
            "YBLOCK",
104
            "ZBLOCK",
105
            # NOTE: we should not tune RBLOCK for persistent reduction.
106
            # We rely on the fact that persistent reduction's triton.Config
107
            # does not have the RBLOCK field to guarantee that.
108
            "RBLOCK",
109
            # the following 3 are for mm
110
            "BLOCK_M",
111
            "BLOCK_N",
112
            "BLOCK_K",
113
            "num_warps",
114
        ]
115
        if self.is_mm:
116
            out.append("num_stages")
117

118
        return out
119

120
    def value_too_large(self, name, val):
121
        if name == "XBLOCK":
122
            return val > self.get_xmax()
123
        if name == "YBLOCK":
124
            return val > self.get_ymax()
125
        if name == "ZBLOCK":
126
            return val > self.get_zmax()
127
        if name == "RBLOCK":
128
            return val > self.get_rmax()
129
        if name == "num_warps":
130
            return val > self.get_warpsmax()
131

132
        return False
133

134
    def get_neighbour_values(self, name, orig_val, radius=1, include_self=False):
135
        """
136
        Get neighbour values in 'radius' steps. The original value is not
137
        returned as it's own neighbour.
138
        """
139
        assert radius >= 1
140

141
        def update(cur_val, inc=True):
142
            if name == "num_stages":
143
                if inc:
144
                    return cur_val + 1
145
                else:
146
                    return cur_val - 1
147
            else:
148
                if inc:
149
                    return cur_val * 2
150
                else:
151
                    return cur_val // 2
152

153
        out = []
154
        # increment loop
155
        cur_val = orig_val
156
        for _ in range(radius):
157
            cur_val = update(cur_val, True)
158
            if self.value_too_large(name, cur_val):
159
                break
160
            out.append(cur_val)
161

162
        # decrement loop
163
        cur_val = orig_val
164
        for _ in range(radius):
165
            cur_val = update(cur_val, False)
166
            if cur_val <= 0:
167
                break
168
            out.append(cur_val)
169

170
        if include_self:
171
            out.append(orig_val)
172
        return out
173

174
    @staticmethod
175
    def has_improvement(baseline, test):
176
        threshold = 0.001  # 0.1%
177
        return test is not None and test < baseline * (1 - threshold)
178

179
    def check_all_tuning_directions(
180
        self,
181
        func: Callable[["triton.Config"], float],
182
        best_config,
183
        best_timing,
184
    ):
185
        """
186
        Check all directions. We only do this once the regular coordinate
187
        descent tuning find no better choices any more.
188
        We only have a few tunable fields, so this should be fine.
189
        """
190
        candidate_values_list = []
191
        effective_fields = []
192
        for field in self.tunable_fields:
193
            old_value = get_field(best_config, field)
194
            if old_value is None:
195
                continue
196
            candidate_values = self.get_neighbour_values(
197
                field,
198
                old_value,
199
                radius=inductor_config.coordinate_descent_search_radius,
200
                include_self=True,
201
            )
202
            candidate_values_list.append(candidate_values)
203
            effective_fields.append(field)
204

205
        choices = itertools.product(*candidate_values_list)
206
        improved = False
207
        for choice in choices:
208
            assert len(choice) == len(effective_fields)
209
            candidate_config = copy.deepcopy(best_config)
210
            for new_val, field in zip(choice, effective_fields):
211
                set_field(candidate_config, field, new_val)
212
            cmp_res, candidate_timing = self.compare_config(
213
                func, candidate_config, best_config, best_timing
214
            )
215
            if cmp_res:
216
                improved = True
217
                best_config = candidate_config
218
                best_timing = candidate_timing
219

220
        return improved, best_config, best_timing
221

222
    def compare_config(self, func, candidate_config, best_config, best_timing):
223
        """
224
        Check if candidate_config is better than best_config.
225

226
        Return a touple of (compare_result, candidate_timing).
227
        compare_result is true iff candidate_config is better.
228
        """
229
        log.debug("Try config %s", candidate_config)
230
        try:
231
            candidate_timing = self.call_func(func, candidate_config)
232
        except Exception as e:
233
            log.debug("Got exception %s", e)
234
            return False, float("inf")
235

236
        if self.has_improvement(best_timing, candidate_timing):
237
            log.debug(
238
                "Tune from %s %f -> %s %f",
239
                best_config,
240
                best_timing,
241
                candidate_config,
242
                candidate_timing,
243
            )
244

245
            return True, candidate_timing
246
        return False, candidate_timing
247

248
    def autotune(
249
        self,
250
        func: Callable[["triton.Config"], float],
251
        baseline_config: "triton.Config",
252
        baseline_timing: Optional[float] = None,
253
    ) -> "triton.Config":
254
        if baseline_timing is None:
255
            baseline_timing = self.call_func(func, baseline_config)
256

257
        log.debug("= Do coordinate descent tuning for %s =", self.name)
258
        log.debug(
259
            "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing
260
        )
261
        improved = True
262
        best_config = baseline_config
263
        best_timing = baseline_timing
264
        tunable_fields = self.tunable_fields
265

266
        while improved:
267
            improved = False
268

269
            for name in tunable_fields:
270
                cur_val = get_field(best_config, name)
271
                # some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None
272
                if cur_val is None:
273
                    continue
274

275
                # It's possible that candidate_values is empty.
276
                # E.g., if XBLOCK is 1 initially and size_hint for x is also 1.
277
                # We would not try either larger or smaller XBLOCK in this case.
278
                candidate_values = self.get_neighbour_values(name, cur_val)
279

280
                for next_val in candidate_values:
281
                    candidate_config = copy.deepcopy(best_config)
282
                    set_field(candidate_config, name, next_val)
283

284
                    cmp_res, candidate_timing = self.compare_config(
285
                        func, candidate_config, best_config, best_timing
286
                    )
287
                    if cmp_res:
288
                        improved = True
289
                        best_config, best_timing = candidate_config, candidate_timing
290

291
            if not improved and inductor_config.coordinate_descent_check_all_directions:
292
                old_best_timing = best_timing
293
                improved, best_config, best_timing = self.check_all_tuning_directions(
294
                    func, best_config, best_timing
295
                )
296

297
                if improved:
298
                    msg = red_text(
299
                        "Coordinate descend tuning found improvement of %.3fx by looking in all directions."
300
                    )
301
                    log.debug(
302
                        msg,
303
                        old_best_timing / best_timing,
304
                    )
305

306
        log.debug(
307
            "Improve from %s %f -> %s %f, %.3fx",
308
            baseline_config,
309
            baseline_timing,
310
            best_config,
311
            best_timing,
312
            baseline_timing / best_timing,
313
        )
314

315
        return best_config
316

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.