DeepSpeed
Зеркало из https://github.com/microsoft/DeepSpeed
161 строка · 6.1 Кб
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6import torch7import abc8from abc import ABC9
10import gc11from deepspeed.ops.op_builder import FPQuantizerBuilder12from deepspeed.accelerator import get_accelerator13
14fp_quant_module = None15
16
17class Quantizer(ABC):18"""19Abstract Quantizer class that implmenents quantize/dequantize methods.
20
21Arguments:
22group_size (int, optional): number of values or elements that are grouped
23together for the quantization process.
24"""
25
26def __init__(self, group_size=512) -> None:27self.group_size = group_size28
29@abc.abstractmethod30def quantize(self,31input,32q_bits=8,33q_mantisa_bits=3,34stochastic_mode=False,35return_meta_tensor=False) -> torch.Tensor:36...37
38@abc.abstractmethod39def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:40...41
42
43class FP_Quantize(Quantizer):44
45def __init__(self, group_size=512) -> None:46global fp_quant_module47super().__init__(group_size=group_size)48if fp_quant_module is None:49fp_quant_module = FPQuantizerBuilder().load()50self.orig_dtype = None51
52def quantize(self,53input,54q_bits=8,55q_mantisa_bits=3,56stochastic_mode=False,57return_meta_tensor=False) -> torch.Tensor:58assert input.dtype == torch.bfloat16, "only support bf16 for now"59if return_meta_tensor:60assert q_bits == 8, "meta tensor is only supported with q_bit=8"61
62self.orig_dtype = input.dtype63self.orig_shape = input.shape64
65if q_bits == 8:66pass67elif q_bits == 12:68q_mantisa_bits = 469elif q_bits == 6:70q_mantisa_bits = 271elif q_bits == 4:72q_mantisa_bits = 173else:74assert (0), \75f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"76self.num_groups = input.numel() // self.group_size77self.input_q = torch.ones(self.num_groups,78int(self.group_size * q_bits) // 8 + 4,79dtype=torch.uint8,80device=input.device)81out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)82if return_meta_tensor:83data, self.scale = out.split(self.group_size, dim=-1)84data = data.contiguous().reshape(input.shape)85self.scale = self.scale.contiguous()86del self.input_q87del out88gc.collect()89get_accelerator().empty_cache()90return data, self.scale91
92return out93
94def to(self, *args, **kwargs):95# Intermediate tensors may need to be moved to different devices96if hasattr(self, 'input_q'):97self.input_q = self.input_q.to(*args, **kwargs)98if hasattr(self, 'scale'):99self.scale = self.scale.to(*args, **kwargs)100
101def get_scales(self):102return fp_quant_module.get_scales(self.scale, self.num_groups)103
104def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:105assert (self.orig_dtype is not None), \106"[De-quantization Error]: you need to call quantize before dequantizing!"107fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype,108device=input_q.device) if fp_out is None else fp_out109if q_bits == 8:110pass111elif q_bits == 12:112q_mantisa_bits = 4113elif q_bits == 6:114q_mantisa_bits = 2115elif q_bits == 4:116q_mantisa_bits = 1117else:118assert (0), \119f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"120
121if scale is not None:122assert input_q.numel() == fp_out.numel(), \123f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'124input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()125fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)126return fp_out127
128def selective_dequantize(self,129input_q,130indexes,131fp_out=None,132q_bits=8,133q_mantisa_bits=3,134scale=None) -> torch.Tensor:135assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \136"Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function."137assert (self.orig_dtype is not None), \138"[De-quantization Error]: you need to call quantize before dequantizing!"139fp_out = torch.empty(140(indexes.shape[0],141*self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out142if q_bits == 8:143pass144elif q_bits == 12:145q_mantisa_bits = 4146elif q_bits == 6:147q_mantisa_bits = 2148elif q_bits == 4:149q_mantisa_bits = 1150else:151assert (0), \152f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"153
154if scale is not None:155assert input_q.numel() == fp_out.numel(), \156f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'157input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()158
159fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,160q_bits - q_mantisa_bits - 1)161return fp_out162