DeepSpeed

Зеркало из https://github.com/microsoft/DeepSpeed
Форк
0
161 строка · 6.1 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
import torch
7
import abc
8
from abc import ABC
9

10
import gc
11
from deepspeed.ops.op_builder import FPQuantizerBuilder
12
from deepspeed.accelerator import get_accelerator
13

14
fp_quant_module = None
15

16

17
class Quantizer(ABC):
18
    """
19
    Abstract Quantizer class that implmenents quantize/dequantize methods.
20

21
    Arguments:
22
        group_size (int, optional): number of values or elements that are grouped
23
            together for the quantization process.
24
    """
25

26
    def __init__(self, group_size=512) -> None:
27
        self.group_size = group_size
28

29
    @abc.abstractmethod
30
    def quantize(self,
31
                 input,
32
                 q_bits=8,
33
                 q_mantisa_bits=3,
34
                 stochastic_mode=False,
35
                 return_meta_tensor=False) -> torch.Tensor:
36
        ...
37

38
    @abc.abstractmethod
39
    def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
40
        ...
41

42

43
class FP_Quantize(Quantizer):
44

45
    def __init__(self, group_size=512) -> None:
46
        global fp_quant_module
47
        super().__init__(group_size=group_size)
48
        if fp_quant_module is None:
49
            fp_quant_module = FPQuantizerBuilder().load()
50
        self.orig_dtype = None
51

52
    def quantize(self,
53
                 input,
54
                 q_bits=8,
55
                 q_mantisa_bits=3,
56
                 stochastic_mode=False,
57
                 return_meta_tensor=False) -> torch.Tensor:
58
        assert input.dtype == torch.bfloat16, "only support bf16 for now"
59
        if return_meta_tensor:
60
            assert q_bits == 8, "meta tensor is only supported with q_bit=8"
61

62
        self.orig_dtype = input.dtype
63
        self.orig_shape = input.shape
64

65
        if q_bits == 8:
66
            pass
67
        elif q_bits == 12:
68
            q_mantisa_bits = 4
69
        elif q_bits == 6:
70
            q_mantisa_bits = 2
71
        elif q_bits == 4:
72
            q_mantisa_bits = 1
73
        else:
74
            assert (0), \
75
                f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
76
        self.num_groups = input.numel() // self.group_size
77
        self.input_q = torch.ones(self.num_groups,
78
                                  int(self.group_size * q_bits) // 8 + 4,
79
                                  dtype=torch.uint8,
80
                                  device=input.device)
81
        out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
82
        if return_meta_tensor:
83
            data, self.scale = out.split(self.group_size, dim=-1)
84
            data = data.contiguous().reshape(input.shape)
85
            self.scale = self.scale.contiguous()
86
            del self.input_q
87
            del out
88
            gc.collect()
89
            get_accelerator().empty_cache()
90
            return data, self.scale
91

92
        return out
93

94
    def to(self, *args, **kwargs):
95
        # Intermediate tensors may need to be moved to different devices
96
        if hasattr(self, 'input_q'):
97
            self.input_q = self.input_q.to(*args, **kwargs)
98
        if hasattr(self, 'scale'):
99
            self.scale = self.scale.to(*args, **kwargs)
100

101
    def get_scales(self):
102
        return fp_quant_module.get_scales(self.scale, self.num_groups)
103

104
    def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
105
        assert (self.orig_dtype is not None), \
106
            "[De-quantization Error]: you need to call quantize before dequantizing!"
107
        fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype,
108
                             device=input_q.device) if fp_out is None else fp_out
109
        if q_bits == 8:
110
            pass
111
        elif q_bits == 12:
112
            q_mantisa_bits = 4
113
        elif q_bits == 6:
114
            q_mantisa_bits = 2
115
        elif q_bits == 4:
116
            q_mantisa_bits = 1
117
        else:
118
            assert (0), \
119
                f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
120

121
        if scale is not None:
122
            assert input_q.numel() == fp_out.numel(), \
123
            f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
124
            input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
125
        fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
126
        return fp_out
127

128
    def selective_dequantize(self,
129
                             input_q,
130
                             indexes,
131
                             fp_out=None,
132
                             q_bits=8,
133
                             q_mantisa_bits=3,
134
                             scale=None) -> torch.Tensor:
135
        assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \
136
            "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function."
137
        assert (self.orig_dtype is not None), \
138
            "[De-quantization Error]: you need to call quantize before dequantizing!"
139
        fp_out = torch.empty(
140
            (indexes.shape[0],
141
             *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out
142
        if q_bits == 8:
143
            pass
144
        elif q_bits == 12:
145
            q_mantisa_bits = 4
146
        elif q_bits == 6:
147
            q_mantisa_bits = 2
148
        elif q_bits == 4:
149
            q_mantisa_bits = 1
150
        else:
151
            assert (0), \
152
                f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
153

154
        if scale is not None:
155
            assert input_q.numel() == fp_out.numel(), \
156
            f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
157
            input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
158

159
        fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
160
                                             q_bits - q_mantisa_bits - 1)
161
        return fp_out
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.