pytorch

Форк
0
/
quantization.cpp 
66 строк · 2.2 Кб
1
#include <ATen/Context.h>
2
#include <torch/csrc/jit/mobile/module.h>
3
#include <torch/csrc/jit/mobile/quantization.h>
4

5
namespace torch {
6
namespace jit {
7
namespace mobile {
8
namespace quantization {
9

10
void PTQQuanizationHelper::quantize_dynamic(
11
    torch::jit::mobile::Module& m,
12
    const std::string& method_name) {
13
  at::globalContext().setReleaseWeightsWhenPrepacking(false);
14
  std::string reset_observers_method_name = "reset_observers_" + method_name;
15
  std::string observe_method_name = "observe_" + method_name;
16
  std::string quantize_method_name = "quantize_" + method_name;
17
  std::string quantized_method_name = "quantized_" + method_name;
18

19
  TORCH_CHECK(
20
      m.find_method(reset_observers_method_name).has_value(),
21
      "PTQ ready module must have",
22
      reset_observers_method_name,
23
      " method.");
24
  TORCH_CHECK(
25
      m.find_method(observe_method_name),
26
      "PTQ ready module must have",
27
      reset_observers_method_name,
28
      " method.");
29
  TORCH_CHECK(
30
      m.find_method(quantize_method_name),
31
      "PTQ ready module must have",
32
      quantize_method_name,
33
      " method.");
34
  TORCH_CHECK(
35
      m.find_method(quantized_method_name),
36
      "PTQ ready module must have",
37
      quantized_method_name,
38
      " method.");
39
  TORCH_CHECK(
40
      m.find_method("get_all_bundled_inputs"),
41
      "PTQ ready module must have get_all_bundled_inputs method.");
42

43
  auto inputs = m.run_method("get_all_bundled_inputs")
44
                    .toList()
45
                    .get(0)
46
                    .toTupleRef()
47
                    .elements()
48
                    .vec();
49
  m.get_method(reset_observers_method_name)({});
50
  m.get_method(observe_method_name)(inputs);
51
  m.get_method(quantize_method_name)(inputs);
52

53
  m.compareMethodSchemas(method_name, quantized_method_name);
54
  m.unsafeRemoveMethod(method_name);
55
  const Function& to_be_copied =
56
      m.find_method(quantized_method_name).value().function();
57
  m.unsafeCopyMethod(method_name, to_be_copied);
58
  m.unsafeRemoveMethod(quantized_method_name);
59
  m.unsafeRemoveMethod(quantize_method_name);
60
  m.unsafeRemoveMethod(observe_method_name);
61
  m.unsafeRemoveMethod(reset_observers_method_name);
62
}
63
} // namespace quantization
64
} // namespace mobile
65
} // namespace jit
66
} // namespace torch
67

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.