1
"""Microbenchmarks for the torch.fft module"""
2
from argparse import ArgumentParser
3
from collections import namedtuple
4
from collections.abc import Iterable
8
from torch.utils import benchmark
9
from torch.utils.benchmark.op_fuzzers.spectral import SpectralOpFuzzer
12
def _dim_options(ndim):
18
return [0, 1, 2, (0, 1), (0, 2), None]
19
raise ValueError(f"Expected ndim in range 1-3, got {ndim}")
22
def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, device: str, samples: int,
23
probability_regular: float):
24
cuda = device == 'cuda'
25
spectral_fuzzer = SpectralOpFuzzer(seed=seed, dtype=dtype, cuda=cuda,
26
probability_regular=probability_regular)
28
for tensors, tensor_params, params in spectral_fuzzer.take(samples):
29
shape = [params['k0'], params['k1'], params['k2']][:params['ndim']]
30
str_shape = ' x '.join([f"{s:<4}" for s in shape])
31
sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
32
for dim in _dim_options(params['ndim']):
33
for nthreads in (1, 4, 16) if not cuda else (1,):
34
measurement = benchmark.Timer(
35
stmt='func(x, dim=dim)',
36
globals={'func': function, 'x': tensors['x'], 'dim': dim},
37
label=f"{name}_{device}",
39
description=f"dim={dim}",
41
).blocked_autorange(min_run_time=1)
42
measurement.metadata = {
48
measurement.metadata.update(tensor_params['x'])
49
results.append(measurement)
53
Benchmark = namedtuple('Benchmark', ['name', 'function', 'dtype'])
55
Benchmark('fft_real', torch.fft.fftn, torch.float32),
56
Benchmark('fft_complex', torch.fft.fftn, torch.complex64),
57
Benchmark('ifft', torch.fft.ifftn, torch.complex64),
58
Benchmark('rfft', torch.fft.rfftn, torch.float32),
59
Benchmark('irfft', torch.fft.irfftn, torch.complex64),
61
BENCHMARK_MAP = {b.name: b for b in BENCHMARKS}
62
BENCHMARK_NAMES = [b.name for b in BENCHMARKS]
63
DEVICE_NAMES = ['cpu', 'cuda']
65
def _output_csv(file, results):
66
file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n')
67
for measurement in results:
68
metadata = measurement.metadata
69
device, dim, shape, name, numel, contiguous = (
70
metadata['device'], metadata['dim'], metadata['shape'],
71
metadata['name'], metadata['numel'], metadata['is_contiguous'])
73
if isinstance(dim, Iterable):
74
dim_str = '-'.join(str(d) for d in dim)
77
shape_str = 'x'.join(str(s) for s in shape)
79
print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str, # type: ignore[possibly-undefined]
80
measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
84
if __name__ == '__main__':
85
parser = ArgumentParser(description=__doc__)
86
parser.add_argument('--device', type=str, choices=DEVICE_NAMES, nargs='+', default=DEVICE_NAMES)
87
parser.add_argument('--bench', type=str, choices=BENCHMARK_NAMES, nargs='+', default=BENCHMARK_NAMES)
88
parser.add_argument('--seed', type=int, default=0)
89
parser.add_argument('--samples', type=int, default=10)
90
parser.add_argument('--probability-regular', '--probability_regular', type=float, default=1.0)
91
parser.add_argument('-o', '--output', type=str)
92
args = parser.parse_args()
94
num_benchmarks = len(args.device) * len(args.bench)
97
for device in args.device:
98
for bench in (BENCHMARK_MAP[b] for b in args.bench):
99
results += run_benchmark(
100
name=bench.name, function=bench.function, dtype=bench.dtype,
101
seed=args.seed, device=device, samples=args.samples,
102
probability_regular=args.probability_regular)
104
print(f'Completed {bench.name} benchmark on {device} ({i} of {num_benchmarks})')
106
if args.output is not None:
107
with open(args.output, 'w') as f:
108
_output_csv(f, results)
110
compare = benchmark.Compare(results)
111
compare.trim_significant_figures()