2
#include <torch/torch.h>
3
#include <ATen/record_function.h>
5
#include "c10/util/Flags.h"
11
C10_DEFINE_int(iter, 10000, "Number of iterations");
12
C10_DEFINE_int(sampled_iter, 10e6,
13
"Number of iterations for the sampled observer benchmark");
16
const int kTensorSize = 16;
17
const int kSmallTensorSize = 1;
18
const float kLowSamplingProb = 0.0001;
22
double sampling_prob = 1.0,
23
at::RecordFunctionCallback::StartCallback fn =
24
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
25
auto cb = at::RecordFunctionCallback(
27
[](const at::RecordFunction&, at::ObserverContext*) {})
29
if (sampling_prob < 1.0) {
30
cb.samplingProb(sampling_prob);
32
at::addGlobalCallback(cb);
35
float runTensorGEMMBench(int tensor_size, int iter) {
36
typedef std::chrono::high_resolution_clock clock;
37
typedef std::chrono::microseconds us;
38
std::chrono::time_point<clock> start_time = clock::now();
39
auto inp = torch::randn({tensor_size, tensor_size});
40
for (auto idx = 0; idx < iter; ++idx) {
43
auto duration = static_cast<float>(
44
std::chrono::duration_cast<us>(clock::now() - start_time).count());
48
float runPureRecordFunctionBench(int iter) {
49
typedef std::chrono::high_resolution_clock clock;
50
typedef std::chrono::microseconds us;
51
std::chrono::time_point<clock> start_time = clock::now();
52
for (auto idx = 0; idx < iter; ++idx) {
53
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE);
54
if (step_callbacks.has_value()) {
55
at::RecordFunction guard(std::move(*step_callbacks));
56
guard.before("Test", -1);
59
auto duration = static_cast<float>(
60
std::chrono::duration_cast<us>(clock::now() - start_time).count());
66
for (auto tensor_size : std::set<int>({kSmallTensorSize, kTensorSize})) {
67
duration = runTensorGEMMBench(tensor_size, FLAGS_iter);
68
std::cout << "Tensor GEMM benchmark ("
72
<< ", " << FLAGS_iter << "): " << duration
73
<< " us." << std::endl;
75
duration = runPureRecordFunctionBench(FLAGS_iter);
76
std::cout << "Pure RecordFunction benchmark ("
77
<< FLAGS_iter << "): "
79
<< " us." << std::endl;
82
int main(int argc, char** argv) {
83
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
84
std::cout << "Failed to parse command line flags" << std::endl;
88
at::enableRecordFunction();
91
std::cout << "Warm up" << std::endl;
94
std::cout << "Running without observers" << std::endl;
98
std::cout << "Running with empty non-sampled observer" << std::endl;
100
at::clearCallbacks();
102
addTestCallback(kLowSamplingProb);
103
std::cout << "Running with empty sampled observer" << std::endl;
105
at::clearCallbacks();
107
std::cout << "Checking number of sampled observer invocations" << std::endl;
108
static int cb_count = 0;
111
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
117
auto duration = runPureRecordFunctionBench(FLAGS_sampled_iter);
119
std::cout << "Pure RecordFunction runtime of " << FLAGS_sampled_iter
120
<< " iterations: " << duration
121
<< " us, number of callback invocations: " << cb_count
122
<< ", expected number: ~" << (int)(FLAGS_sampled_iter * kLowSamplingProb)
123
<< " invocations" << std::endl;
125
at::clearCallbacks();