pytorch

Форк
0
/
cuda_nccl_op_gpu.cc 
275 строк · 7.5 Кб
1
#include "caffe2/core/context_gpu.h"
2
#include "caffe2/core/operator.h"
3

4
#include "caffe2/contrib/nccl/cuda_nccl_gpu.h"
5

6
namespace caffe2 {
7

8
nccl::NCCLExecution getNCCLElements(
9
    OperatorBase* op,
10
    const CUDAContext& context) {
11
  // We either do an N-N op, or an N-1 op.
12
  CAFFE_ENFORCE(op->InputSize() == op->OutputSize() || op->OutputSize() == 1);
13
  nccl::NCCLExecution ex;
14
  ex.stream_gpu_id = context.device_id();
15
  ex.stream = context.cuda_stream();
16
  ex.root = op->template GetSingleArgument<int>("root", 0);
17
  ex.elements.resize(op->InputSize());
18
  for (auto i = 0; i < op->InputSize(); ++i) {
19
    auto& el = ex.elements[i];
20
    el.src = &(op->Input<Tensor>(i, CUDA));
21
    if (op->OutputSize() == 1) {
22
      // Reduce op
23
      if (i == ex.root) {
24
        el.dst = op->Output<Tensor>(0, CUDA);
25
      }
26
    } else if (i < op->OutputSize()) {
27
      el.dst = op->Output<Tensor>(i, CUDA);
28
    }
29
    // TODO - expensive (>1ms) - cache these.
30
    el.device = GetGPUIDForPointer(op->Input<Tensor>(i, CUDA).raw_data());
31
  }
32

33
  return ex;
34
}
35

36
namespace {
37

38
// Check if all inputs are float
39
template <typename T>
40
bool AllInputsAre(OperatorBase* op) {
41
  for (auto i = 0; i < op->InputSize(); ++i) {
42
    if (op->Input<Tensor>(i, CUDA).IsType<T>()) {
43
      continue;
44
    } else {
45
      return false;
46
    }
47
  }
48
  return true;
49
}
50

51
// Manual count of all instantiated NCCL ops.
52
// If this drops to zero after destructing the last NCCL op,
53
// it means we can safely destroy all lazily created NCCL contexts.
54
std::atomic<int> kNCCLOpCounter(0);
55

56
}; // namespace
57

58
class NCCLBaseOp : public Operator<CUDAContext> {
59
 public:
60
  using Operator::Operator;
61

62
  NCCLBaseOp(const OperatorDef& operator_def, Workspace* ws)
63
      : Operator<CUDAContext>(operator_def, ws) {
64
    kNCCLOpCounter++;
65
  }
66

67
  ~NCCLBaseOp() {
68
    if (--kNCCLOpCounter == 0) {
69
      nccl::destroyContexts();
70
    }
71
  }
72
};
73

74
class NCCLAllreduceOp final : public NCCLBaseOp {
75
 public:
76
  using NCCLBaseOp::NCCLBaseOp;
77

78
  bool RunOnDevice() override {
79
    if (InputSize() == 1)
80
      return true;
81

82
    if (AllInputsAre<float>(this)) {
83
      nccl::NCCL<float>::AllReduce(getNCCLElements(this, context_));
84
      return true;
85
    } else if (AllInputsAre<at::Half>(this)) {
86
      nccl::NCCL<at::Half>::AllReduce(getNCCLElements(this, context_));
87
      return true;
88
    } else {
89
      return false;
90
    }
91
  }
92

93
  static std::vector<TensorShape> ShapeInference(
94
      const OperatorDef& def,
95
      const std::vector<TensorShape>& in) {
96
    auto n_outputs = def.output_size();
97
    CAFFE_ENFORCE(
98
        n_outputs == 1 || n_outputs == in.size(),
99
        "NCCLAllreduce only supports N-1 or N-N reductions");
100

101
    for (auto i = 0; i < in.size(); i++) {
102
      CAFFE_ENFORCE(
103
          in[0].dims_size() == in[i].dims_size(),
104
          "NCCLAllreduce requires inputs of same dimension");
105
      for (auto j = 0; j < in[0].dims_size(); j++) {
106
        CAFFE_ENFORCE(
107
            in[0].dims(j) == in[i].dims(j),
108
            "NCCLAllreduce requires inputs to be of same shape");
109
      }
110
    }
111

112
    std::vector<TensorShape> out(n_outputs);
113
    for (auto i = 0; i < out.size(); i++) {
114
      out[i] = in[0];
115
    }
116
    return out;
117
  }
118

119
  static struct OpSchema::Cost CostInference(
120
      const OperatorDef& def,
121
      const vector<TensorShape>& inputs) {
122
    CAFFE_ENFORCE_GE(inputs.size(), 1, "Conv requires at least 1 input");
123
    const TensorShape X0 = inputs[0];
124
    const auto nElem = nElemFromDim(inputs[0]);
125

126
    struct OpSchema::Cost c;
127
    c.flops = (inputs.size() - 1) * nElem;
128
    c.bytes_read = inputs.size() * nElem;
129
    c.bytes_written = def.output_size() * nElem;
130
    c.params_bytes = 0;
131
    return c;
132
  }
133
};
134

135
class NCCLBroadcastOp final : public NCCLBaseOp {
136
 public:
137
  using NCCLBaseOp::NCCLBaseOp;
138

139
  bool RunOnDevice() override {
140
    if (InputSize() == 1)
141
      return true;
142
    if (AllInputsAre<float>(this)) {
143
      nccl::NCCL<float>::Broadcast(getNCCLElements(this, context_));
144
      return true;
145
    } else if (AllInputsAre<at::Half>(this)) {
146
      nccl::NCCL<at::Half>::Broadcast(getNCCLElements(this, context_));
147
      return true;
148
    } else {
149
      return false;
150
    }
151
  }
152
};
153

154
class NCCLReduceOp final : public NCCLBaseOp {
155
 public:
156
  using NCCLBaseOp::NCCLBaseOp;
157

158
  bool RunOnDevice() override {
159
    if (InputSize() == 1)
160
      return true;
161
    const auto& ex = getNCCLElements(this, context_);
162

163
    if (AllInputsAre<float>(this)) {
164
      nccl::NCCL<float>::Reduce(ex);
165
      return true;
166
    } else if (AllInputsAre<at::Half>(this)) {
167
      nccl::NCCL<at::Half>::Reduce(ex);
168
      return true;
169
    } else {
170
      return false;
171
    }
172
  }
173
};
174

175
class NCCLAllGatherOp final : public NCCLBaseOp {
176
 public:
177
  using NCCLBaseOp::NCCLBaseOp;
178

179
  bool RunOnDevice() override {
180
    if (InputSize() == 1)
181
      return true;
182
    if (AllInputsAre<float>(this)) {
183
      nccl::NCCL<float>::AllGather(getNCCLElements(this, context_));
184
      return true;
185
    } else if (AllInputsAre<at::Half>(this)) {
186
      nccl::NCCL<at::Half>::AllGather(getNCCLElements(this, context_));
187
      return true;
188
    } else {
189
      return false;
190
    }
191
  }
192
};
193

194
class NCCLReduceScatterOp final : public NCCLBaseOp {
195
 public:
196
  using NCCLBaseOp::NCCLBaseOp;
197

198
  bool RunOnDevice() override {
199
    if (AllInputsAre<float>(this)) {
200
      nccl::NCCL<float>::ReduceScatter(getNCCLElements(this, context_));
201
      return true;
202
    } else if (AllInputsAre<at::Half>(this)) {
203
      nccl::NCCL<at::Half>::ReduceScatter(getNCCLElements(this, context_));
204
      return true;
205
    } else {
206
      return false;
207
    }
208
  }
209
};
210

211
namespace {
212

213
std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer(
214
    const OperatorDef& def) {
215
  std::vector<DeviceOption> opt;
216
  for (int i = 0; i < def.input().size(); ++i) {
217
    DeviceOption dev;
218
    dev.set_device_type(1);
219
    dev.set_device_id(i);
220
    opt.push_back(dev);
221
  }
222
  return std::make_pair(opt, opt);
223
}
224

225
REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp);
226
OPERATOR_SCHEMA(NCCLAllreduce)
227
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
228
    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
229
    .CostInferenceFunction(NCCLAllreduceOp::CostInference)
230
    .TensorInferenceFunction(NCCLAllreduceOp::ShapeInference)
231
    .IdenticalTypeAndShape()
232
    .InputsCanCrossDevices()
233
    .AllowOneToOneInplace()
234
    .DeviceInferenceFunction(ncclOpDevInfer);
235
SHOULD_NOT_DO_GRADIENT(NCCLAllreduce);
236

237
REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp);
238
OPERATOR_SCHEMA(NCCLBroadcast)
239
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
240
    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
241
    .IdenticalTypeAndShape()
242
    .InputsCanCrossDevices()
243
    .EnforceOneToOneInplace()
244
    .DeviceInferenceFunction(ncclOpDevInfer);
245

246
SHOULD_NOT_DO_GRADIENT(NCCLBroadcast);
247

248
REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp);
249
OPERATOR_SCHEMA(NCCLReduce)
250
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
251
    .NumOutputs(1)
252
    .IdenticalTypeAndShapeOfInput(0)
253
    .InputsCanCrossDevices()
254
    .AllowInplace([](int /*in*/, int out) -> bool { return (out == 0); })
255
    .DeviceInferenceFunction(ncclOpDevInfer);
256
SHOULD_NOT_DO_GRADIENT(NCCLReduce);
257

258
REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp);
259
OPERATOR_SCHEMA(NCCLAllGather)
260
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
261
    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
262
    .InputsCanCrossDevices()
263
    .DeviceInferenceFunction(ncclOpDevInfer);
264
SHOULD_NOT_DO_GRADIENT(NCCLAllGather);
265

266
REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp);
267
OPERATOR_SCHEMA(NCCLReduceScatter)
268
    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
269
    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
270
    .InputsCanCrossDevices()
271
    .DeviceInferenceFunction(ncclOpDevInfer);
272
SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter);
273

274
} // namespace
275
} // namespace caffe2
276

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.