google-research
114 строк · 4.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Helper routines for quantization."""
17
18from typing import Any19
20import chex21from flax import struct22import jax.numpy as jnp23
24
25# pylint:disable=no-value-for-parameter
26@struct.dataclass27class QuantizedValue:28"""State associated with quantized value."""29quantized: chex.Array30diagonal: chex.Array # Diagonal (if extract_diagonal is set)31bucket_size: chex.Array32quantized_dtype: jnp.dtype = struct.field(33pytree_node=False) # Dtype for the quantized value.34extract_diagonal: bool = struct.field(35pytree_node=False) # In case its centered.36shape: Any = struct.field(pytree_node=False) # Shape of the tensor.37
38@classmethod39def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):40if isinstance(fvalue, list) and not fvalue:41return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])42quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(43fvalue, quantized_dtype, extract_diagonal)44return QuantizedValue(quantized, diagonal_fvalue, bucket_size,45quantized_dtype, extract_diagonal,46list(quantized.shape))47
48# Quantization is from Lingvo JAX optimizers.49# We extend it for int16 quantization of PSD matrices.50@classmethod51def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):52"""Returns quantized value and the bucket."""53if quantized_dtype == jnp.float32:54return fvalue, [], []55elif quantized_dtype == jnp.bfloat16:56return fvalue.astype(jnp.bfloat16), [], []57
58float_dtype = fvalue.dtype59if quantized_dtype == jnp.int8:60# value -128 is not used.61num_buckets = jnp.array(127.0, dtype=float_dtype)62elif quantized_dtype == jnp.int16:63# value -32768 is not used.64num_buckets = jnp.array(32767.0, dtype=float_dtype)65else:66raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')67# max value is mapped to num_buckets68
69if extract_diagonal and fvalue.ndim != 2:70raise ValueError(71f'Input array {fvalue} must be 2D to work with extract_diagonal.')72
73diagonal_fvalue = []74if extract_diagonal:75diagonal_fvalue = jnp.diag(fvalue)76# Remove the diagonal entries.77fvalue = fvalue - jnp.diag(diagonal_fvalue)78
79# TODO(rohananil): Extend this by making use of information about the blocks80# SM3 style which will be useful for diagonal statistics81# We first decide the scale.82if fvalue.ndim < 1:83raise ValueError(84f'Input array {fvalue} must have a strictly positive number of '85'dimensions.')86
87max_abs = jnp.max(jnp.abs(fvalue), axis=0)88bucket_size = max_abs / num_buckets89bs_expanded = bucket_size[jnp.newaxis, Ellipsis]90# To avoid divide by 0.091bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,92jnp.ones_like(bs_expanded))93ratio = fvalue / bs_nonzero94# We use rounding to remove bias.95quantized = jnp.round(ratio)96return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size97
98def to_float(self):99"""Returns the float value."""100if isinstance(self.quantized, list) and not self.quantized:101return self.quantized102
103if self.quantized_dtype == jnp.float32:104return self.quantized105
106if self.quantized_dtype == jnp.bfloat16:107return self.quantized.astype(jnp.float32)108
109float_dtype = self.bucket_size.dtype110bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]111val = self.quantized.astype(float_dtype) * bucket_size112if self.extract_diagonal:113val += jnp.diag(self.diagonal)114return val115
116