quanto

Форк
0
/
helpers.py 
87 строк · 3.1 Кб
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import functools
16
import gc
17

18
import pytest
19
import torch
20
from packaging import version
21

22
from quanto import absmax_scale, qint8, quantize_activation, quantize_weight
23

24

25
def torch_min_version(v):
26
    def torch_min_version_decorator(test):
27
        @functools.wraps(test)
28
        def test_wrapper(*args, **kwargs):
29
            if version.parse(torch.__version__) < version.parse(v):
30
                pytest.skip(f"Requires pytorch >= {v}")
31
            test(*args, **kwargs)
32

33
        return test_wrapper
34

35
    return torch_min_version_decorator
36

37

38
def device_eq(a, b):
39
    if a.type != b.type:
40
        return False
41
    a_index = a.index if a.index is not None else 0
42
    b_index = b.index if b.index is not None else 0
43
    return a_index == b_index
44

45

46
def random_tensor(shape, dtype=torch.float32):
47
    # Return a random tensor between -1. and 1.
48
    return torch.rand(shape, dtype=dtype) * 2 - 1
49

50

51
def random_qactivation(shape, qtype=qint8, dtype=torch.float32):
52
    t = random_tensor(shape, dtype)
53
    scale = absmax_scale(t, qtype=qtype)
54
    return quantize_activation(t, qtype=qtype, scale=scale)
55

56

57
def random_qweight(shape, qtype, dtype=torch.float32, axis=0, group_size=None):
58
    t = random_tensor(shape, dtype)
59
    return quantize_weight(t, qtype=qtype, axis=axis, group_size=group_size)
60

61

62
def assert_similar(a, b, atol=None, rtol=None):
63
    """Verify that the cosine similarity of the two inputs is close to 1.0 everywhere"""
64
    assert a.dtype == b.dtype
65
    assert a.shape == b.shape
66
    if atol is None:
67
        # We use torch finfo resolution
68
        atol = torch.finfo(a.dtype).resolution
69
    if rtol is None:
70
        # Please refer to that discussion for default rtol values based on the float type:
71
        # https://scicomp.stackexchange.com/questions/43111/float-equality-tolerance-for-single-and-half-precision
72
        rtol = {torch.float32: 1e-5, torch.float16: 1e-3, torch.bfloat16: 1e-1}[a.dtype]
73
    sim = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0)
74
    if not torch.allclose(sim, torch.tensor(1.0, dtype=sim.dtype), atol=atol, rtol=rtol):
75
        max_deviation = torch.min(sim)
76
        raise ValueError(f"Alignment {max_deviation:.8f} deviates too much from 1.0 with atol={atol}, rtol={rtol}")
77

78

79
def get_device_memory(device):
80
    gc.collect()
81
    if device.type == "cuda":
82
        torch.cuda.empty_cache()
83
        return torch.cuda.memory_allocated()
84
    elif device.type == "mps":
85
        torch.mps.empty_cache()
86
        return torch.mps.current_allocated_memory()
87
    return None
88

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.