quanto

Форк
0
/
benchmark.py 
128 строк · 4.2 Кб
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import time
17
from contextlib import nullcontext
18

19
import numpy as np
20
import torch
21
from tqdm.auto import tqdm
22

23
from quanto.library import disable_extensions
24

25

26
def get_dqmm_bench(input_dtype, device):
27
    input = torch.rand([1024, 1024], dtype=input_dtype).to(device)
28
    other = torch.randint(-127, 127, [1024, 1024], dtype=torch.int8).to(device)
29
    other_scale = torch.ones((1024,), dtype=input_dtype, device=device) * 0.5
30

31
    def bench_fn():
32
        return torch.ops.quanto.dqmm(input, other, other_scale)
33

34
    return bench_fn
35

36

37
def get_unpack_bench(bits, device):
38
    qmax = 2**bits
39
    a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device)
40

41
    def bench_fn():
42
        return torch.ops.quanto.unpack(a, bits)
43

44
    return bench_fn
45

46

47
def timing(get_bench_func, device, iterations=10):
48
    def synchronize(device):
49
        if device.type == "cuda":
50
            torch.cuda.synchronize()
51
        elif device.type == "mps":
52
            torch.mps.synchronize()
53
        else:
54
            torch.cpu.synchronize()
55

56
    def timing_event(device):
57
        if device.type == "cuda":
58
            return torch.cuda.Event(enable_timing=True)
59
        elif device.type == "mps":
60
            return torch.mps.Event(enable_timing=True)
61

62
        class CPUEvent:
63
            def __init__(self):
64
                self.time = None
65

66
            def record(self):
67
                self.time = time.time()
68

69
            def elapsed_time(self, other):
70
                assert self.time is not None
71
                assert other.time is not None
72
                return (other.time - self.time) * 1000
73

74
        return CPUEvent()
75

76
    synchronize(device)
77

78
    bench_func = get_bench_func(device)
79
    # Warmup to load library
80
    bench_func()
81
    latencies = np.empty((iterations, 2))
82
    for i in tqdm(range(iterations)):
83
        for j, context in enumerate([disable_extensions(), nullcontext()]):
84
            start_event = timing_event(device)
85
            end_event = timing_event(device)
86
            synchronize(device)
87
            start_event.record()
88
            with context:
89
                bench_func()
90
            end_event.record()
91
            synchronize(device)
92
            latencies[i, j] = start_event.elapsed_time(end_event)
93
    return np.mean(latencies[:, 0]), np.mean(latencies[:, 1])
94

95

96
GET_BENCH_FUNCTIONS = {
97
    "dqmm_w8a16": lambda device: get_dqmm_bench(torch.float16, device),
98
    "unpack_2bit": lambda device: get_unpack_bench(2, device),
99
    "unpack_4bit": lambda device: get_unpack_bench(4, device),
100
}
101

102

103
def main():
104
    parser = argparse.ArgumentParser(description="Kernel benchmark")
105
    parser.add_argument("--kernel", type=str, default=None, help="The kernel to benchmark. None to test all of them")
106
    parser.add_argument("--device", type=str, default=None, help="The device to use for benchmark.")
107
    parser.add_argument("--it", type=int, default=10, help="The number of benchmark iterations")
108
    args = parser.parse_args()
109
    if args.device is None:
110
        if torch.cuda.is_available():
111
            device = torch.device("cuda")
112
        elif torch.backends.mps.is_available():
113
            device = torch.device("mps")
114
        else:
115
            device = torch.device("cpu")
116
    else:
117
        device = torch.device(args.device)
118
    all_kernels = GET_BENCH_FUNCTIONS.keys()
119
    kernels = all_kernels if args.kernel is None else [args.kernel]
120
    for kernel in kernels:
121
        get_bench_fn = GET_BENCH_FUNCTIONS[kernel]
122
        python_ms, ext_ms = timing(get_bench_fn, device, iterations=args.it)
123
        ratio = python_ms / ext_ms
124
        print(f"\n{kernel}[{device.type}]: python = {python_ms:.3f} ms, ext = {ext_ms:.3f} ms, ratio = {ratio:.1f}x")
125

126

127
if __name__ == "__main__":
128
    main()
129

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.