HairFastGAN
36 строк · 870.0 Байт
1import functools2import sys3import time4
5import numpy as np6import torch7
8
9def get_time():10torch.cuda.current_stream().synchronize()11return time.time()12
13
14def bench_session(func):15times = []16
17@functools.wraps(func)18def wraps(*args, **kwargs):19if kwargs.pop('benchmark', False):20nonlocal times21start = get_time()22
23result = func(*args, **kwargs)24
25eval_time = get_time() - start26times.append(eval_time)27
28print(f'\n{len(times)} experiment ended in {eval_time:.3f}(s)', file=sys.stderr)29print(f'min time: {np.min(times):.3f}(s),'30f' median time: {np.median(times):.3f}(s),'31f' std time: {np.std(times):.3f}(s)', file=sys.stderr)32return result33else:34return func(*args, **kwargs)35
36return wraps37