1
#include "caffe2/contrib/nccl/cuda_nccl_gpu.h"
7
std::vector<int> getDevices(const NCCLExecution& ex) {
8
std::vector<int> result;
9
result.reserve(ex.elements.size());
10
for (const auto& el : ex.elements) {
11
result.push_back(el.device);
18
explicit NCCLContext(const NCCLExecution& ex)
19
: devices_(getDevices(ex)), master_gpu_id_(ex.stream_gpu_id) {
20
comms_.resize(devices_.size());
22
ncclCommInitAll(comms_.data(), devices_.size(), devices_.data()));
24
streams_.resize(devices_.size());
25
events_.resize(devices_.size());
26
for (auto i = 0U; i < devices_.size(); ++i) {
27
CUDAGuard g(devices_[i]);
28
// get stream priorities
30
CUDA_ENFORCE(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
31
CUDA_ENFORCE(cudaStreamCreateWithPriority(
32
&streams_[i], cudaStreamNonBlocking, hi_pri));
33
CUDA_ENFORCE(cudaEventCreateWithFlags(
34
&events_[i], cudaEventDefault | cudaEventDisableTiming));
36
CUDAGuard g(master_gpu_id_);
37
CUDA_ENFORCE(cudaEventCreateWithFlags(
38
&master_event_, cudaEventDefault | cudaEventDisableTiming));
42
for (auto i = 0U; i < devices_.size(); ++i) {
43
CUDAGuard g(devices_[i]);
44
CUDA_ENFORCE(cudaStreamDestroy(streams_[i]));
45
CUDA_ENFORCE(cudaEventDestroy(events_[i]));
47
CUDAGuard g(master_gpu_id_);
48
CUDA_ENFORCE(cudaEventDestroy(master_event_));
50
for (auto& comm : comms_) {
51
ncclCommDestroy(comm);
55
std::vector<int> devices_;
56
std::vector<ncclComm_t> comms_;
57
std::vector<cudaStream_t> streams_;
59
cudaEvent_t master_event_;
60
std::vector<cudaEvent_t> events_;
62
C10_DISABLE_COPY_AND_ASSIGN(NCCLContext);
65
// We share the contexts across multiple operators, hence the cache.
66
static std::mutex& gContextsMutex() {
71
std::unordered_map<std::string, std::unique_ptr<NCCLContext>>& gContexts() {
72
static std::unordered_map<std::string, std::unique_ptr<NCCLContext>> m;
76
std::string ncclKey(const NCCLExecution& ex) {
79
CUDA_CHECK(cudaGetDevice(&curr_device));
80
result += to_string(curr_device) + ":";
81
for (const auto& el : ex.elements) {
82
result += to_string(el.device) + ",";
87
NCCLContext* getNCCLContext(const NCCLExecution& ex) {
88
auto& contexts = gContexts();
89
const auto key = ncclKey(ex);
91
LOG(INFO) << "Creating NCCLContext for key: " << key;
92
contexts[key].reset(new NCCLContext(ex));
94
return TORCH_CHECK_NOTNULL(contexts[key].get());
101
class ncclTypeWrapper<float> {
103
static const ncclDataType_t type = ncclFloat;
107
class ncclTypeWrapper<int> {
109
static const ncclDataType_t type = ncclInt;
112
#ifdef CAFFE_HAS_CUDA_FP16
114
class ncclTypeWrapper<at::Half> {
116
static const ncclDataType_t type = ncclHalf;
120
template <typename T, typename InitF, typename F>
121
void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
123
for (auto i = 0U; i < ex.elements.size(); ++i) {
124
auto& ctx = ex.elements[i];
125
CUDAGuard g(ctx.device);
126
init_f(ex.elements[i]);
129
std::lock_guard<std::mutex> g(gContextsMutex());
130
auto* context = getNCCLContext(ex);
131
auto& comms = context->comms_;
132
auto& streams = context->streams_;
133
auto& events = context->events_;
134
// Record an event on the master context, wait on it in each of the
135
// children streams, so the children streams are synchronized WRT
136
// the original stream.
138
CUDAGuard g(ex.stream_gpu_id);
139
CUDA_ENFORCE(cudaEventRecord(context->master_event_, ex.stream));
143
// lock out alloc / free while NCCL launches
144
std::lock_guard<std::mutex> lock(CUDAContext::mutex());
146
#if NCCL_VERSION_MIN(2, 0, 0)
147
CAFFE_NCCL_CHECK(ncclGroupStart());
150
for (auto i = 0U; i < ex.elements.size(); ++i) {
151
auto& ctx = ex.elements[i];
152
CUDAGuard g(ctx.device);
153
auto& comm = comms[i];
154
auto& stream = streams[i];
156
TORCH_DCHECK_EQ(ctx.device, GetGPUIDForPointer(ctx.src->raw_data()));
157
CUDA_ENFORCE(cudaStreamWaitEvent(stream, context->master_event_, 0));
158
f(ctx, comm, stream);
161
#if NCCL_VERSION_MIN(2, 0, 0)
162
CAFFE_NCCL_CHECK(ncclGroupEnd());
165
for (auto i = 0U; i < ex.elements.size(); ++i) {
166
auto& ctx = ex.elements[i];
167
CUDAGuard g(ctx.device);
168
auto& stream = streams[i];
169
auto& event = events[i];
171
// Record an event on each children stream that we have finished
173
CUDA_ENFORCE(cudaEventRecord(event, stream));
177
// Now, wait on all the events in the original stream.
178
CUDAGuard dg(ex.stream_gpu_id);
179
for (auto& event : events) {
180
CUDA_ENFORCE(cudaStreamWaitEvent(TORCH_CHECK_NOTNULL(ex.stream), event, 0));
186
void destroyContexts() {
187
std::lock_guard<std::mutex> g(gContextsMutex());
188
auto& contexts = gContexts();
193
void NCCL<T>::AllReduce(const NCCLExecution& ex) {
196
[](const NCCLElement& ctx) {
197
ctx.dst->Resize(ctx.src->sizes());
198
ctx.dst->template mutable_data<T>();
200
[](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
201
CAFFE_NCCL_CHECK(ncclAllReduce(
203
ctx.dst->raw_mutable_data(),
205
ncclTypeWrapper<T>::type,
213
void NCCL<T>::Broadcast(const NCCLExecution& ex) {
216
[](const NCCLElement& ctx) {
217
ctx.dst->Resize(ctx.src->sizes());
218
ctx.dst->template mutable_data<T>();
220
[&ex](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
221
CAFFE_NCCL_CHECK(ncclBcast(
222
ctx.dst->raw_mutable_data(),
224
ncclTypeWrapper<T>::type,
232
void NCCL<T>::Reduce(const NCCLExecution& ex) {
235
[](const NCCLElement& ctx) {
237
ctx.dst->Resize(ctx.src->sizes());
238
ctx.dst->template mutable_data<T>();
241
[&ex](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
242
CAFFE_NCCL_CHECK(ncclReduce(
244
ctx.dst ? ctx.dst->raw_mutable_data() : nullptr,
246
ncclTypeWrapper<T>::type,
255
void NCCL<T>::AllGather(const NCCLExecution& ex) {
256
const auto n = ex.elements.size();
259
[n](const NCCLElement& ctx) {
260
CAFFE_ENFORCE_NE(ctx.src, ctx.dst);
261
std::vector<int64_t> dims;
262
dims.reserve(ctx.src->dim() + 1);
264
for (auto d : ctx.src->sizes()) {
267
ctx.dst->Resize(dims);
268
ctx.dst->template mutable_data<T>();
270
[](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
271
#if NCCL_VERSION_MIN(2, 0, 0)
272
CAFFE_NCCL_CHECK(ncclAllGather(
274
ctx.dst->raw_mutable_data(),
276
ncclTypeWrapper<T>::type,
280
CAFFE_NCCL_CHECK(ncclAllGather(
283
ncclTypeWrapper<T>::type,
284
ctx.dst->raw_mutable_data(),
292
void NCCL<T>::ReduceScatter(const NCCLExecution& ex) {
295
[](const NCCLElement& ctx) {
296
CAFFE_ENFORCE_NE(ctx.src, ctx.dst);
297
const auto& srcDims = ctx.src->sizes();
298
std::vector<int64_t> dstDims(srcDims.begin() + 1, srcDims.end());
299
ctx.dst->Resize(dstDims);
300
ctx.dst->template mutable_data<T>();
302
[](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
303
CAFFE_NCCL_CHECK(ncclReduceScatter(
305
ctx.dst->raw_mutable_data(),
307
ncclTypeWrapper<T>::type,
314
// Explicit instantiation
315
template class NCCL<float>;
316
template class NCCL<int>;
317
#ifdef CAFFE_HAS_CUDA_FP16
318
template class NCCL<at::Half>;