pytorch

Форк
0
/
test_counters.cpp 
305 строк · 6.6 Кб
1
#include <gtest/gtest.h>
2

3
#include <thread>
4

5
#include <torch/csrc/monitor/counters.h>
6
#include <torch/csrc/monitor/events.h>
7

8
using namespace torch::monitor;
9

10
TEST(MonitorTest, CounterDouble) {
11
  Stat<double> a{
12
      "a",
13
      {Aggregation::MEAN, Aggregation::COUNT},
14
      std::chrono::milliseconds(100000),
15
      2,
16
  };
17
  a.add(5.0);
18
  ASSERT_EQ(a.count(), 1);
19
  a.add(6.0);
20
  ASSERT_EQ(a.count(), 0);
21

22
  auto stats = a.get();
23
  std::unordered_map<Aggregation, double, AggregationHash> want = {
24
      {Aggregation::MEAN, 5.5},
25
      {Aggregation::COUNT, 2.0},
26
  };
27
  ASSERT_EQ(stats, want);
28
}
29

30
TEST(MonitorTest, CounterInt64Sum) {
31
  Stat<int64_t> a{
32
      "a",
33
      {Aggregation::SUM},
34
      std::chrono::milliseconds(100000),
35
      2,
36
  };
37
  a.add(5);
38
  a.add(6);
39
  auto stats = a.get();
40
  std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
41
      {Aggregation::SUM, 11},
42
  };
43
  ASSERT_EQ(stats, want);
44
}
45

46
TEST(MonitorTest, CounterInt64Value) {
47
  Stat<int64_t> a{
48
      "a",
49
      {Aggregation::VALUE},
50
      std::chrono::milliseconds(100000),
51
      2,
52
  };
53
  a.add(5);
54
  a.add(6);
55
  auto stats = a.get();
56
  std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
57
      {Aggregation::VALUE, 6},
58
  };
59
  ASSERT_EQ(stats, want);
60
}
61

62
TEST(MonitorTest, CounterInt64Mean) {
63
  Stat<int64_t> a{
64
      "a",
65
      {Aggregation::MEAN},
66
      std::chrono::milliseconds(100000),
67
      2,
68
  };
69
  {
70
    // zero samples case
71
    auto stats = a.get();
72
    std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
73
        {Aggregation::MEAN, 0},
74
    };
75
    ASSERT_EQ(stats, want);
76
  }
77

78
  a.add(0);
79
  a.add(10);
80

81
  {
82
    auto stats = a.get();
83
    std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
84
        {Aggregation::MEAN, 5},
85
    };
86
    ASSERT_EQ(stats, want);
87
  }
88
}
89

90
TEST(MonitorTest, CounterInt64Count) {
91
  Stat<int64_t> a{
92
      "a",
93
      {Aggregation::COUNT},
94
      std::chrono::milliseconds(100000),
95
      2,
96
  };
97
  ASSERT_EQ(a.count(), 0);
98
  a.add(0);
99
  ASSERT_EQ(a.count(), 1);
100
  a.add(10);
101
  ASSERT_EQ(a.count(), 0);
102

103
  auto stats = a.get();
104
  std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
105
      {Aggregation::COUNT, 2},
106
  };
107
  ASSERT_EQ(stats, want);
108
}
109

110
TEST(MonitorTest, CounterInt64MinMax) {
111
  Stat<int64_t> a{
112
      "a",
113
      {Aggregation::MIN, Aggregation::MAX},
114
      std::chrono::milliseconds(100000),
115
      6,
116
  };
117
  {
118
    auto stats = a.get();
119
    std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
120
        {Aggregation::MAX, 0},
121
        {Aggregation::MIN, 0},
122
    };
123
    ASSERT_EQ(stats, want);
124
  }
125

126
  a.add(0);
127
  a.add(5);
128
  a.add(-5);
129
  a.add(-6);
130
  a.add(9);
131
  a.add(2);
132
  {
133
    auto stats = a.get();
134
    std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
135
        {Aggregation::MAX, 9},
136
        {Aggregation::MIN, -6},
137
    };
138
    ASSERT_EQ(stats, want);
139
  }
140
}
141

142
TEST(MonitorTest, CounterInt64WindowSize) {
143
  Stat<int64_t> a{
144
      "a",
145
      {Aggregation::COUNT, Aggregation::SUM},
146
      std::chrono::milliseconds(100000),
147
      /*windowSize=*/3,
148
  };
149
  a.add(1);
150
  a.add(2);
151
  ASSERT_EQ(a.count(), 2);
152
  a.add(3);
153
  ASSERT_EQ(a.count(), 0);
154

155
  // after logging max for window, should be zero
156
  a.add(4);
157
  ASSERT_EQ(a.count(), 0);
158

159
  auto stats = a.get();
160
  std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
161
      {Aggregation::COUNT, 3},
162
      {Aggregation::SUM, 6},
163
  };
164
  ASSERT_EQ(stats, want);
165
}
166

167
TEST(MonitorTest, CounterInt64WindowSizeHuge) {
168
  Stat<int64_t> a{
169
      "a",
170
      {Aggregation::COUNT, Aggregation::SUM},
171
      std::chrono::hours(24 * 365 * 10), // 10 years
172
      /*windowSize=*/3,
173
  };
174
  a.add(1);
175
  a.add(2);
176
  ASSERT_EQ(a.count(), 2);
177
  a.add(3);
178
  ASSERT_EQ(a.count(), 0);
179

180
  // after logging max for window, should be zero
181
  a.add(4);
182
  ASSERT_EQ(a.count(), 0);
183

184
  auto stats = a.get();
185
  std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
186
      {Aggregation::COUNT, 3},
187
      {Aggregation::SUM, 6},
188
  };
189
  ASSERT_EQ(stats, want);
190
}
191

192
template <typename T>
193
struct TestStat : public Stat<T> {
194
  uint64_t mockWindowId{1};
195

196
  TestStat(
197
      std::string name,
198
      std::initializer_list<Aggregation> aggregations,
199
      std::chrono::milliseconds windowSize,
200
      int64_t maxSamples = std::numeric_limits<int64_t>::max())
201
      : Stat<T>(name, aggregations, windowSize, maxSamples) {}
202

203
  uint64_t currentWindowId() const override {
204
    return mockWindowId;
205
  }
206
};
207

208
struct AggregatingEventHandler : public EventHandler {
209
  std::vector<Event> events;
210

211
  void handle(const Event& e) override {
212
    events.emplace_back(e);
213
  }
214
};
215

216
template <typename T>
217
struct HandlerGuard {
218
  std::shared_ptr<T> handler;
219

220
  HandlerGuard() : handler(std::make_shared<T>()) {
221
    registerEventHandler(handler);
222
  }
223

224
  ~HandlerGuard() {
225
    unregisterEventHandler(handler);
226
  }
227
};
228

229
TEST(MonitorTest, Stat) {
230
  HandlerGuard<AggregatingEventHandler> guard;
231

232
  Stat<int64_t> a{
233
      "a",
234
      {Aggregation::COUNT, Aggregation::SUM},
235
      std::chrono::milliseconds(1),
236
  };
237
  ASSERT_EQ(guard.handler->events.size(), 0);
238

239
  a.add(1);
240
  ASSERT_LE(a.count(), 1);
241

242
  std::this_thread::sleep_for(std::chrono::milliseconds(2));
243
  a.add(2);
244
  ASSERT_LE(a.count(), 1);
245

246
  ASSERT_GE(guard.handler->events.size(), 1);
247
  ASSERT_LE(guard.handler->events.size(), 2);
248
}
249

250
TEST(MonitorTest, StatEvent) {
251
  HandlerGuard<AggregatingEventHandler> guard;
252

253
  TestStat<int64_t> a{
254
      "a",
255
      {Aggregation::COUNT, Aggregation::SUM},
256
      std::chrono::milliseconds(1),
257
  };
258
  ASSERT_EQ(guard.handler->events.size(), 0);
259

260
  a.add(1);
261
  ASSERT_EQ(a.count(), 1);
262
  a.add(2);
263
  ASSERT_EQ(a.count(), 2);
264
  ASSERT_EQ(guard.handler->events.size(), 0);
265

266
  a.mockWindowId = 100;
267

268
  a.add(3);
269
  ASSERT_LE(a.count(), 1);
270

271
  ASSERT_EQ(guard.handler->events.size(), 1);
272
  Event e = guard.handler->events.at(0);
273
  ASSERT_EQ(e.name, "torch.monitor.Stat");
274
  ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
275
  std::unordered_map<std::string, data_value_t> data{
276
      {"a.sum", 3L},
277
      {"a.count", 2L},
278
  };
279
  ASSERT_EQ(e.data, data);
280
}
281

282
TEST(MonitorTest, StatEventDestruction) {
283
  HandlerGuard<AggregatingEventHandler> guard;
284

285
  {
286
    TestStat<int64_t> a{
287
        "a",
288
        {Aggregation::COUNT, Aggregation::SUM},
289
        std::chrono::hours(10),
290
    };
291
    a.add(1);
292
    ASSERT_EQ(a.count(), 1);
293
    ASSERT_EQ(guard.handler->events.size(), 0);
294
  }
295
  ASSERT_EQ(guard.handler->events.size(), 1);
296

297
  Event e = guard.handler->events.at(0);
298
  ASSERT_EQ(e.name, "torch.monitor.Stat");
299
  ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
300
  std::unordered_map<std::string, data_value_t> data{
301
      {"a.sum", 1L},
302
      {"a.count", 1L},
303
  };
304
  ASSERT_EQ(e.data, data);
305
}
306

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.