17
from contextlib import nullcontext
21
from tqdm.auto import tqdm
23
from quanto.library import disable_extensions
26
def get_dqmm_bench(input_dtype, device):
27
input = torch.rand([1024, 1024], dtype=input_dtype).to(device)
28
other = torch.randint(-127, 127, [1024, 1024], dtype=torch.int8).to(device)
29
other_scale = torch.ones((1024,), dtype=input_dtype, device=device) * 0.5
32
return torch.ops.quanto.dqmm(input, other, other_scale)
37
def get_unpack_bench(bits, device):
39
a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device)
42
return torch.ops.quanto.unpack(a, bits)
47
def timing(get_bench_func, device, iterations=10):
48
def synchronize(device):
49
if device.type == "cuda":
50
torch.cuda.synchronize()
51
elif device.type == "mps":
52
torch.mps.synchronize()
54
torch.cpu.synchronize()
56
def timing_event(device):
57
if device.type == "cuda":
58
return torch.cuda.Event(enable_timing=True)
59
elif device.type == "mps":
60
return torch.mps.Event(enable_timing=True)
67
self.time = time.time()
69
def elapsed_time(self, other):
70
assert self.time is not None
71
assert other.time is not None
72
return (other.time - self.time) * 1000
78
bench_func = get_bench_func(device)
81
latencies = np.empty((iterations, 2))
82
for i in tqdm(range(iterations)):
83
for j, context in enumerate([disable_extensions(), nullcontext()]):
84
start_event = timing_event(device)
85
end_event = timing_event(device)
92
latencies[i, j] = start_event.elapsed_time(end_event)
93
return np.mean(latencies[:, 0]), np.mean(latencies[:, 1])
96
GET_BENCH_FUNCTIONS = {
97
"dqmm_w8a16": lambda device: get_dqmm_bench(torch.float16, device),
98
"unpack_2bit": lambda device: get_unpack_bench(2, device),
99
"unpack_4bit": lambda device: get_unpack_bench(4, device),
104
parser = argparse.ArgumentParser(description="Kernel benchmark")
105
parser.add_argument("--kernel", type=str, default=None, help="The kernel to benchmark. None to test all of them")
106
parser.add_argument("--device", type=str, default=None, help="The device to use for benchmark.")
107
parser.add_argument("--it", type=int, default=10, help="The number of benchmark iterations")
108
args = parser.parse_args()
109
if args.device is None:
110
if torch.cuda.is_available():
111
device = torch.device("cuda")
112
elif torch.backends.mps.is_available():
113
device = torch.device("mps")
115
device = torch.device("cpu")
117
device = torch.device(args.device)
118
all_kernels = GET_BENCH_FUNCTIONS.keys()
119
kernels = all_kernels if args.kernel is None else [args.kernel]
120
for kernel in kernels:
121
get_bench_fn = GET_BENCH_FUNCTIONS[kernel]
122
python_ms, ext_ms = timing(get_bench_fn, device, iterations=args.it)
123
ratio = python_ms / ext_ms
124
print(f"\n{kernel}[{device.type}]: python = {python_ms:.3f} ms, ext = {ext_ms:.3f} ms, ratio = {ratio:.1f}x")
127
if __name__ == "__main__":