pytorch

Форк
0
/
spectral_ops_fuzz_test.py 
113 строк · 4.6 Кб
1
"""Microbenchmarks for the torch.fft module"""
2
from argparse import ArgumentParser
3
from collections import namedtuple
4
from collections.abc import Iterable
5

6
import torch
7
import torch.fft
8
from torch.utils import benchmark
9
from torch.utils.benchmark.op_fuzzers.spectral import SpectralOpFuzzer
10

11

12
def _dim_options(ndim):
13
    if ndim == 1:
14
        return [None]
15
    elif ndim == 2:
16
        return [0, 1, None]
17
    elif ndim == 3:
18
        return [0, 1, 2, (0, 1), (0, 2), None]
19
    raise ValueError(f"Expected ndim in range 1-3, got {ndim}")
20

21

22
def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, device: str, samples: int,
23
                  probability_regular: float):
24
    cuda = device == 'cuda'
25
    spectral_fuzzer = SpectralOpFuzzer(seed=seed, dtype=dtype, cuda=cuda,
26
                                       probability_regular=probability_regular)
27
    results = []
28
    for tensors, tensor_params, params in spectral_fuzzer.take(samples):
29
        shape = [params['k0'], params['k1'], params['k2']][:params['ndim']]
30
        str_shape = ' x '.join([f"{s:<4}" for s in shape])
31
        sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
32
        for dim in _dim_options(params['ndim']):
33
            for nthreads in (1, 4, 16) if not cuda else (1,):
34
                measurement = benchmark.Timer(
35
                    stmt='func(x, dim=dim)',
36
                    globals={'func': function, 'x': tensors['x'], 'dim': dim},
37
                    label=f"{name}_{device}",
38
                    sub_label=sub_label,
39
                    description=f"dim={dim}",
40
                    num_threads=nthreads,
41
                ).blocked_autorange(min_run_time=1)
42
                measurement.metadata = {
43
                    'name': name,
44
                    'device': device,
45
                    'dim': dim,
46
                    'shape': shape,
47
                }
48
                measurement.metadata.update(tensor_params['x'])
49
                results.append(measurement)
50
    return results
51

52

53
Benchmark = namedtuple('Benchmark', ['name', 'function', 'dtype'])
54
BENCHMARKS = [
55
    Benchmark('fft_real', torch.fft.fftn, torch.float32),
56
    Benchmark('fft_complex', torch.fft.fftn, torch.complex64),
57
    Benchmark('ifft', torch.fft.ifftn, torch.complex64),
58
    Benchmark('rfft', torch.fft.rfftn, torch.float32),
59
    Benchmark('irfft', torch.fft.irfftn, torch.complex64),
60
]
61
BENCHMARK_MAP = {b.name: b for b in BENCHMARKS}
62
BENCHMARK_NAMES = [b.name for b in BENCHMARKS]
63
DEVICE_NAMES = ['cpu', 'cuda']
64

65
def _output_csv(file, results):
66
    file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n')
67
    for measurement in results:
68
        metadata = measurement.metadata
69
        device, dim, shape, name, numel, contiguous = (
70
            metadata['device'], metadata['dim'], metadata['shape'],
71
            metadata['name'], metadata['numel'], metadata['is_contiguous'])
72

73
        if isinstance(dim, Iterable):
74
            dim_str = '-'.join(str(d) for d in dim)
75
        else:
76
            dim_str = str(dim)
77
            shape_str = 'x'.join(str(s) for s in shape)
78

79
        print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str,  # type: ignore[possibly-undefined]
80
              measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
81
              sep=',', file=file)
82

83

84
if __name__ == '__main__':
85
    parser = ArgumentParser(description=__doc__)
86
    parser.add_argument('--device', type=str, choices=DEVICE_NAMES, nargs='+', default=DEVICE_NAMES)
87
    parser.add_argument('--bench', type=str, choices=BENCHMARK_NAMES, nargs='+', default=BENCHMARK_NAMES)
88
    parser.add_argument('--seed', type=int, default=0)
89
    parser.add_argument('--samples', type=int, default=10)
90
    parser.add_argument('--probability-regular', '--probability_regular', type=float, default=1.0)
91
    parser.add_argument('-o', '--output', type=str)
92
    args = parser.parse_args()
93

94
    num_benchmarks = len(args.device) * len(args.bench)
95
    i = 0
96
    results = []
97
    for device in args.device:
98
        for bench in (BENCHMARK_MAP[b] for b in args.bench):
99
            results += run_benchmark(
100
                name=bench.name, function=bench.function, dtype=bench.dtype,
101
                seed=args.seed, device=device, samples=args.samples,
102
                probability_regular=args.probability_regular)
103
            i += 1
104
            print(f'Completed {bench.name} benchmark on {device} ({i} of {num_benchmarks})')
105

106
    if args.output is not None:
107
        with open(args.output, 'w') as f:
108
            _output_csv(f, results)
109

110
    compare = benchmark.Compare(results)
111
    compare.trim_significant_figures()
112
    compare.colorize()
113
    compare.print()
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.