pytorch

Форк
0
/
record_function_benchmark.cc 
126 строк · 3.8 Кб
1

2
#include <torch/torch.h>
3
#include <ATen/record_function.h>
4

5
#include "c10/util/Flags.h"
6

7
#include <chrono>
8
#include <iostream>
9
#include <ctime>
10

11
C10_DEFINE_int(iter, 10000, "Number of iterations");
12
C10_DEFINE_int(sampled_iter, 10e6,
13
    "Number of iterations for the sampled observer benchmark");
14

15
namespace {
16
const int kTensorSize = 16;
17
const int kSmallTensorSize = 1;
18
const float kLowSamplingProb = 0.0001;
19
}
20

21
void addTestCallback(
22
    double sampling_prob = 1.0,
23
    at::RecordFunctionCallback::StartCallback fn =
24
        [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
25
  auto cb = at::RecordFunctionCallback(
26
      fn,
27
      [](const at::RecordFunction&, at::ObserverContext*) {})
28
    .needsInputs(false);
29
  if (sampling_prob < 1.0) {
30
    cb.samplingProb(sampling_prob);
31
  }
32
  at::addGlobalCallback(cb);
33
}
34

35
float runTensorGEMMBench(int tensor_size, int iter) {
36
  typedef std::chrono::high_resolution_clock clock;
37
  typedef std::chrono::microseconds us;
38
  std::chrono::time_point<clock> start_time = clock::now();
39
  auto inp = torch::randn({tensor_size, tensor_size});
40
  for (auto idx = 0; idx < iter; ++idx) {
41
    torch::mm(inp, inp);
42
  }
43
  auto duration = static_cast<float>(
44
      std::chrono::duration_cast<us>(clock::now() - start_time).count());
45
  return duration;
46
}
47

48
float runPureRecordFunctionBench(int iter) {
49
  typedef std::chrono::high_resolution_clock clock;
50
  typedef std::chrono::microseconds us;
51
  std::chrono::time_point<clock> start_time = clock::now();
52
  for (auto idx = 0; idx < iter; ++idx) {
53
    auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE);
54
    if (step_callbacks.has_value()) {
55
      at::RecordFunction guard(std::move(*step_callbacks));
56
      guard.before("Test", -1);
57
    }
58
  }
59
  auto duration = static_cast<float>(
60
      std::chrono::duration_cast<us>(clock::now() - start_time).count());
61
  return duration;
62
}
63

64
void runBenchmark() {
65
  float duration = 0;
66
  for (auto tensor_size : std::set<int>({kSmallTensorSize, kTensorSize})) {
67
    duration = runTensorGEMMBench(tensor_size, FLAGS_iter);
68
    std::cout << "Tensor GEMM benchmark ("
69
              << tensor_size
70
              << "x"
71
              << tensor_size
72
              << ", " << FLAGS_iter << "): " << duration
73
              << " us." << std::endl;
74
  }
75
  duration = runPureRecordFunctionBench(FLAGS_iter);
76
  std::cout << "Pure RecordFunction benchmark ("
77
            << FLAGS_iter << "): "
78
            << duration
79
            << " us." << std::endl;
80
}
81

82
int main(int argc, char** argv) {
83
  if (!c10::ParseCommandLineFlags(&argc, &argv)) {
84
    std::cout << "Failed to parse command line flags" << std::endl;
85
    return -1;
86
  }
87

88
  at::enableRecordFunction();
89
  at::clearCallbacks();
90

91
  std::cout << "Warm up" << std::endl;
92
  runBenchmark();
93

94
  std::cout << "Running without observers" << std::endl;
95
  runBenchmark();
96

97
  addTestCallback();
98
  std::cout << "Running with empty non-sampled observer" << std::endl;
99
  runBenchmark();
100
  at::clearCallbacks();
101

102
  addTestCallback(kLowSamplingProb);
103
  std::cout << "Running with empty sampled observer" << std::endl;
104
  runBenchmark();
105
  at::clearCallbacks();
106

107
  std::cout << "Checking number of sampled observer invocations" << std::endl;
108
  static int cb_count = 0;
109
  addTestCallback(
110
      kLowSamplingProb,
111
      [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
112
        ++cb_count;
113
        return nullptr;
114
      }
115
  );
116

117
  auto duration = runPureRecordFunctionBench(FLAGS_sampled_iter);
118

119
  std::cout << "Pure RecordFunction runtime of " << FLAGS_sampled_iter
120
            << " iterations: " << duration
121
            << " us, number of callback invocations: " << cb_count
122
            << ", expected number: ~" << (int)(FLAGS_sampled_iter * kLowSamplingProb)
123
            << " invocations" << std::endl;
124

125
  at::clearCallbacks();
126
  return 0;
127
}
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.