DeepSpeed
Зеркало из https://github.com/microsoft/DeepSpeed
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6import torch
7
8from deepspeed.ops.op_builder import QuantizerBuilder
9
10# Cuda modules will be imported if needed
11quantizer_cuda_module = None
12
13
14def ds_quantizer(input, groups=1, bit_num=8, sr=False, asym=False):
15# Load cuda modules if needed
16global quantizer_cuda_module
17if quantizer_cuda_module is None:
18quantizer_cuda_module = QuantizerBuilder().load()
19if sr:
20if asym:
21quantize_func = quantizer_cuda_module.ds_sr_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_asym_fp32
22else:
23quantize_func = quantizer_cuda_module.ds_sr_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_fp32
24else:
25if asym:
26quantize_func = quantizer_cuda_module.ds_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_asym_fp32
27else:
28quantize_func = quantizer_cuda_module.ds_quantize_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_quantize_fp32
29return quantize_func(input, groups, bit_num)
30