pytorch
57 строк · 2.2 Кб
1from typing import Any, Callable
2
3import torch
4
5
6def setup_baseline():
7from torchao.quantization.utils import recommended_inductor_config_setter
8
9recommended_inductor_config_setter()
10torch._dynamo.config.automatic_dynamic_shapes = False
11torch._dynamo.config.cache_size_limit = 10000
12
13
14def torchao_optimize_ctx(quantization: str):
15from torchao.quantization.quant_api import (
16autoquant,
17int4_weight_only,
18int8_dynamic_activation_int8_weight,
19int8_weight_only,
20quantize_,
21)
22from torchao.utils import unwrap_tensor_subclass
23
24def inner(model_iter_fn: Callable):
25def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
26if getattr(module, "_quantized", None) is None:
27if quantization == "int8dynamic":
28quantize_(
29module,
30int8_dynamic_activation_int8_weight(),
31set_inductor_config=False,
32)
33elif quantization == "int8weightonly":
34quantize_(module, int8_weight_only(), set_inductor_config=False)
35elif quantization == "int4weightonly":
36quantize_(module, int4_weight_only(), set_inductor_config=False)
37if quantization == "autoquant":
38autoquant(module, error_on_unseen=False, set_inductor_config=False)
39if isinstance(example_inputs, dict):
40module(**example_inputs)
41else:
42module(*example_inputs)
43from torchao.quantization.autoquant import AUTOQUANT_CACHE
44
45if len(AUTOQUANT_CACHE) == 0:
46raise Exception( # noqa: TRY002`
47"NotAutoquantizable"
48f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
49)
50else:
51unwrap_tensor_subclass(module)
52setattr(module, "_quantized", True) # noqa: B010
53model_iter_fn(module, example_inputs)
54
55return _torchao_apply
56
57return inner
58