pytorch

Форк
0
/
cuda_nccl_gpu.cc 
322 строки · 8.7 Кб
1
#include "caffe2/contrib/nccl/cuda_nccl_gpu.h"
2

3
namespace caffe2 {
4
namespace nccl {
5
namespace {
6

7
std::vector<int> getDevices(const NCCLExecution& ex) {
8
  std::vector<int> result;
9
  result.reserve(ex.elements.size());
10
  for (const auto& el : ex.elements) {
11
    result.push_back(el.device);
12
  }
13
  return result;
14
}
15

16
class NCCLContext {
17
 public:
18
  explicit NCCLContext(const NCCLExecution& ex)
19
      : devices_(getDevices(ex)), master_gpu_id_(ex.stream_gpu_id) {
20
    comms_.resize(devices_.size());
21
    CAFFE_NCCL_CHECK(
22
        ncclCommInitAll(comms_.data(), devices_.size(), devices_.data()));
23

24
    streams_.resize(devices_.size());
25
    events_.resize(devices_.size());
26
    for (auto i = 0U; i < devices_.size(); ++i) {
27
      CUDAGuard g(devices_[i]);
28
      // get stream priorities
29
      int lo_pri, hi_pri;
30
      CUDA_ENFORCE(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
31
      CUDA_ENFORCE(cudaStreamCreateWithPriority(
32
          &streams_[i], cudaStreamNonBlocking, hi_pri));
33
      CUDA_ENFORCE(cudaEventCreateWithFlags(
34
          &events_[i], cudaEventDefault | cudaEventDisableTiming));
35
    }
36
    CUDAGuard g(master_gpu_id_);
37
    CUDA_ENFORCE(cudaEventCreateWithFlags(
38
        &master_event_, cudaEventDefault | cudaEventDisableTiming));
39
  }
40

41
  ~NCCLContext() {
42
    for (auto i = 0U; i < devices_.size(); ++i) {
43
      CUDAGuard g(devices_[i]);
44
      CUDA_ENFORCE(cudaStreamDestroy(streams_[i]));
45
      CUDA_ENFORCE(cudaEventDestroy(events_[i]));
46
    }
47
    CUDAGuard g(master_gpu_id_);
48
    CUDA_ENFORCE(cudaEventDestroy(master_event_));
49

50
    for (auto& comm : comms_) {
51
      ncclCommDestroy(comm);
52
    }
53
  }
54

55
  std::vector<int> devices_;
56
  std::vector<ncclComm_t> comms_;
57
  std::vector<cudaStream_t> streams_;
58
  int master_gpu_id_;
59
  cudaEvent_t master_event_;
60
  std::vector<cudaEvent_t> events_;
61

62
  C10_DISABLE_COPY_AND_ASSIGN(NCCLContext);
63
};
64

65
// We share the contexts across multiple operators, hence the cache.
66
static std::mutex& gContextsMutex() {
67
  static std::mutex m;
68
  return m;
69
}
70

71
std::unordered_map<std::string, std::unique_ptr<NCCLContext>>& gContexts() {
72
  static std::unordered_map<std::string, std::unique_ptr<NCCLContext>> m;
73
  return m;
74
}
75

76
std::string ncclKey(const NCCLExecution& ex) {
77
  std::string result;
78
  int curr_device;
79
  CUDA_CHECK(cudaGetDevice(&curr_device));
80
  result += to_string(curr_device) + ":";
81
  for (const auto& el : ex.elements) {
82
    result += to_string(el.device) + ",";
83
  }
84
  return result;
85
}
86

87
NCCLContext* getNCCLContext(const NCCLExecution& ex) {
88
  auto& contexts = gContexts();
89
  const auto key = ncclKey(ex);
90
  if (!contexts[key]) {
91
    LOG(INFO) << "Creating NCCLContext for key: " << key;
92
    contexts[key].reset(new NCCLContext(ex));
93
  }
94
  return TORCH_CHECK_NOTNULL(contexts[key].get());
95
}
96

97
template <typename T>
98
class ncclTypeWrapper;
99

100
template <>
101
class ncclTypeWrapper<float> {
102
 public:
103
  static const ncclDataType_t type = ncclFloat;
104
};
105

106
template <>
107
class ncclTypeWrapper<int> {
108
 public:
109
  static const ncclDataType_t type = ncclInt;
110
};
111

112
#ifdef CAFFE_HAS_CUDA_FP16
113
template <>
114
class ncclTypeWrapper<at::Half> {
115
 public:
116
  static const ncclDataType_t type = ncclHalf;
117
};
118
#endif
119

120
template <typename T, typename InitF, typename F>
121
void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
122
  // do initialization
123
  for (auto i = 0U; i < ex.elements.size(); ++i) {
124
    auto& ctx = ex.elements[i];
125
    CUDAGuard g(ctx.device);
126
    init_f(ex.elements[i]);
127
  }
128

129
  std::lock_guard<std::mutex> g(gContextsMutex());
130
  auto* context = getNCCLContext(ex);
131
  auto& comms = context->comms_;
132
  auto& streams = context->streams_;
133
  auto& events = context->events_;
134
  // Record an event on the master context, wait on it in each of the
135
  // children streams, so the children streams are synchronized WRT
136
  // the original stream.
137
  {
138
    CUDAGuard g(ex.stream_gpu_id);
139
    CUDA_ENFORCE(cudaEventRecord(context->master_event_, ex.stream));
140
  }
141

142
  {
143
    // lock out alloc / free while NCCL launches
144
    std::lock_guard<std::mutex> lock(CUDAContext::mutex());
145

146
#if NCCL_VERSION_MIN(2, 0, 0)
147
    CAFFE_NCCL_CHECK(ncclGroupStart());
148
#endif
149

150
    for (auto i = 0U; i < ex.elements.size(); ++i) {
151
      auto& ctx = ex.elements[i];
152
      CUDAGuard g(ctx.device);
153
      auto& comm = comms[i];
154
      auto& stream = streams[i];
155

156
      TORCH_DCHECK_EQ(ctx.device, GetGPUIDForPointer(ctx.src->raw_data()));
157
      CUDA_ENFORCE(cudaStreamWaitEvent(stream, context->master_event_, 0));
158
      f(ctx, comm, stream);
159
    }
160

161
#if NCCL_VERSION_MIN(2, 0, 0)
162
    CAFFE_NCCL_CHECK(ncclGroupEnd());
163
#endif
164

165
    for (auto i = 0U; i < ex.elements.size(); ++i) {
166
      auto& ctx = ex.elements[i];
167
      CUDAGuard g(ctx.device);
168
      auto& stream = streams[i];
169
      auto& event = events[i];
170

171
      // Record an event on each children stream that we have finished
172
      // our computation
173
      CUDA_ENFORCE(cudaEventRecord(event, stream));
174
    }
175
  }
176

177
  // Now, wait on all the events in the original stream.
178
  CUDAGuard dg(ex.stream_gpu_id);
179
  for (auto& event : events) {
180
    CUDA_ENFORCE(cudaStreamWaitEvent(TORCH_CHECK_NOTNULL(ex.stream), event, 0));
181
  }
182
}
183

184
} // namespace
185

186
void destroyContexts() {
187
  std::lock_guard<std::mutex> g(gContextsMutex());
188
  auto& contexts = gContexts();
189
  contexts.clear();
190
}
191

192
template <typename T>
193
void NCCL<T>::AllReduce(const NCCLExecution& ex) {
194
  return runNCCL<T>(
195
      ex,
196
      [](const NCCLElement& ctx) {
197
        ctx.dst->Resize(ctx.src->sizes());
198
        ctx.dst->template mutable_data<T>();
199
      },
200
      [](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
201
        CAFFE_NCCL_CHECK(ncclAllReduce(
202
            ctx.src->raw_data(),
203
            ctx.dst->raw_mutable_data(),
204
            ctx.dst->numel(),
205
            ncclTypeWrapper<T>::type,
206
            ncclSum,
207
            comm,
208
            stream));
209
      });
210
}
211

212
template <typename T>
213
void NCCL<T>::Broadcast(const NCCLExecution& ex) {
214
  return runNCCL<T>(
215
      ex,
216
      [](const NCCLElement& ctx) {
217
        ctx.dst->Resize(ctx.src->sizes());
218
        ctx.dst->template mutable_data<T>();
219
      },
220
      [&ex](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
221
        CAFFE_NCCL_CHECK(ncclBcast(
222
            ctx.dst->raw_mutable_data(),
223
            ctx.dst->numel(),
224
            ncclTypeWrapper<T>::type,
225
            ex.root,
226
            comm,
227
            stream));
228
      });
229
}
230

231
template <typename T>
232
void NCCL<T>::Reduce(const NCCLExecution& ex) {
233
  return runNCCL<T>(
234
      ex,
235
      [](const NCCLElement& ctx) {
236
        if (ctx.dst) {
237
          ctx.dst->Resize(ctx.src->sizes());
238
          ctx.dst->template mutable_data<T>();
239
        }
240
      },
241
      [&ex](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
242
        CAFFE_NCCL_CHECK(ncclReduce(
243
            ctx.src->raw_data(),
244
            ctx.dst ? ctx.dst->raw_mutable_data() : nullptr,
245
            ctx.src->numel(),
246
            ncclTypeWrapper<T>::type,
247
            ncclSum,
248
            ex.root,
249
            comm,
250
            stream));
251
      });
252
}
253

254
template <typename T>
255
void NCCL<T>::AllGather(const NCCLExecution& ex) {
256
  const auto n = ex.elements.size();
257
  return runNCCL<T>(
258
      ex,
259
      [n](const NCCLElement& ctx) {
260
        CAFFE_ENFORCE_NE(ctx.src, ctx.dst);
261
        std::vector<int64_t> dims;
262
        dims.reserve(ctx.src->dim() + 1);
263
        dims.push_back(n);
264
        for (auto d : ctx.src->sizes()) {
265
          dims.push_back(d);
266
        }
267
        ctx.dst->Resize(dims);
268
        ctx.dst->template mutable_data<T>();
269
      },
270
      [](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
271
#if NCCL_VERSION_MIN(2, 0, 0)
272
        CAFFE_NCCL_CHECK(ncclAllGather(
273
            ctx.src->raw_data(),
274
            ctx.dst->raw_mutable_data(),
275
            ctx.src->numel(),
276
            ncclTypeWrapper<T>::type,
277
            comm,
278
            stream));
279
#else
280
        CAFFE_NCCL_CHECK(ncclAllGather(
281
            ctx.src->raw_data(),
282
            ctx.src->size(),
283
            ncclTypeWrapper<T>::type,
284
            ctx.dst->raw_mutable_data(),
285
            comm,
286
            stream));
287
#endif
288
      });
289
}
290

291
template <typename T>
292
void NCCL<T>::ReduceScatter(const NCCLExecution& ex) {
293
  return runNCCL<T>(
294
      ex,
295
      [](const NCCLElement& ctx) {
296
        CAFFE_ENFORCE_NE(ctx.src, ctx.dst);
297
        const auto& srcDims = ctx.src->sizes();
298
        std::vector<int64_t> dstDims(srcDims.begin() + 1, srcDims.end());
299
        ctx.dst->Resize(dstDims);
300
        ctx.dst->template mutable_data<T>();
301
      },
302
      [](const NCCLElement& ctx, ncclComm_t comm, cudaStream_t stream) {
303
        CAFFE_NCCL_CHECK(ncclReduceScatter(
304
            ctx.src->raw_data(),
305
            ctx.dst->raw_mutable_data(),
306
            ctx.dst->numel(),
307
            ncclTypeWrapper<T>::type,
308
            ncclSum,
309
            comm,
310
            stream));
311
      });
312
}
313

314
// Explicit instantiation
315
template class NCCL<float>;
316
template class NCCL<int>;
317
#ifdef CAFFE_HAS_CUDA_FP16
318
template class NCCL<at::Half>;
319
#endif
320

321
} // namespace nccl
322
} // namespace caffe2
323

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.