1
#include <gtest/gtest.h>
5
#include <torch/csrc/monitor/counters.h>
6
#include <torch/csrc/monitor/events.h>
8
using namespace torch::monitor;
10
TEST(MonitorTest, CounterDouble) {
13
{Aggregation::MEAN, Aggregation::COUNT},
14
std::chrono::milliseconds(100000),
18
ASSERT_EQ(a.count(), 1);
20
ASSERT_EQ(a.count(), 0);
23
std::unordered_map<Aggregation, double, AggregationHash> want = {
24
{Aggregation::MEAN, 5.5},
25
{Aggregation::COUNT, 2.0},
27
ASSERT_EQ(stats, want);
30
TEST(MonitorTest, CounterInt64Sum) {
34
std::chrono::milliseconds(100000),
40
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
41
{Aggregation::SUM, 11},
43
ASSERT_EQ(stats, want);
46
TEST(MonitorTest, CounterInt64Value) {
50
std::chrono::milliseconds(100000),
56
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
57
{Aggregation::VALUE, 6},
59
ASSERT_EQ(stats, want);
62
TEST(MonitorTest, CounterInt64Mean) {
66
std::chrono::milliseconds(100000),
72
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
73
{Aggregation::MEAN, 0},
75
ASSERT_EQ(stats, want);
83
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
84
{Aggregation::MEAN, 5},
86
ASSERT_EQ(stats, want);
90
TEST(MonitorTest, CounterInt64Count) {
94
std::chrono::milliseconds(100000),
97
ASSERT_EQ(a.count(), 0);
99
ASSERT_EQ(a.count(), 1);
101
ASSERT_EQ(a.count(), 0);
103
auto stats = a.get();
104
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
105
{Aggregation::COUNT, 2},
107
ASSERT_EQ(stats, want);
110
TEST(MonitorTest, CounterInt64MinMax) {
113
{Aggregation::MIN, Aggregation::MAX},
114
std::chrono::milliseconds(100000),
118
auto stats = a.get();
119
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
120
{Aggregation::MAX, 0},
121
{Aggregation::MIN, 0},
123
ASSERT_EQ(stats, want);
133
auto stats = a.get();
134
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
135
{Aggregation::MAX, 9},
136
{Aggregation::MIN, -6},
138
ASSERT_EQ(stats, want);
142
TEST(MonitorTest, CounterInt64WindowSize) {
145
{Aggregation::COUNT, Aggregation::SUM},
146
std::chrono::milliseconds(100000),
151
ASSERT_EQ(a.count(), 2);
153
ASSERT_EQ(a.count(), 0);
155
// after logging max for window, should be zero
157
ASSERT_EQ(a.count(), 0);
159
auto stats = a.get();
160
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
161
{Aggregation::COUNT, 3},
162
{Aggregation::SUM, 6},
164
ASSERT_EQ(stats, want);
167
TEST(MonitorTest, CounterInt64WindowSizeHuge) {
170
{Aggregation::COUNT, Aggregation::SUM},
171
std::chrono::hours(24 * 365 * 10), // 10 years
176
ASSERT_EQ(a.count(), 2);
178
ASSERT_EQ(a.count(), 0);
180
// after logging max for window, should be zero
182
ASSERT_EQ(a.count(), 0);
184
auto stats = a.get();
185
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
186
{Aggregation::COUNT, 3},
187
{Aggregation::SUM, 6},
189
ASSERT_EQ(stats, want);
193
struct TestStat : public Stat<T> {
194
uint64_t mockWindowId{1};
198
std::initializer_list<Aggregation> aggregations,
199
std::chrono::milliseconds windowSize,
200
int64_t maxSamples = std::numeric_limits<int64_t>::max())
201
: Stat<T>(name, aggregations, windowSize, maxSamples) {}
203
uint64_t currentWindowId() const override {
208
struct AggregatingEventHandler : public EventHandler {
209
std::vector<Event> events;
211
void handle(const Event& e) override {
212
events.emplace_back(e);
218
std::shared_ptr<T> handler;
220
HandlerGuard() : handler(std::make_shared<T>()) {
221
registerEventHandler(handler);
225
unregisterEventHandler(handler);
229
TEST(MonitorTest, Stat) {
230
HandlerGuard<AggregatingEventHandler> guard;
234
{Aggregation::COUNT, Aggregation::SUM},
235
std::chrono::milliseconds(1),
237
ASSERT_EQ(guard.handler->events.size(), 0);
240
ASSERT_LE(a.count(), 1);
242
std::this_thread::sleep_for(std::chrono::milliseconds(2));
244
ASSERT_LE(a.count(), 1);
246
ASSERT_GE(guard.handler->events.size(), 1);
247
ASSERT_LE(guard.handler->events.size(), 2);
250
TEST(MonitorTest, StatEvent) {
251
HandlerGuard<AggregatingEventHandler> guard;
255
{Aggregation::COUNT, Aggregation::SUM},
256
std::chrono::milliseconds(1),
258
ASSERT_EQ(guard.handler->events.size(), 0);
261
ASSERT_EQ(a.count(), 1);
263
ASSERT_EQ(a.count(), 2);
264
ASSERT_EQ(guard.handler->events.size(), 0);
266
a.mockWindowId = 100;
269
ASSERT_LE(a.count(), 1);
271
ASSERT_EQ(guard.handler->events.size(), 1);
272
Event e = guard.handler->events.at(0);
273
ASSERT_EQ(e.name, "torch.monitor.Stat");
274
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
275
std::unordered_map<std::string, data_value_t> data{
279
ASSERT_EQ(e.data, data);
282
TEST(MonitorTest, StatEventDestruction) {
283
HandlerGuard<AggregatingEventHandler> guard;
288
{Aggregation::COUNT, Aggregation::SUM},
289
std::chrono::hours(10),
292
ASSERT_EQ(a.count(), 1);
293
ASSERT_EQ(guard.handler->events.size(), 0);
295
ASSERT_EQ(guard.handler->events.size(), 1);
297
Event e = guard.handler->events.at(0);
298
ASSERT_EQ(e.name, "torch.monitor.Stat");
299
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
300
std::unordered_map<std::string, data_value_t> data{
304
ASSERT_EQ(e.data, data);