pytorch

Форк
0
/
blobs_queue.cc 
176 строк · 5.8 Кб
1
#include "caffe2/queue/blobs_queue.h"
2

3
#include <atomic>
4
#include <condition_variable>
5
#include <memory>
6
#include <mutex>
7
#include <queue>
8

9
#include "caffe2/core/blob_stats.h"
10
#include "caffe2/core/logging.h"
11
#include "caffe2/core/stats.h"
12
#include "caffe2/core/tensor.h"
13
#include "caffe2/core/timer.h"
14
#include "caffe2/core/workspace.h"
15

16
#include <c10/util/irange.h>
17

18
namespace caffe2 {
19

20
// Constants for user tracepoints
21
C10_UNUSED static constexpr int SDT_NONBLOCKING_OP = 0;
22
C10_UNUSED static constexpr int SDT_BLOCKING_OP = 1;
23
C10_UNUSED static constexpr uint64_t SDT_TIMEOUT = (uint64_t)-1;
24
C10_UNUSED static constexpr uint64_t SDT_ABORT = (uint64_t)-2;
25
C10_UNUSED static constexpr uint64_t SDT_CANCEL = (uint64_t)-3;
26

27
BlobsQueue::BlobsQueue(
28
    Workspace* ws,
29
    const std::string& queueName,
30
    size_t capacity,
31
    size_t numBlobs,
32
    bool enforceUniqueName,
33
    const std::vector<std::string>& fieldNames)
34
    : numBlobs_(numBlobs), name_(queueName), stats_(queueName) {
35
  if (!fieldNames.empty()) {
36
    CAFFE_ENFORCE_EQ(
37
        fieldNames.size(), numBlobs, "Wrong number of fieldNames provided.");
38
    stats_.queue_dequeued_bytes.setDetails(fieldNames);
39
  }
40
  queue_.reserve(capacity);
41
  for (size_t i = 0; i < capacity; ++i) {
42
    std::vector<Blob*> blobs;
43
    blobs.reserve(numBlobs);
44
    for (size_t j = 0; j < numBlobs; ++j) {
45
      const auto blobName = queueName + "_" + to_string(i) + "_" + to_string(j);
46
      if (enforceUniqueName) {
47
        CAFFE_ENFORCE(
48
            !ws->GetBlob(blobName),
49
            "Queue internal blob already exists: ",
50
            blobName);
51
      }
52
      blobs.push_back(ws->CreateBlob(blobName));
53
    }
54
    queue_.push_back(blobs);
55
  }
56
  TORCH_DCHECK_EQ(queue_.size(), capacity);
57
}
58

59
bool BlobsQueue::blockingRead(
60
    const std::vector<Blob*>& inputs,
61
    float timeout_secs) {
62
  Timer readTimer;
63
  auto keeper = this->shared_from_this();
64
  C10_UNUSED const auto& name = name_.c_str();
65
  TORCH_SDT(queue_read_start, name, (void*)this, SDT_BLOCKING_OP);
66
  std::unique_lock<std::mutex> g(mutex_);
67
  auto canRead = [this]() {
68
    CAFFE_ENFORCE_LE(reader_, writer_);
69
    return reader_ != writer_;
70
  };
71
  // Decrease queue balance before reading to indicate queue read pressure
72
  // is being increased (-ve queue balance indicates more reads than writes)
73
  CAFFE_EVENT(stats_, queue_balance, -1);
74
  if (timeout_secs > 0) {
75
    std::chrono::milliseconds timeout_ms(int(timeout_secs * 1000));
76
    cv_.wait_for(
77
        g, timeout_ms, [this, canRead]() { return closing_ || canRead(); });
78
  } else {
79
    cv_.wait(g, [this, canRead]() { return closing_ || canRead(); });
80
  }
81
  if (!canRead()) {
82
    if (timeout_secs > 0 && !closing_) {
83
      LOG(ERROR) << "DequeueBlobs timed out in " << timeout_secs << " secs";
84
      TORCH_SDT(queue_read_end, name, (void*)this, SDT_TIMEOUT);
85
    } else {
86
      TORCH_SDT(queue_read_end, name, (void*)this, SDT_CANCEL);
87
    }
88
    return false;
89
  }
90
  DCHECK(canRead());
91
  auto& result = queue_[reader_ % queue_.size()];
92
  CAFFE_ENFORCE(inputs.size() >= result.size());
93
  for (const auto i : c10::irange(result.size())) {
94
    auto bytes = BlobStat::sizeBytes(*result[i]);
95
    CAFFE_EVENT(stats_, queue_dequeued_bytes, bytes, i);
96
    using std::swap;
97
    swap(*(inputs[i]), *(result[i]));
98
  }
99
  TORCH_SDT(queue_read_end, name, (void*)this, writer_ - reader_);
100
  CAFFE_EVENT(stats_, queue_dequeued_records);
101
  ++reader_;
102
  cv_.notify_all();
103
  CAFFE_EVENT(stats_, read_time_ns, readTimer.NanoSeconds());
104
  return true;
105
}
106

107
bool BlobsQueue::tryWrite(const std::vector<Blob*>& inputs) {
108
  Timer writeTimer;
109
  auto keeper = this->shared_from_this();
110
  C10_UNUSED const auto& name = name_.c_str();
111
  TORCH_SDT(queue_write_start, name, (void*)this, SDT_NONBLOCKING_OP);
112
  std::unique_lock<std::mutex> g(mutex_);
113
  if (!canWrite()) {
114
    TORCH_SDT(queue_write_end, name, (void*)this, SDT_ABORT);
115
    return false;
116
  }
117
  // Increase queue balance before writing to indicate queue write pressure is
118
  // being increased (+ve queue balance indicates more writes than reads)
119
  CAFFE_EVENT(stats_, queue_balance, 1);
120
  DCHECK(canWrite());
121
  doWrite(inputs);
122
  CAFFE_EVENT(stats_, write_time_ns, writeTimer.NanoSeconds());
123
  return true;
124
}
125

126
bool BlobsQueue::blockingWrite(const std::vector<Blob*>& inputs) {
127
  Timer writeTimer;
128
  auto keeper = this->shared_from_this();
129
  C10_UNUSED const auto& name = name_.c_str();
130
  TORCH_SDT(queue_write_start, name, (void*)this, SDT_BLOCKING_OP);
131
  std::unique_lock<std::mutex> g(mutex_);
132
  // Increase queue balance before writing to indicate queue write pressure is
133
  // being increased (+ve queue balance indicates more writes than reads)
134
  CAFFE_EVENT(stats_, queue_balance, 1);
135
  cv_.wait(g, [this]() { return closing_ || canWrite(); });
136
  if (!canWrite()) {
137
    TORCH_SDT(queue_write_end, name, (void*)this, SDT_ABORT);
138
    return false;
139
  }
140
  DCHECK(canWrite());
141
  doWrite(inputs);
142
  CAFFE_EVENT(stats_, write_time_ns, writeTimer.NanoSeconds());
143
  return true;
144
}
145

146
void BlobsQueue::close() {
147
  closing_ = true;
148

149
  std::lock_guard<std::mutex> g(mutex_);
150
  cv_.notify_all();
151
}
152

153
bool BlobsQueue::canWrite() {
154
  // writer is always within [reader, reader + size)
155
  // we can write if reader is within [reader, reader + size)
156
  CAFFE_ENFORCE_LE(reader_, writer_);
157
  CAFFE_ENFORCE_LE(writer_, static_cast<int64_t>(reader_ + queue_.size()));
158
  // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
159
  return writer_ != reader_ + queue_.size();
160
}
161

162
void BlobsQueue::doWrite(const std::vector<Blob*>& inputs) {
163
  auto& result = queue_[writer_ % queue_.size()];
164
  CAFFE_ENFORCE(inputs.size() >= result.size());
165
  C10_UNUSED const auto& name = name_.c_str();
166
  for (const auto i : c10::irange(result.size())) {
167
    using std::swap;
168
    swap(*(inputs[i]), *(result[i]));
169
  }
170
  TORCH_SDT(
171
      queue_write_end, name, (void*)this, reader_ + queue_.size() - writer_);
172
  ++writer_;
173
  cv_.notify_all();
174
}
175

176
} // namespace caffe2
177

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.