pytorch

Форк
0
310 строк · 10.6 Кб
1
import contextlib
2
import json
3
import os
4
import time
5

6
import numpy as np
7

8
import torch
9

10
from . import tensor_engine
11

12

13
class Benchmark:
14
    def __init__(self, mode, device, dtype):
15
        self.mode = mode
16
        self.deterministic = False
17
        self.device = device
18
        self.dtype = dtype
19
        self.output_type = "stdout"
20
        self.print_ir = False
21
        self.print_kernel = False
22
        if mode == "both":
23
            self.requires_grad = True
24
        elif mode == "fwd":
25
            self.requires_grad = False
26
        else:
27
            raise ValueError(f"invalid mode: {mode}")
28
        self.result_grad = None
29
        self.grad_variables = []
30
        self.engine = tensor_engine.get_engine()
31
        self.engine.reset(device)
32

33
        # forward all member functions in self.engine to self
34
        for method in dir(self.engine):
35
            if not callable(getattr(self.engine, method)):
36
                continue
37
            # don't forward if this function is overriden here
38
            if hasattr(self, method):
39
                continue
40
            # don't forward if it is a internal function
41
            if method.startswith("_"):
42
                continue
43
            method_engine = getattr(self.engine, method)
44
            setattr(self, method, method_engine)
45

46
    def forward(self):
47
        """do one step worth of computation"""
48
        raise ValueError("this method should be reimplemented by subclass")
49

50
    def check(self):
51
        if not self.deterministic:
52
            return
53
        np.testing.assert_allclose(
54
            self.reference(), self.numpy(self.compute()), atol=1e-2
55
        )
56

57
    def config(self):
58
        """returns an array for the current benchmark configs"""
59
        raise ValueError("this method should be reimplemented by subclass")
60

61
    def desc(self):
62
        """return the description of the current benchmark"""
63
        config = self.config()
64
        config_str = "_".join([str(x) for x in config])
65
        device = self.device
66
        if "NNC_NUM_THREADS" in os.environ:
67
            num_threads_str = os.environ["NNC_NUM_THREADS"]
68
            device += num_threads_str
69
        return f"{self.engine.mode}: {self.module()}_{self.mode}_{device}_{config_str}"
70

71
    @staticmethod
72
    def module():
73
        raise ValueError("this method should be reimplemented by subclass")
74

75
    def memory_workload(self):
76
        raise ValueError("this method should be reimplemented by subclass")
77

78
    def compute_workload(self):
79
        """return the number of scalar operations it takes to finish the tensor op"""
80
        return None
81

82
    @staticmethod
83
    def input_iterable():
84
        """A benchmark child class should return true if it utilizes the input iter arg"""
85
        return False
86

87
    def dtype_to_bytes(self):
88
        return torch.tensor(0, dtype=self.dtype).element_size()
89

90
    @staticmethod
91
    def default_configs():
92
        """return a list of defualt configs for this benchmark"""
93
        raise ValueError("this method should be reimplemented by subclass")
94

95
    def is_supported(self):
96
        return True
97

98
    def rand(self, shape, device=None, dtype=None, requires_grad=False):
99
        v = self.engine.rand(
100
            shape, device=device, dtype=dtype, requires_grad=requires_grad
101
        )
102
        if requires_grad:
103
            self.grad_variables.append(v)
104
        return v
105

106
    def nchw_rand(self, shape, device=None, requires_grad=False):
107
        v = self.engine.nchw_rand(shape, device=device, requires_grad=requires_grad)
108
        if requires_grad:
109
            self.grad_variables.append(v)
110
        return v
111

112
    def compute(self):
113
        if self.bm_jit:
114
            return self.bm_jit(*self.inputs)
115
        else:
116
            return self.forward(*self.inputs)
117

118
    def run(self, args):
119
        self.print_ir = args.print_ir
120
        if args.cuda_fuser == "old":
121
            torch._C._jit_override_can_fuse_on_gpu(True)
122
            if args.print_kernel:
123
                os.environ["PYTORCH_FUSION_DEBUG"] = "1"
124
            return self.run_impl(True)
125
        elif args.cuda_fuser == "te":
126
            torch._C._jit_set_texpr_fuser_enabled(True)
127
            with cuda_pointwise_context(
128
                args.cuda_pointwise_loop_levels,
129
                args.cuda_pointwise_block_count,
130
                args.cuda_pointwise_block_size,
131
            ):
132
                return self.run_impl(True)
133
        elif args.cuda_fuser == "nvf":
134
            torch._C._jit_set_nvfuser_enabled(True)
135
            torch._C._jit_set_profiling_executor(True)
136
            torch._C._jit_set_profiling_mode(True)
137
            torch._C._jit_override_can_fuse_on_cpu(False)
138
            torch._C._jit_override_can_fuse_on_gpu(False)
139
            torch._C._jit_set_bailout_depth(20)
140
            if args.print_kernel:
141
                os.environ["PYTORCH_CUDA_FUSER_DEBUG"] = "1"
142
            return self.run_impl(True)
143
        else:
144
            return self.run_impl(False)
145

146
    def run_impl(self, use_fuser):
147
        warmups = 10
148
        if self.device == "cuda":
149
            iters = 1000
150
        else:
151
            iters = 10
152
        engine = tensor_engine.get_engine()
153

154
        self.bm_jit = None
155
        for i in range(warmups + iters):
156
            if i == warmups:
157
                if self.device == "cuda":
158
                    engine.sync_cuda()
159
                time_start = time.time()
160

161
            if i == 0:
162
                if self.jit_mode == "trace" and use_fuser:
163
                    self.bm_jit = torch.jit.trace(
164
                        self.forward, example_inputs=self.inputs, check_trace=False
165
                    )
166
                if callable(getattr(self, "reference", None)):
167
                    self.check()
168
                else:
169
                    print("Warning: no reference result for ", self.module())
170
            elif i == 1:
171
                # The fusion graph is visible after the first iter is executed
172
                if self.jit_mode == "trace" and use_fuser and self.print_ir:
173
                    print(self.bm_jit.graph_for(*self.inputs))
174
            z = self.compute()
175
            if self.mode == "both":
176
                if self.result_grad is None:
177
                    self.result_grad = engine.rand_like(z)
178
                engine.backward([z], [self.result_grad], self.grad_variables)
179

180
        if self.device == "cuda":
181
            engine.sync_cuda()
182

183
        duration = time.time() - time_start
184
        iter_time = duration / iters
185
        memory_workload = self.memory_workload()
186
        compute_workload = self.compute_workload()
187

188
        result_dict = {
189
            "desc": self.desc(),
190
            "us": iter_time * 1e6,
191
            "sol": memory_workload["sol"] * self.dtype_to_bytes() / iter_time / 1e9,
192
            "algorithmic": memory_workload["algorithmic"]
193
            * self.dtype_to_bytes()
194
            / iter_time
195
            / 1e9,
196
        }
197
        if compute_workload:
198
            result_dict["compute_workload"] = compute_workload / iter_time / 1e9
199
        self.dump_result(result_dict)
200

201
    def dump_result(self, result_dict):
202
        if self.output_type == "json":
203
            print(json.dumps(result_dict))
204
        elif self.output_type == "stdout":
205
            msg = "{}: {:.2f} us, SOL {:.2f} GB/s, algorithmic {:.2f} GB/s".format(
206
                result_dict["desc"],
207
                result_dict["us"],
208
                result_dict["sol"],
209
                result_dict["algorithmic"],
210
            )
211
            if "compute_workload" in result_dict:
212
                msg += f", compute {result_dict['compute_workload']:.2f} Gops/s"
213
            print(msg)
214
        else:
215
            raise Exception("Unknown output_type " + self.output_type)  # noqa: TRY002
216

217

218
@contextlib.contextmanager
219
def cuda_pointwise_context(loop_levels, block_count, block_size):
220
    if loop_levels:
221
        old_loop_levels = torch._C._jit_get_te_cuda_pointwise_loop_levels()
222
        torch._C._jit_set_te_cuda_pointwise_loop_levels(loop_levels)
223
    if block_count:
224
        old_block_count = torch._C._jit_get_te_cuda_pointwise_block_count()
225
        torch._C._jit_set_te_cuda_pointwise_block_count(block_count)
226
    if block_size:
227
        old_block_size = torch._C._jit_get_te_cuda_pointwise_block_size()
228
        torch._C._jit_set_te_cuda_pointwise_block_size(block_size)
229

230
    try:
231
        yield
232
    finally:
233
        if loop_levels:
234
            torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels)
235
        if block_count:
236
            torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count)
237
        if block_size:
238
            torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size)
239

240

241
# Auxiliary class to facilitate dynamic input shape
242
class DynamicShape:
243
    r"""
244
    An Auxiliary class for dynamic shape benchmarks
245

246
    Pre-computes input with random shapes and also
247
    modifies the compute method so in each call the
248
    fuser sees a different input tensor shape
249
    """
250

251
    # Number of random inputs in an instance
252
    SAMPLE_SIZE = 100
253

254
    def __init__(self, dynamic_range=1.2):
255
        self._input_samples = []
256
        self._input_sample_index = 0
257
        self._dynamic_range = (
258
            1.0 / dynamic_range if dynamic_range > 1.0 else dynamic_range
259
        )
260
        self._enable_dynamic_shapes = True
261

262
    # Returns the input test case that current index points to
263
    @property
264
    def inputs(self):
265
        return self._input_samples[self._input_sample_index]
266

267
    # An inputs assignment actually adds a test case in the class buffer
268
    @inputs.setter
269
    def inputs(self, val):
270
        self._input_samples.append(val)
271

272
    # Runs normal compute while increment test case index
273
    def compute(self):
274
        super().compute()
275
        self._input_sample_index = (self._input_sample_index + 1) % self.SAMPLE_SIZE
276

277
    # Defined by benchmark, the benchmark needs to specify the input
278
    # tensor construction in this method, essentially the same way
279
    # a benchmark creates the inputs list in the initializer
280
    def instantiate_input(self):
281
        raise NotImplementedError
282

283
    # Instantiate random shaped inputs and start the benchmark run
284
    def run(self, args):
285
        # force disable dynamic shape from command line
286
        if args.no_dynamic_shape:
287
            self._enable_dynamic_shapes = False
288
        self.load_inputs()
289
        super().run(args)
290

291
    # pre-compute inputs so the creations of random tensors
292
    # do not add to the compute time
293
    def load_inputs(self):
294
        for i in range(self.SAMPLE_SIZE - 1):
295
            self.instantiate_input()
296

297
    # returns a randomized shape
298
    def rand_shape(self, shape):
299
        if not self._enable_dynamic_shapes:
300
            return shape
301
        ratios = np.random.uniform(self._dynamic_range, 1.0, len(shape))
302
        dyn_shape = list(np.multiply(shape, ratios).astype(int))
303
        return dyn_shape
304

305

306
benchmark_classes = []
307

308

309
def register_benchmark_class(benchmark_cls):
310
    benchmark_classes.append(benchmark_cls)
311

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.