1
#include <ATen/Context.h>
2
#include <torch/csrc/jit/mobile/module.h>
3
#include <torch/csrc/jit/mobile/quantization.h>
8
namespace quantization {
10
void PTQQuanizationHelper::quantize_dynamic(
11
torch::jit::mobile::Module& m,
12
const std::string& method_name) {
13
at::globalContext().setReleaseWeightsWhenPrepacking(false);
14
std::string reset_observers_method_name = "reset_observers_" + method_name;
15
std::string observe_method_name = "observe_" + method_name;
16
std::string quantize_method_name = "quantize_" + method_name;
17
std::string quantized_method_name = "quantized_" + method_name;
20
m.find_method(reset_observers_method_name).has_value(),
21
"PTQ ready module must have",
22
reset_observers_method_name,
25
m.find_method(observe_method_name),
26
"PTQ ready module must have",
27
reset_observers_method_name,
30
m.find_method(quantize_method_name),
31
"PTQ ready module must have",
35
m.find_method(quantized_method_name),
36
"PTQ ready module must have",
37
quantized_method_name,
40
m.find_method("get_all_bundled_inputs"),
41
"PTQ ready module must have get_all_bundled_inputs method.");
43
auto inputs = m.run_method("get_all_bundled_inputs")
49
m.get_method(reset_observers_method_name)({});
50
m.get_method(observe_method_name)(inputs);
51
m.get_method(quantize_method_name)(inputs);
53
m.compareMethodSchemas(method_name, quantized_method_name);
54
m.unsafeRemoveMethod(method_name);
55
const Function& to_be_copied =
56
m.find_method(quantized_method_name).value().function();
57
m.unsafeCopyMethod(method_name, to_be_copied);
58
m.unsafeRemoveMethod(quantized_method_name);
59
m.unsafeRemoveMethod(quantize_method_name);
60
m.unsafeRemoveMethod(observe_method_name);
61
m.unsafeRemoveMethod(reset_observers_method_name);
63
} // namespace quantization