quanto
1# Copyright 2024 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import timeit16
17import torch18from torch_int._CUDA import bmm_s8t_s8n_f32t19
20
21def mm(A, B):22return torch._int_mm(A.squeeze(), B.squeeze().transpose(1, 0))23
24
25A = torch.randint(1, 10, [1, 512, 12288]).type(torch.int8).cuda()26B = torch.randint(1, 10, [1, 512, 12288]).type(torch.int8).cuda()27print(A)28
29# Using torch int matmul
30# Warmup (slow)
31mm(A, B)32# Average on several calls
33it = 100034print("torch _int_mm")35print(timeit.Timer(lambda: mm(A, B)).timeit(it) / it)36
37# Using torch_int custom kernels
38# Warmup (slow)
39bmm_s8t_s8n_f32t(A, B, 0.1)40# Average on several calls
41it = 100042print("torch_int kernels")43print(timeit.Timer(lambda: bmm_s8t_s8n_f32t(A, B, 0.1)).timeit(it) / it)44
45# Using torch f16 matmul
46# Warmup (slow)
47A = A.type(torch.float16)48B = B.type(torch.float16)49torch.matmul(A, B.transpose(2, 1))50# Average on several calls
51it = 100052print("torch fp16 matmul")53print(timeit.Timer(lambda: torch.matmul(A, B.transpose(2, 1))).timeit(it) / it)54