pytorch

Форк
0
/
at_launch_benchmark.cc 
94 строки · 2.5 Кб
1
#include "ATen/Parallel.h"
2

3
#include "c10/util/Flags.h"
4
#include "caffe2/core/init.h"
5

6
#include <atomic>
7
#include <chrono>
8
#include <condition_variable>
9
#include <iostream>
10
#include <mutex>
11
#include <ctime>
12

13
C10_DEFINE_int(iter, 10e4, "Number of at::launch iterations (tasks)");
14
C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations")
15
C10_DEFINE_int(inter_op_threads, 0, "Number of inter-op threads");
16
C10_DEFINE_int(benchmark_iter, 3, "Number of times to run benchmark")
17

18
namespace {
19
int iter = 0;
20
std::atomic<int> counter{0};
21
std::condition_variable cv;
22
std::mutex mutex;
23
}
24

25
 void launch_tasks() {
26
  at::launch([]() {
27
    at::launch([](){
28
      at::launch([]() {
29
        auto cur_ctr = ++counter;
30
        if (cur_ctr == iter) {
31
          std::unique_lock<std::mutex> lk(mutex);
32
          cv.notify_one();
33
        }
34
      });
35
    });
36
  });
37
}
38

39
void launch_tasks_and_wait(int tasks_num) {
40
  iter = tasks_num;
41
  counter = 0;
42
  for (auto idx = 0; idx < iter; ++idx) {
43
    launch_tasks();
44
  }
45
  {
46
    std::unique_lock<std::mutex> lk(mutex);
47
    while (counter < iter) {
48
      cv.wait(lk);
49
    }
50
  }
51
}
52

53
int main(int argc, char** argv) {
54
  if (!c10::ParseCommandLineFlags(&argc, &argv)) {
55
    std::cout << "Failed to parse command line flags" << std::endl;
56
    return -1;
57
  }
58
  caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
59
  at::init_num_threads();
60

61
  if (FLAGS_inter_op_threads > 0) {
62
    at::set_num_interop_threads(FLAGS_inter_op_threads);
63
  }
64

65
  typedef std::chrono::high_resolution_clock clock;
66
  typedef std::chrono::milliseconds ms;
67

68
  std::cout << "Launching " << FLAGS_warmup_iter << " warmup tasks using "
69
            << at::get_num_interop_threads() << " threads "
70
            << std::endl;
71

72
  std::chrono::time_point<clock> start_time = clock::now();
73
  launch_tasks_and_wait(FLAGS_warmup_iter);
74
  auto duration = static_cast<float>(
75
      std::chrono::duration_cast<ms>(clock::now() - start_time).count());
76

77
  std::cout << "Warmup time: " << duration << " ms." << std::endl;
78

79
  std::cout << "Launching " << FLAGS_iter << " tasks using "
80
            << at::get_num_interop_threads() << " threads "
81
            << std::endl;
82

83
  for (auto bench_iter = 0; bench_iter < FLAGS_benchmark_iter; ++bench_iter) {
84
    start_time = clock::now();
85
    launch_tasks_and_wait(FLAGS_iter);
86
    duration = static_cast<float>(
87
        std::chrono::duration_cast<ms>(clock::now() - start_time).count());
88

89
    std::cout << "Time to run " << iter << " iterations "
90
              << (duration/1000.0) << " s." << std::endl;
91
  }
92

93
  return 0;
94
}
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.