google-research

Форк
0
/
quantization_utils.py 
114 строк · 4.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Helper routines for quantization."""
17

18
from typing import Any
19

20
import chex
21
from flax import struct
22
import jax.numpy as jnp
23

24

25
# pylint:disable=no-value-for-parameter
26
@struct.dataclass
27
class QuantizedValue:
28
  """State associated with quantized value."""
29
  quantized: chex.Array
30
  diagonal: chex.Array  # Diagonal (if extract_diagonal is set)
31
  bucket_size: chex.Array
32
  quantized_dtype: jnp.dtype = struct.field(
33
      pytree_node=False)  # Dtype for the quantized value.
34
  extract_diagonal: bool = struct.field(
35
      pytree_node=False)  # In case its centered.
36
  shape: Any = struct.field(pytree_node=False)  # Shape of the tensor.
37

38
  @classmethod
39
  def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
40
    if isinstance(fvalue, list) and not fvalue:
41
      return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
42
    quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
43
        fvalue, quantized_dtype, extract_diagonal)
44
    return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
45
                          quantized_dtype, extract_diagonal,
46
                          list(quantized.shape))
47

48
  # Quantization is from Lingvo JAX optimizers.
49
  # We extend it for int16 quantization of PSD matrices.
50
  @classmethod
51
  def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
52
    """Returns quantized value and the bucket."""
53
    if quantized_dtype == jnp.float32:
54
      return fvalue, [], []
55
    elif quantized_dtype == jnp.bfloat16:
56
      return fvalue.astype(jnp.bfloat16), [], []
57

58
    float_dtype = fvalue.dtype
59
    if quantized_dtype == jnp.int8:
60
      # value -128 is not used.
61
      num_buckets = jnp.array(127.0, dtype=float_dtype)
62
    elif quantized_dtype == jnp.int16:
63
      # value -32768 is not used.
64
      num_buckets = jnp.array(32767.0, dtype=float_dtype)
65
    else:
66
      raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
67
    # max value is mapped to num_buckets
68

69
    if extract_diagonal and fvalue.ndim != 2:
70
      raise ValueError(
71
          f'Input array {fvalue} must be 2D to work with extract_diagonal.')
72

73
    diagonal_fvalue = []
74
    if extract_diagonal:
75
      diagonal_fvalue = jnp.diag(fvalue)
76
      # Remove the diagonal entries.
77
      fvalue = fvalue - jnp.diag(diagonal_fvalue)
78

79
    # TODO(rohananil): Extend this by making use of information about the blocks
80
    # SM3 style which will be useful for diagonal statistics
81
    # We first decide the scale.
82
    if fvalue.ndim < 1:
83
      raise ValueError(
84
          f'Input array {fvalue} must have a strictly positive number of '
85
          'dimensions.')
86

87
    max_abs = jnp.max(jnp.abs(fvalue), axis=0)
88
    bucket_size = max_abs / num_buckets
89
    bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
90
    # To avoid divide by 0.0
91
    bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
92
                           jnp.ones_like(bs_expanded))
93
    ratio = fvalue / bs_nonzero
94
    # We use rounding to remove bias.
95
    quantized = jnp.round(ratio)
96
    return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
97

98
  def to_float(self):
99
    """Returns the float value."""
100
    if isinstance(self.quantized, list) and not self.quantized:
101
      return self.quantized
102

103
    if self.quantized_dtype == jnp.float32:
104
      return self.quantized
105

106
    if self.quantized_dtype == jnp.bfloat16:
107
      return self.quantized.astype(jnp.float32)
108

109
    float_dtype = self.bucket_size.dtype
110
    bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
111
    val = self.quantized.astype(float_dtype) * bucket_size
112
    if self.extract_diagonal:
113
      val += jnp.diag(self.diagonal)
114
    return val
115

116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.