1
#include "caffe2/core/context_gpu.h"
2
#include "caffe2/core/operator.h"
4
#include "caffe2/contrib/nccl/cuda_nccl_gpu.h"
8
nccl::NCCLExecution getNCCLElements(
10
const CUDAContext& context) {
11
// We either do an N-N op, or an N-1 op.
12
CAFFE_ENFORCE(op->InputSize() == op->OutputSize() || op->OutputSize() == 1);
13
nccl::NCCLExecution ex;
14
ex.stream_gpu_id = context.device_id();
15
ex.stream = context.cuda_stream();
16
ex.root = op->template GetSingleArgument<int>("root", 0);
17
ex.elements.resize(op->InputSize());
18
for (auto i = 0; i < op->InputSize(); ++i) {
19
auto& el = ex.elements[i];
20
el.src = &(op->Input<Tensor>(i, CUDA));
21
if (op->OutputSize() == 1) {
24
el.dst = op->Output<Tensor>(0, CUDA);
26
} else if (i < op->OutputSize()) {
27
el.dst = op->Output<Tensor>(i, CUDA);
29
// TODO - expensive (>1ms) - cache these.
30
el.device = GetGPUIDForPointer(op->Input<Tensor>(i, CUDA).raw_data());
38
// Check if all inputs are float
40
bool AllInputsAre(OperatorBase* op) {
41
for (auto i = 0; i < op->InputSize(); ++i) {
42
if (op->Input<Tensor>(i, CUDA).IsType<T>()) {
51
// Manual count of all instantiated NCCL ops.
52
// If this drops to zero after destructing the last NCCL op,
53
// it means we can safely destroy all lazily created NCCL contexts.
54
std::atomic<int> kNCCLOpCounter(0);
58
class NCCLBaseOp : public Operator<CUDAContext> {
60
using Operator::Operator;
62
NCCLBaseOp(const OperatorDef& operator_def, Workspace* ws)
63
: Operator<CUDAContext>(operator_def, ws) {
68
if (--kNCCLOpCounter == 0) {
69
nccl::destroyContexts();
74
class NCCLAllreduceOp final : public NCCLBaseOp {
76
using NCCLBaseOp::NCCLBaseOp;
78
bool RunOnDevice() override {
82
if (AllInputsAre<float>(this)) {
83
nccl::NCCL<float>::AllReduce(getNCCLElements(this, context_));
85
} else if (AllInputsAre<at::Half>(this)) {
86
nccl::NCCL<at::Half>::AllReduce(getNCCLElements(this, context_));
93
static std::vector<TensorShape> ShapeInference(
94
const OperatorDef& def,
95
const std::vector<TensorShape>& in) {
96
auto n_outputs = def.output_size();
98
n_outputs == 1 || n_outputs == in.size(),
99
"NCCLAllreduce only supports N-1 or N-N reductions");
101
for (auto i = 0; i < in.size(); i++) {
103
in[0].dims_size() == in[i].dims_size(),
104
"NCCLAllreduce requires inputs of same dimension");
105
for (auto j = 0; j < in[0].dims_size(); j++) {
107
in[0].dims(j) == in[i].dims(j),
108
"NCCLAllreduce requires inputs to be of same shape");
112
std::vector<TensorShape> out(n_outputs);
113
for (auto i = 0; i < out.size(); i++) {
119
static struct OpSchema::Cost CostInference(
120
const OperatorDef& def,
121
const vector<TensorShape>& inputs) {
122
CAFFE_ENFORCE_GE(inputs.size(), 1, "Conv requires at least 1 input");
123
const TensorShape X0 = inputs[0];
124
const auto nElem = nElemFromDim(inputs[0]);
126
struct OpSchema::Cost c;
127
c.flops = (inputs.size() - 1) * nElem;
128
c.bytes_read = inputs.size() * nElem;
129
c.bytes_written = def.output_size() * nElem;
135
class NCCLBroadcastOp final : public NCCLBaseOp {
137
using NCCLBaseOp::NCCLBaseOp;
139
bool RunOnDevice() override {
140
if (InputSize() == 1)
142
if (AllInputsAre<float>(this)) {
143
nccl::NCCL<float>::Broadcast(getNCCLElements(this, context_));
145
} else if (AllInputsAre<at::Half>(this)) {
146
nccl::NCCL<at::Half>::Broadcast(getNCCLElements(this, context_));
154
class NCCLReduceOp final : public NCCLBaseOp {
156
using NCCLBaseOp::NCCLBaseOp;
158
bool RunOnDevice() override {
159
if (InputSize() == 1)
161
const auto& ex = getNCCLElements(this, context_);
163
if (AllInputsAre<float>(this)) {
164
nccl::NCCL<float>::Reduce(ex);
166
} else if (AllInputsAre<at::Half>(this)) {
167
nccl::NCCL<at::Half>::Reduce(ex);
175
class NCCLAllGatherOp final : public NCCLBaseOp {
177
using NCCLBaseOp::NCCLBaseOp;
179
bool RunOnDevice() override {
180
if (InputSize() == 1)
182
if (AllInputsAre<float>(this)) {
183
nccl::NCCL<float>::AllGather(getNCCLElements(this, context_));
185
} else if (AllInputsAre<at::Half>(this)) {
186
nccl::NCCL<at::Half>::AllGather(getNCCLElements(this, context_));
194
class NCCLReduceScatterOp final : public NCCLBaseOp {
196
using NCCLBaseOp::NCCLBaseOp;
198
bool RunOnDevice() override {
199
if (AllInputsAre<float>(this)) {
200
nccl::NCCL<float>::ReduceScatter(getNCCLElements(this, context_));
202
} else if (AllInputsAre<at::Half>(this)) {
203
nccl::NCCL<at::Half>::ReduceScatter(getNCCLElements(this, context_));
213
std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer(
214
const OperatorDef& def) {
215
std::vector<DeviceOption> opt;
216
for (int i = 0; i < def.input().size(); ++i) {
218
dev.set_device_type(1);
219
dev.set_device_id(i);
222
return std::make_pair(opt, opt);
225
REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp);
226
OPERATOR_SCHEMA(NCCLAllreduce)
227
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
228
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
229
.CostInferenceFunction(NCCLAllreduceOp::CostInference)
230
.TensorInferenceFunction(NCCLAllreduceOp::ShapeInference)
231
.IdenticalTypeAndShape()
232
.InputsCanCrossDevices()
233
.AllowOneToOneInplace()
234
.DeviceInferenceFunction(ncclOpDevInfer);
235
SHOULD_NOT_DO_GRADIENT(NCCLAllreduce);
237
REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp);
238
OPERATOR_SCHEMA(NCCLBroadcast)
239
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
240
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
241
.IdenticalTypeAndShape()
242
.InputsCanCrossDevices()
243
.EnforceOneToOneInplace()
244
.DeviceInferenceFunction(ncclOpDevInfer);
246
SHOULD_NOT_DO_GRADIENT(NCCLBroadcast);
248
REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp);
249
OPERATOR_SCHEMA(NCCLReduce)
250
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
252
.IdenticalTypeAndShapeOfInput(0)
253
.InputsCanCrossDevices()
254
.AllowInplace([](int /*in*/, int out) -> bool { return (out == 0); })
255
.DeviceInferenceFunction(ncclOpDevInfer);
256
SHOULD_NOT_DO_GRADIENT(NCCLReduce);
258
REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp);
259
OPERATOR_SCHEMA(NCCLAllGather)
260
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
261
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
262
.InputsCanCrossDevices()
263
.DeviceInferenceFunction(ncclOpDevInfer);
264
SHOULD_NOT_DO_GRADIENT(NCCLAllGather);
266
REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp);
267
OPERATOR_SCHEMA(NCCLReduceScatter)
268
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
269
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
270
.InputsCanCrossDevices()
271
.DeviceInferenceFunction(ncclOpDevInfer);
272
SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter);