pytorch

Форк
0
/
torchao_backend.py 
57 строк · 2.2 Кб
1
from typing import Any, Callable
2

3
import torch
4

5

6
def setup_baseline():
7
    from torchao.quantization.utils import recommended_inductor_config_setter
8

9
    recommended_inductor_config_setter()
10
    torch._dynamo.config.automatic_dynamic_shapes = False
11
    torch._dynamo.config.cache_size_limit = 10000
12

13

14
def torchao_optimize_ctx(quantization: str):
15
    from torchao.quantization.quant_api import (
16
        autoquant,
17
        int4_weight_only,
18
        int8_dynamic_activation_int8_weight,
19
        int8_weight_only,
20
        quantize_,
21
    )
22
    from torchao.utils import unwrap_tensor_subclass
23

24
    def inner(model_iter_fn: Callable):
25
        def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
26
            if getattr(module, "_quantized", None) is None:
27
                if quantization == "int8dynamic":
28
                    quantize_(
29
                        module,
30
                        int8_dynamic_activation_int8_weight(),
31
                        set_inductor_config=False,
32
                    )
33
                elif quantization == "int8weightonly":
34
                    quantize_(module, int8_weight_only(), set_inductor_config=False)
35
                elif quantization == "int4weightonly":
36
                    quantize_(module, int4_weight_only(), set_inductor_config=False)
37
                if quantization == "autoquant":
38
                    autoquant(module, error_on_unseen=False, set_inductor_config=False)
39
                    if isinstance(example_inputs, dict):
40
                        module(**example_inputs)
41
                    else:
42
                        module(*example_inputs)
43
                    from torchao.quantization.autoquant import AUTOQUANT_CACHE
44

45
                    if len(AUTOQUANT_CACHE) == 0:
46
                        raise Exception(  # noqa: TRY002`
47
                            "NotAutoquantizable"
48
                            f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
49
                        )
50
                else:
51
                    unwrap_tensor_subclass(module)
52
                setattr(module, "_quantized", True)  # noqa: B010
53
            model_iter_fn(module, example_inputs)
54

55
        return _torchao_apply
56

57
    return inner
58

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.