pytorch

Форк
0
/
open_registration_extension.cpp 
707 строк · 30.4 Кб
1
#include <unordered_map>
2
#include <c10/core/impl/alloc_cpu.h>
3
#include <c10/core/Allocator.h>
4
#include <c10/core/ScalarType.h>
5
#include <c10/util/ArrayRef.h>
6

7
#include <torch/csrc/Device.h>
8
#include <torch/csrc/jit/serialization/pickler.h>
9
#include <c10/core/impl/DeviceGuardImplInterface.h>
10
#include <c10/macros/Macros.h>
11
#include <torch/extension.h>
12

13
#include <ATen/native/cpu/Loops.h>
14
#include <ATen/native/quantized/AffineQuantizer.h>
15
#include <ATen/native/DispatchStub.h>
16
#include <ATen/native/Resize.h>
17
#include <ATen/native/UnaryOps.h>
18
#include <ATen/native/CPUFallback.h>
19
#include <ATen/ops/abs_native.h>
20
#include <ATen/EmptyTensor.h>
21
#include <ATen/core/GeneratorForPrivateuseone.h>
22
#include <ATen/detail/PrivateUse1HooksInterface.h>
23
#include <ATen/ops/view.h>
24
#include <ATen/native/transformers/sdp_utils_cpp.h>
25
#include <ATen/native/transformers/attention.h>
26

27
static uint64_t add_counter = 0;
28
static uint64_t last_saved_value = 0;
29
static c10::DeviceIndex custom_device_index = 0;
30

31
static uint64_t abs_counter = 0;
32
static uint64_t last_abs_saved_value = 0;
33

34
static uint64_t storageImpl_counter = 0;
35
static uint64_t last_storageImpl_saved_value = 0;
36
// register guard
37
namespace at {
38
namespace detail {
39

40
C10_REGISTER_GUARD_IMPL(
41
    PrivateUse1,
42
    c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
43

44
}} // namespace at::detail
45

46
namespace {
47

48
// Using the simplest way to obtain continuous Tensor data and process it.
49
// This is a demo for using operand API, and you can add more complex logic
50
// for input and output tensor based on your custom device kernel.
51
void abs_kernel(at::TensorIteratorBase& iter) {
52
  // Abs only have a input tensor and a output tensor.
53
  auto& output_operand = iter.operand(0);
54
  auto& input_operand = iter.operand(1);
55
  auto& output_tensor_base = output_operand.tensor_base();
56
  auto& input_tensor_base = input_operand.tensor_base();
57
  TORCH_CHECK(!input_operand.original_tensor_base().defined(),
58
    "input original tensor is defined.");
59
  TORCH_CHECK(!output_operand.original_tensor_base().defined(),
60
    "output original tensor is defined.");
61
  // For easy test, only accept contiguous input tensor for calculate.
62
  auto memory_format = input_tensor_base.suggest_memory_format();
63
  TORCH_CHECK(input_tensor_base.is_contiguous(memory_format),
64
    "Input tensor need be contiguous.");
65
  // Add necessary restrictions to ensure the security of the demo.
66
  TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(),
67
    "Intput and output tensor size are not equal.");
68
  // Common dtype is calculate in TensorIteratorBase.
69
  TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float,
70
    "Only support float type.")
71
  // Using for loop for abs calculate.
72
  auto abs_function = [](float* output_ptr, const float* input_ptr,
73
                         const int64_t NUM) {
74
    for (int64_t i = 0; i < NUM; ++i) {
75
      *(output_ptr + i) = std::abs(*(input_ptr + i));
76
    }
77
  };
78
  // To simplify the logic of the test demo code,
79
  // we only use contiguous tensor to calculate on device side.
80
  // And using input tensor memory format.
81
  if (iter.is_contiguous()) {
82
    // Add for will_resize flag check. You can convert to differernt
83
    // tensor memory format when will_resize is True.
84
    // If TensorIteratorConfig resize_outputs_ flag is true, and there are two
85
    // situations:
86
    // 1) Out tensor is undefined, and TensorIterator set will_resize to true;
87
    // 2) Out tensor is defined and tensor size is not equal to input tensor size;
88
    //    TensorIterator set will_resize to true, and call set_output_raw_strided
89
    //    to resize output tensor.
90
    // When output operand will_resize flag is ture, dummy
91
    // device can convert tensor to dummy device preferred memory format.
92
    // Here we don't convert tensor memory format, because it will become complex
93
    // when dummy device want keep same memory format for training network.
94
    TORCH_CHECK(output_operand.will_resize,
95
      "output operand will_resize flag need be True.");
96
    abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel());
97
  } else {
98
    // Stride copy is not support for foo device, using cpu device instead.
99
    // For abs op, the last situation is: output tensor is not contiguous with
100
    // operand will_resize is False.
101
    TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True.");
102
    // Get a contiguous tensor with input memory format.
103
    at::Tensor output = at::empty(output_tensor_base.sizes(),
104
                                  input_tensor_base.options()
105
                                                   .memory_format(memory_format));
106
    // For structured op which inheried from TensorIteratorBase, maybe you need to
107
    // call set_output_raw_strided function to update output stored in op sturctured.
108
    // abs op is no need to do this.
109
    output_operand.exchange_tensor(c10::MaybeOwned<at::TensorBase>::owned(std::in_place, output));
110
    abs_function((float*)output_operand.tensor_base().mutable_data_ptr(),
111
                 (float*)iter.data_ptr(1), iter.numel());
112
    // Copy tensor base to original tensor base, and keep same scalar type and
113
    // stride with cpu and gpu.
114
    if (output_operand.original_tensor_base().defined() &&
115
        !output_operand.original_tensor_base().is_same(output_operand.tensor_base())) {
116
      output_operand.original_tensor().copy_(output_operand.tensor());
117
      output_operand.restore_original_tensor();
118
    }
119
  }
120
}
121

122
void quantize_tensor_per_tensor_affine_privateuse1(
123
    const at::Tensor& rtensor,
124
    at::Tensor& qtensor,
125
    double scale,
126
    int64_t zero_point) {
127
    // do nothing
128
}
129

130
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
131
    const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
132
  auto backend = sdp::SDPBackend::overrideable;
133
  return static_cast<int64_t>(backend);
134
}
135
} // namespace
136

137
namespace at::native {
138

139
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);
140
REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
141
REGISTER_PRIVATEUSE1_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_privateuse1);
142

143
} // namespace at::native
144
struct CustomBackendMetadata : public c10::BackendMeta {
145
  // for testing this field will mutate when clone() is called by shallow_copy_from.
146
  int backend_version_format_{-1};
147
  int format_number_{-1};
148
  mutable bool cloned_{false};
149
  // define the constructor
150
  CustomBackendMetadata(int backend_version_format, int format_number) :
151
      backend_version_format_(backend_version_format), format_number_(format_number) {}
152
  c10::intrusive_ptr<c10::BackendMeta> clone(
153
      const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
154
    cloned_ = true;
155
    return c10::BackendMeta::clone(ptr);
156
  }
157
};
158

159
// we need to register two functions for serialization
160
void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
161
  if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr() == nullptr) {
162
    return;
163
  }
164
  auto tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
165
  if (tmeta->backend_version_format_ == 1) {
166
    m["backend_version_format"] = true;
167
  }
168
  if (tmeta->format_number_ == 29) {
169
    m["format_number"] = true;
170
  }
171
}
172

173
void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
174
  int backend_version_format{-1};
175
  int format_number{-1};
176
  if (m.find("backend_version_format") != m.end()) {
177
    backend_version_format = 1;
178
  }
179
  if (m.find("format_number") != m.end()) {
180
    format_number = 29;
181
  }
182
  c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
183
      new CustomBackendMetadata(backend_version_format, format_number))};
184
  t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
185
}
186

187
void custom_serialization_registry() {
188
  torch::jit::TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1,
189
                                        &for_serialization,
190
                                        &for_deserialization);
191
}
192

193
//check if BackendMeta serialization correctly
194
bool check_backend_meta(const at::Tensor& t) {
195
  if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr()) {
196
    CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(
197
        t.unsafeGetTensorImpl()->get_backend_meta());
198
    if (tmeta->backend_version_format_==1 && tmeta->format_number_==29) {
199
      return true;
200
    }
201
  }
202
  return false;
203
}
204

205
// a fake set function is exposed to the Python side
206
void custom_set_backend_meta(const at::Tensor& t) {
207
  int backend_version_format{1};
208
  int format_number{29};
209
  c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
210
      new CustomBackendMetadata(backend_version_format, format_number))};
211
  t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
212
}
213

214
// A dummy storageImpl for our custom device, that secretly uses the CPU
215
c10::intrusive_ptr<c10::StorageImpl> make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,
216
                                                              c10::SymInt size_bytes,
217
                                                              c10::DataPtr data_ptr,
218
                                                              c10::Allocator* allocator,
219
                                                              bool resizable) {
220
  c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl;
221
  if (data_ptr == nullptr){
222
    custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
223
      c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable);
224
  } else {
225
    custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
226
      c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable);
227
  }
228
  storageImpl_counter += 1;
229
  return custom_storage_impl;
230
}
231

232
// Register our dummy storageImpl create method.
233
void custom_storage_registry() {
234
  c10::SetStorageImplCreate(c10::DeviceType::PrivateUse1, &make_custom_storage_impl);
235
}
236

237
bool custom_storageImpl_called() {
238
  if (storageImpl_counter > last_storageImpl_saved_value) {
239
    last_storageImpl_saved_value = storageImpl_counter;
240
    return true;
241
  }
242
  return false;
243
}
244

245
// basic dummy add function
246
at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {
247
  add_counter += 1;
248
  // Since this custom device is just for testing, not bothering to implement kernels.
249
  return at::empty(self.sizes(), self.options());
250
}
251

252
// basic abs function
253
at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
254
  return at::native::abs_out(self, out);
255
}
256

257
// A dummy allocator for our custom device, that secretly uses the CPU
258
struct DummyCustomAllocator final : at::Allocator {
259
  DummyCustomAllocator() = default;
260
  at::DataPtr allocate(size_t nbytes) override {
261
    void* data = c10::alloc_cpu(nbytes);
262
    return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, custom_device_index)};
263
  }
264

265
  static void ReportAndDelete(void* ptr) {
266
    if (!ptr) {
267
      return;
268
    }
269
    c10::free_cpu(ptr);
270
  }
271

272
  at::DeleterFnPtr raw_deleter() const override {
273
    return &ReportAndDelete;
274
  }
275

276
  void copy_data(void* dest, const void* src, std::size_t count) const final {
277
    default_copy_data(dest, src, count);
278
  }
279
};
280

281
// Register our dummy allocator
282
static DummyCustomAllocator global_custom_alloc;
283
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
284

285
// basic dummy empty function, so we can directly construct tensors on the custom device
286
// This dummy test device will just use the CPU allocator, and ignores pinned memory.
287
at::Tensor custom_empty_memory_format(at::IntArrayRef size,
288
                                      std::optional<at::ScalarType> dtype,
289
                                      std::optional<at::Layout> layout,
290
                                      std::optional<at::Device> device,
291
                                      std::optional<bool> pin_memory,
292
                                      std::optional<at::MemoryFormat> memory_format) {
293
  constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
294
  return at::detail::empty_generic(size,
295
                                   &global_custom_alloc,
296
                                   private_use_ks,
297
                                   c10::dtype_or_default(dtype),
298
                                   memory_format);
299
}
300
at::Tensor custom_empty_symint(c10::IntArrayRef size,
301
                               std::optional<at::ScalarType> dtype,
302
                               std::optional<at::Layout> layout,
303
                               std::optional<at::Device> device,
304
                               std::optional<bool> pin_memory,
305
                               std::optional<at::MemoryFormat> memory_format) {
306
  constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
307
  return at::detail::empty_generic(size,
308
    &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
309
}
310

311
at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
312
  // Not bothering to implement.
313
  return self;
314
}
315

316
// Unsafe using dummy device data_ptr to creat a cpu tensor, and shared data_ptr.
317
at::Tensor unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor& src) {
318
  TORCH_CHECK(src.device().type() == c10::DeviceType::PrivateUse1,
319
              "Only support dummy device.");
320
  const auto& sizes_ = src.sizes();
321
  const auto& strides_ = src.strides();
322
  auto storage_offset_ = src.storage_offset();
323
  at::detail::check_size_nonnegative(sizes_);
324

325
  size_t size_bytes = at::detail::computeStorageNbytes(sizes_, strides_,
326
                                                       src.element_size(),
327
                                                       storage_offset_);
328

329
  at::DataPtr data_ptr =
330
    c10::InefficientStdFunctionContext::makeDataPtr(src.storage().mutable_data_ptr().get(),
331
                                                    [](void*){}, at::kCPU);
332

333
  c10::Storage storage{c10::Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr),
334
    /*allocator=*/&global_custom_alloc, /*resizeable=*/false};
335

336
  constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU);
337
  at::Tensor tensor = at::detail::make_tensor<c10::TensorImpl>(
338
       std::move(storage), cpu_ks, src.dtype());
339

340
  c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
341
  tensor_impl->set_sizes_and_strides(sizes_, strides_);
342
  tensor_impl->set_storage_offset(storage_offset_);
343
  return tensor;
344
}
345

346
// basic dummy copy_() function, so we can copy from the custom device to/from CPU
347
at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
348
  TORCH_CHECK(
349
      self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1,
350
      "Dummy test only allows copy from cpu -> dummy device.");
351
  TORCH_CHECK(
352
      dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1,
353
      "Dummy test only allows copy from cpu -> dummy device.");
354

355
  // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
356
  TORCH_CHECK(self.sizes() == dst.sizes());
357
  TORCH_CHECK(self.scalar_type() == dst.scalar_type());
358

359
  if (self.is_contiguous() && dst.is_contiguous()) {
360
    std::memcpy(dst.storage().data_ptr().get(),
361
                self.storage().data_ptr().get(),
362
                self.storage().nbytes());
363
  } else {
364
    // Using cpu tensor to accomplishment stride copy.
365
    auto convert_to_cpu_tensor = [](const at::Tensor& src) -> at::Tensor {
366
      if (src.device().type() == c10::DeviceType::PrivateUse1) {
367
        return unsafe_create_cpu_tensor_from_dummy_tensor(src);
368
      } else {
369
        return src;
370
      }
371
    };
372
    at::Tensor cpu_self = convert_to_cpu_tensor(self);
373
    at::Tensor cpu_dst = convert_to_cpu_tensor(dst);
374
    cpu_dst.copy_(cpu_self);
375
  }
376

377
  return dst;
378
}
379

380
at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) {
381
  return custom__copy_from(self, dst, false);
382
}
383

384
at::Tensor custom_empty_strided(c10::IntArrayRef size,
385
                                c10::IntArrayRef stride,
386
                                std::optional<at::ScalarType> dtype_opt,
387
                                std::optional<at::Layout> layout_opt,
388
                                std::optional<at::Device> device_opt,
389
                                std::optional<bool> pin_memory_opt) {
390
  constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
391
  auto dtype = c10::dtype_or_default(dtype_opt);
392
  return  at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
393
}
394

395
// Some set operations for the basic use case
396
at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
397
  int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
398
  c10::IntArrayRef stride = {};
399
  result.unsafeGetTensorImpl()->set_storage_offset(0);
400
  at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
401
  at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
402
                               new_size, stride_opt,
403
                               /*resize_storage=*/!result.is_meta());
404
  return result;
405
}
406

407
// Some set operations for the basic use case
408
at::Tensor& custom_set_source_Storage_storage_offset(at::Tensor& result,
409
                                                     c10::Storage storage,
410
                                                     int64_t storage_offset,
411
                                                     c10::IntArrayRef size,
412
                                                     c10::IntArrayRef stride) {
413
  result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
414
  at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
415
  at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
416
                               size, stride_opt,
417
                               /*resize_storage=*/!result.is_meta());
418
  return result;
419
}
420

421
const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
422
                          std::optional<at::MemoryFormat> optional_memory_format) {
423
  at::TensorImpl* tensor_impl = self.unsafeGetTensorImpl();
424
  tensor_impl->set_sizes_contiguous(size);
425
  const auto itemsize = tensor_impl->dtype().itemsize();
426
  const auto offset = tensor_impl->storage_offset();
427
  const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset);
428
  // Dummy device is using cpu allocator, so here just call cpu
429
  // function maybe_resize_storage_cpu in aten/src/ATen/native/Resize.h
430
  // to get a sufficient memory space.
431
  at::native::maybe_resize_storage_cpu(tensor_impl, storage_size);
432
  if (optional_memory_format.has_value()) {
433
    auto memory_format =
434
        optional_memory_format.value();
435
    TORCH_CHECK(
436
        memory_format != at::MemoryFormat::Preserve,
437
        "Unsupported memory format",
438
        memory_format);
439
    tensor_impl->empty_tensor_restride(memory_format);
440
  }
441
  return self;
442
}
443

444
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
445
custom_scaled_dot_product_fused_attention_overrideable(
446
    const at::Tensor & query,
447
    const at::Tensor & key,
448
    const at::Tensor & value,
449
    const std::optional<at::Tensor> & attn_bias,
450
    double dropout_p,
451
    bool is_causal,
452
    bool return_debug_mask,
453
    std::optional<double> scale) {
454
  const int64_t batch_size = query.size(0);
455
  const int64_t num_heads = query.size(1);
456
  const int64_t head_dim_qk = query.size(3);
457
  const int64_t head_dim_v = value.size(3);
458
  const int64_t max_seqlen_q = query.size(2);
459
  const int64_t max_seqlen_kv = key.size(2);
460

461
  auto opts = query.options();
462
  auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
463
  auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
464
  auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
465
                                   opts.dtype(at::kFloat));
466
  auto philox_seed = at::empty({}, at::dtype(at::kLong));
467
  auto philox_offset = at::empty({}, at::dtype(at::kLong));
468

469
  return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask);
470
}
471
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
472
custom_scaled_dot_product_fused_attention_overrideable_backward(
473
    const at::Tensor & grad_out,
474
    const at::Tensor & query,
475
    const at::Tensor & key,
476
    const at::Tensor & value,
477
    const at::Tensor & attn_bias,
478
    std::array<bool,4> grad_input_mask,
479
    const at::Tensor & out,
480
    const at::Tensor & logsumexp,
481
    const at::Tensor & cum_seq_q,
482
    const at::Tensor & cum_seq_k,
483
    int64_t max_q,
484
    int64_t max_k,
485
    double dropout_p,
486
    bool is_causal,
487
    const at::Tensor & philox_seed,
488
    const at::Tensor & philox_offset,
489
    std::optional<double> scale) {
490
  return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
491
          at::empty_like(query),
492
          at::empty_like(key),
493
          at::empty_like(value),
494
          at::empty_like(attn_bias));
495
}
496

497
// This macro does the heavy lifting.
498
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
499
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
500
// Later in this file, we map a custom device to the PrivateUse1 device type,
501
// which allows user code that puts a tensor on your custom_device to eventually get plumbed
502
// into the kernels registered here.
503
//
504
// This macro registers your kernels to the PyTorch Dispatcher.
505
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
506
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
507
  m.impl("abs.out", &custom_abs_out);
508
  m.impl("add.Tensor", &custom_add_Tensor);
509
  m.impl("empty.memory_format", &custom_empty_symint);
510
  m.impl("fill_.Scalar", &custom_fill__scalar);
511
  m.impl("_copy_from", &custom__copy_from);
512
  m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
513
  m.impl("empty_strided", &custom_empty_strided);
514
  m.impl("set_.source_Storage", &custom_set_source_Storage);
515
  m.impl("set_.source_Storage_storage_offset",&custom_set_source_Storage_storage_offset);
516
  m.impl("resize_", &custom_resize_);
517
  m.impl("as_strided", at::native::as_strided_tensorimpl);
518
  m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
519
  m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
520
  m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
521
  m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
522
}
523

524
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
525
  at::native::cpu_fallback(op, stack);
526
}
527

528
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
529
  m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
530
  m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
531
  m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
532
  m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
533
  m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
534
}
535

536
// This basic implementation doesn't bother dealing with different device indices
537
// (e.g. custom_device:0 vs. custom_device:1).
538
// We could do that by letting the user pass in a device index in our exposed device function.
539
// Note that if you do that, you'll also need to register a device guard to core.
540
// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
541
c10::Device get_custom_device() {
542
  return c10::Device(c10::DeviceType::PrivateUse1, 0);
543
}
544

545
bool custom_add_called() {
546
  bool called = false;
547
  if (add_counter > last_saved_value) {
548
    called = true;
549
    last_saved_value = add_counter;
550
  }
551
  return called;
552
}
553

554
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
555
public:
556
  // Constructors
557
  PrivateGeneratorImpl(c10::DeviceIndex device_index) {
558
    device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
559
    key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
560
  }
561
  ~PrivateGeneratorImpl() override = default;
562
};
563

564
// this is used to register generator
565
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
566
  return at::make_generator<PrivateGeneratorImpl>(device_index);
567
}
568

569
void register_generator_first() {
570
  REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
571
}
572

573
void register_generator_second() {
574
  REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
575
}
576

577
void set_custom_device_index(c10::DeviceIndex device_index) {
578
  custom_device_index = device_index;
579
}
580

581
// a global flag used for dummy pin_memory of custom device
582
bool custom_pinned_flag = false;
583

584
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
585

586
struct FooHooksInterface : public at::PrivateUse1HooksInterface {
587
    FooHooksInterface(FooHooksArgs) {}
588
    ~FooHooksInterface() override = default;
589
    const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) override {
590
      static auto device_gen = make_generator_privateuse1(device_index);
591
      return device_gen;
592
    }
593
    // this is a simple implementation, custom_pinned_flag will be set as true
594
    // once tensor.pin_memory() is called. And then tensor.is_pinned()
595
    // always return true no matter what tensor it's called on.
596
    bool isPinnedPtr(const void* data) const override {
597
      return custom_pinned_flag;
598
    }
599
    c10::Allocator* getPinnedMemoryAllocator() const override {
600
      custom_pinned_flag = true;
601
      return c10::GetCPUAllocator();
602
    }
603
};
604

605
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
606
C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs)
607
// Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
608
C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "FooHooks", FooHooksInterface)
609

610
static at::PrivateUse1HooksInterface* privateuse1_hooks_local = nullptr;
611
static at::PrivateUse1HooksInterface* get_private_hooks() {
612
  static c10::once_flag once;
613
  c10::call_once(once, [] {
614
    privateuse1_hooks_local = PrivateUse1HooksRegistry()->Create("FooHooks", {}).release();
615
    if (!privateuse1_hooks_local) {
616
      privateuse1_hooks_local = new FooHooksInterface(FooHooksArgs{});
617
    }
618
  });
619
  return privateuse1_hooks_local;
620
}
621

622
void register_hook() {
623
  at::RegisterPrivateUse1HooksInterface(get_private_hooks());
624
}
625

626
bool is_register_hook() {
627
  return privateuse1_hooks_local != nullptr;
628
}
629

630
const at::Generator& default_generator(c10::DeviceIndex device_index) {
631
  return at::globalContext().defaultGenerator(at::Device(c10::DeviceType::PrivateUse1, device_index));;
632
}
633

634
void fallback_with_undefined_tensor() {
635
  at::Tensor first = at::empty((2,3)).to(at::DeviceType::PrivateUse1);
636
  at::Tensor second = at::Tensor();
637
  at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1);
638
  at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1);
639
  at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1);
640
  at::TensorList tensors = {first, first};
641
  at::TensorList undefined_tensors = {first, second};
642
  at::TensorList steps = {step, step};
643
  return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors,
644
                           steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false,
645
                           grad_scale, found_inf);
646
}
647

648
struct CustomAutogradFnReturnsSelf : public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
649

650
  static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
651
    return self;
652
  }
653

654
  static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
655
    return {grad_output[0] * 0.5};
656
  }
657
};
658

659
struct CustomAutogradFnAliasing : public torch::autograd::Function<CustomAutogradFnAliasing> {
660

661
  static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
662
    return self.view_symint(self.sym_sizes());
663
  }
664

665
  static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
666
    return {grad_output[0] * 0.5};
667
  }
668
};
669

670
at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
671
  return CustomAutogradFnReturnsSelf::apply(x);
672
}
673

674
at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
675
  return CustomAutogradFnAliasing::apply(x);
676
}
677

678
// Here, we're exposing a custom device object that corresponds to our custom backend.
679
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
680
// that's implemented in C++.
681
// The implementation in this file maps directly to the `PrivateUse1` device type.
682
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
683
    m.def("custom_device", &get_custom_device, "get custom device object");
684
    m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
685
    m.def("register_generator_first", &register_generator_first, "register generator for custom device firstly");
686
    m.def("register_generator_second", &register_generator_second, "register generator for custom device secondly");
687
    m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
688
    m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
689
    m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
690
    m.def("custom_set_backend_meta", &custom_set_backend_meta, "a fake set tensor BackendMeta function");
691
    m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly");
692
    m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function");
693
    m.def("register_hook", &register_hook, "register_hook for privateuse1");
694
    m.def("is_register_hook", &is_register_hook, "is_register_hook for privateuse1");
695
    m.def("default_generator", &default_generator, "default_generator for privateuse1");
696
    m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
697

698
    // Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++
699
    m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self);
700
}
701

702
TORCH_LIBRARY(_test_funcs, m) {
703
  m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
704
}
705
TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) {
706
  m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing);
707
}
708

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.