pytorch

Форк
0
/
intra_inter_benchmark.cc 
166 строк · 4.9 Кб
1
#include "ATen/ATen.h"
2
#include "ATen/Parallel.h"
3

4
#include "c10/util/Flags.h"
5
#include "caffe2/core/init.h"
6

7
#include <chrono>
8
#include <condition_variable>
9
#include <ctime>
10
#include <iostream>
11
#include <mutex>
12
#include <thread>
13

14
C10_DEFINE_int(iter_pow, 10, "Number of tasks, 2^N");
15
C10_DEFINE_int(sub_iter, 1024, "Number of subtasks");
16
C10_DEFINE_int(warmup_iter_pow, 3, "Number of warmup tasks, 2^N");
17
C10_DEFINE_int(inter_op_threads, 0, "Number of inter-op threads");
18
C10_DEFINE_int(intra_op_threads, 0, "Number of intra-op threads");
19
C10_DEFINE_int(tensor_dim, 50, "Tensor dim");
20
C10_DEFINE_int(benchmark_iter, 10, "Number of times to run benchmark")
21
C10_DEFINE_bool(extra_stats, false,
22
    "Collect extra stats; warning: skews results");
23
C10_DEFINE_string(task_type, "add", "Tensor operation: add or mm");
24

25
namespace {
26
std::atomic<int> counter{0};
27
int overall_tasks = 0;
28
std::condition_variable cv;
29
std::mutex tasks_mutex;
30
bool run_mm = false;
31

32
std::mutex stats_mutex;
33
std::unordered_set<std::thread::id> tids;
34
}
35

36
void wait() {
37
  std::unique_lock<std::mutex> lk(tasks_mutex);
38
  while (counter < overall_tasks) {
39
    cv.wait(lk);
40
  }
41
}
42

43
void _launch_tasks_tree(
44
    int level, int end_level, at::Tensor& left, at::Tensor& right) {
45
  if (level == end_level) {
46
    at::parallel_for(0, FLAGS_sub_iter, 1,
47
        [&left, &right](int64_t begin, int64_t end) {
48
      if (FLAGS_extra_stats) {
49
        std::unique_lock<std::mutex> lk(stats_mutex);
50
        tids.insert(std::this_thread::get_id());
51
      }
52
      for (auto k = begin; k < end; ++k) {
53
        if (run_mm) {
54
          left.mm(right);
55
        } else {
56
          left.add(right);
57
        }
58
        auto cur_ctr = ++counter;
59
        if (cur_ctr == overall_tasks) {
60
          std::unique_lock<std::mutex> lk(tasks_mutex);
61
          cv.notify_one();
62
        }
63
      }
64
    });
65
  } else {
66
    at::launch([&left, &right, level, end_level]() {
67
      _launch_tasks_tree(level + 1, end_level, left, right);
68
    });
69
    at::launch([&left, &right, level, end_level]() {
70
      _launch_tasks_tree(level + 1, end_level, left, right);
71
    });
72
  }
73
};
74

75
void launch_tasks_and_wait(at::Tensor& left, at::Tensor& right, int iter_pow) {
76
  overall_tasks = pow(2, iter_pow) * FLAGS_sub_iter;
77
  counter = 0;
78

79
  _launch_tasks_tree(0, iter_pow, left, right);
80
  wait();
81
}
82

83
void reset_extra_stats() {
84
  tids.clear();
85
}
86

87
void print_extra_stats() {
88
  std::cout << "# threads: " << tids.size() << std::endl;
89
}
90

91
void print_runtime_stats(const std::vector<float>& runtimes) {
92
  TORCH_INTERNAL_ASSERT(!runtimes.empty());
93
  float sum = 0.0;
94
  float sqr_sum = 0.0;
95
  size_t N = runtimes.size();
96
  for (size_t idx = 0; idx < N; ++idx) {
97
    sum += runtimes[idx];
98
    sqr_sum += runtimes[idx] * runtimes[idx];
99
  }
100
  float mean = sum / N;
101
  float sd = std::sqrt(sqr_sum / N - mean * mean);
102
  std::cout << "N = " << N << ", mean = " << mean << ", sd = " << sd
103
            << std::endl;
104
}
105

106
int main(int argc, char** argv) {
107
  if (!c10::ParseCommandLineFlags(&argc, &argv)) {
108
    std::cout << "Failed to parse command line flags" << std::endl;
109
    return -1;
110
  }
111
  caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
112
  at::init_num_threads();
113

114
  if (FLAGS_inter_op_threads > 0) {
115
    at::set_num_interop_threads(FLAGS_inter_op_threads);
116
  }
117
  if (FLAGS_intra_op_threads > 0) {
118
    at::set_num_threads(FLAGS_intra_op_threads);
119
  }
120

121
  TORCH_CHECK(FLAGS_task_type == "add" || FLAGS_task_type == "mm");
122
  run_mm = FLAGS_task_type == "mm";
123

124
  auto left = at::ones({FLAGS_tensor_dim, FLAGS_tensor_dim}, at::kFloat);
125
  auto right = at::ones({FLAGS_tensor_dim, FLAGS_tensor_dim}, at::kFloat);
126

127
  std::cout << "Launching " << pow(2, FLAGS_warmup_iter_pow)
128
            << " warmup tasks" << std::endl;
129

130
  typedef std::chrono::high_resolution_clock clock;
131
  typedef std::chrono::milliseconds ms;
132

133
  std::chrono::time_point<clock> start_time = clock::now();
134
  launch_tasks_and_wait(left, right, FLAGS_warmup_iter_pow);
135
  auto duration = static_cast<float>(
136
      std::chrono::duration_cast<ms>(clock::now() - start_time).count());
137

138
  std::cout << "Warmup time: " << duration << " ms." << std::endl;
139

140
  std::cout << "Launching " << pow(2, FLAGS_iter_pow) << " tasks with "
141
            << FLAGS_sub_iter << " subtasks each, using "
142
            << at::get_num_interop_threads() << " inter-op threads and "
143
            << at::get_num_threads() << " intra-op threads, "
144
            << "tensor dim: " << FLAGS_tensor_dim
145
            << ", task type: " << FLAGS_task_type << std::endl;
146

147
  std::vector<float> runtimes;
148
  for (auto bench_iter = 0; bench_iter < FLAGS_benchmark_iter; ++bench_iter) {
149
    reset_extra_stats();
150
    start_time = clock::now();
151
    launch_tasks_and_wait(left, right, FLAGS_iter_pow);
152
    duration = static_cast<float>(
153
        std::chrono::duration_cast<ms>(clock::now() - start_time).count());
154
    runtimes.push_back(duration);
155

156
    if (FLAGS_extra_stats) {
157
      print_extra_stats();
158
    }
159

160
    std::cout << "Runtime: " << duration << " ms." << std::endl;
161
  }
162

163
  print_runtime_stats(runtimes);
164

165
  return 0;
166
}
167

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.