pytorch
707 строк · 30.4 Кб
1#include <unordered_map>
2#include <c10/core/impl/alloc_cpu.h>
3#include <c10/core/Allocator.h>
4#include <c10/core/ScalarType.h>
5#include <c10/util/ArrayRef.h>
6
7#include <torch/csrc/Device.h>
8#include <torch/csrc/jit/serialization/pickler.h>
9#include <c10/core/impl/DeviceGuardImplInterface.h>
10#include <c10/macros/Macros.h>
11#include <torch/extension.h>
12
13#include <ATen/native/cpu/Loops.h>
14#include <ATen/native/quantized/AffineQuantizer.h>
15#include <ATen/native/DispatchStub.h>
16#include <ATen/native/Resize.h>
17#include <ATen/native/UnaryOps.h>
18#include <ATen/native/CPUFallback.h>
19#include <ATen/ops/abs_native.h>
20#include <ATen/EmptyTensor.h>
21#include <ATen/core/GeneratorForPrivateuseone.h>
22#include <ATen/detail/PrivateUse1HooksInterface.h>
23#include <ATen/ops/view.h>
24#include <ATen/native/transformers/sdp_utils_cpp.h>
25#include <ATen/native/transformers/attention.h>
26
27static uint64_t add_counter = 0;
28static uint64_t last_saved_value = 0;
29static c10::DeviceIndex custom_device_index = 0;
30
31static uint64_t abs_counter = 0;
32static uint64_t last_abs_saved_value = 0;
33
34static uint64_t storageImpl_counter = 0;
35static uint64_t last_storageImpl_saved_value = 0;
36// register guard
37namespace at {
38namespace detail {
39
40C10_REGISTER_GUARD_IMPL(
41PrivateUse1,
42c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
43
44}} // namespace at::detail
45
46namespace {
47
48// Using the simplest way to obtain continuous Tensor data and process it.
49// This is a demo for using operand API, and you can add more complex logic
50// for input and output tensor based on your custom device kernel.
51void abs_kernel(at::TensorIteratorBase& iter) {
52// Abs only have a input tensor and a output tensor.
53auto& output_operand = iter.operand(0);
54auto& input_operand = iter.operand(1);
55auto& output_tensor_base = output_operand.tensor_base();
56auto& input_tensor_base = input_operand.tensor_base();
57TORCH_CHECK(!input_operand.original_tensor_base().defined(),
58"input original tensor is defined.");
59TORCH_CHECK(!output_operand.original_tensor_base().defined(),
60"output original tensor is defined.");
61// For easy test, only accept contiguous input tensor for calculate.
62auto memory_format = input_tensor_base.suggest_memory_format();
63TORCH_CHECK(input_tensor_base.is_contiguous(memory_format),
64"Input tensor need be contiguous.");
65// Add necessary restrictions to ensure the security of the demo.
66TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(),
67"Intput and output tensor size are not equal.");
68// Common dtype is calculate in TensorIteratorBase.
69TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float,
70"Only support float type.")
71// Using for loop for abs calculate.
72auto abs_function = [](float* output_ptr, const float* input_ptr,
73const int64_t NUM) {
74for (int64_t i = 0; i < NUM; ++i) {
75*(output_ptr + i) = std::abs(*(input_ptr + i));
76}
77};
78// To simplify the logic of the test demo code,
79// we only use contiguous tensor to calculate on device side.
80// And using input tensor memory format.
81if (iter.is_contiguous()) {
82// Add for will_resize flag check. You can convert to differernt
83// tensor memory format when will_resize is True.
84// If TensorIteratorConfig resize_outputs_ flag is true, and there are two
85// situations:
86// 1) Out tensor is undefined, and TensorIterator set will_resize to true;
87// 2) Out tensor is defined and tensor size is not equal to input tensor size;
88// TensorIterator set will_resize to true, and call set_output_raw_strided
89// to resize output tensor.
90// When output operand will_resize flag is ture, dummy
91// device can convert tensor to dummy device preferred memory format.
92// Here we don't convert tensor memory format, because it will become complex
93// when dummy device want keep same memory format for training network.
94TORCH_CHECK(output_operand.will_resize,
95"output operand will_resize flag need be True.");
96abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel());
97} else {
98// Stride copy is not support for foo device, using cpu device instead.
99// For abs op, the last situation is: output tensor is not contiguous with
100// operand will_resize is False.
101TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True.");
102// Get a contiguous tensor with input memory format.
103at::Tensor output = at::empty(output_tensor_base.sizes(),
104input_tensor_base.options()
105.memory_format(memory_format));
106// For structured op which inheried from TensorIteratorBase, maybe you need to
107// call set_output_raw_strided function to update output stored in op sturctured.
108// abs op is no need to do this.
109output_operand.exchange_tensor(c10::MaybeOwned<at::TensorBase>::owned(std::in_place, output));
110abs_function((float*)output_operand.tensor_base().mutable_data_ptr(),
111(float*)iter.data_ptr(1), iter.numel());
112// Copy tensor base to original tensor base, and keep same scalar type and
113// stride with cpu and gpu.
114if (output_operand.original_tensor_base().defined() &&
115!output_operand.original_tensor_base().is_same(output_operand.tensor_base())) {
116output_operand.original_tensor().copy_(output_operand.tensor());
117output_operand.restore_original_tensor();
118}
119}
120}
121
122void quantize_tensor_per_tensor_affine_privateuse1(
123const at::Tensor& rtensor,
124at::Tensor& qtensor,
125double scale,
126int64_t zero_point) {
127// do nothing
128}
129
130int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
131const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
132auto backend = sdp::SDPBackend::overrideable;
133return static_cast<int64_t>(backend);
134}
135} // namespace
136
137namespace at::native {
138
139REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);
140REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
141REGISTER_PRIVATEUSE1_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_privateuse1);
142
143} // namespace at::native
144struct CustomBackendMetadata : public c10::BackendMeta {
145// for testing this field will mutate when clone() is called by shallow_copy_from.
146int backend_version_format_{-1};
147int format_number_{-1};
148mutable bool cloned_{false};
149// define the constructor
150CustomBackendMetadata(int backend_version_format, int format_number) :
151backend_version_format_(backend_version_format), format_number_(format_number) {}
152c10::intrusive_ptr<c10::BackendMeta> clone(
153const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
154cloned_ = true;
155return c10::BackendMeta::clone(ptr);
156}
157};
158
159// we need to register two functions for serialization
160void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
161if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr() == nullptr) {
162return;
163}
164auto tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
165if (tmeta->backend_version_format_ == 1) {
166m["backend_version_format"] = true;
167}
168if (tmeta->format_number_ == 29) {
169m["format_number"] = true;
170}
171}
172
173void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
174int backend_version_format{-1};
175int format_number{-1};
176if (m.find("backend_version_format") != m.end()) {
177backend_version_format = 1;
178}
179if (m.find("format_number") != m.end()) {
180format_number = 29;
181}
182c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
183new CustomBackendMetadata(backend_version_format, format_number))};
184t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
185}
186
187void custom_serialization_registry() {
188torch::jit::TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1,
189&for_serialization,
190&for_deserialization);
191}
192
193//check if BackendMeta serialization correctly
194bool check_backend_meta(const at::Tensor& t) {
195if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr()) {
196CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(
197t.unsafeGetTensorImpl()->get_backend_meta());
198if (tmeta->backend_version_format_==1 && tmeta->format_number_==29) {
199return true;
200}
201}
202return false;
203}
204
205// a fake set function is exposed to the Python side
206void custom_set_backend_meta(const at::Tensor& t) {
207int backend_version_format{1};
208int format_number{29};
209c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
210new CustomBackendMetadata(backend_version_format, format_number))};
211t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
212}
213
214// A dummy storageImpl for our custom device, that secretly uses the CPU
215c10::intrusive_ptr<c10::StorageImpl> make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,
216c10::SymInt size_bytes,
217c10::DataPtr data_ptr,
218c10::Allocator* allocator,
219bool resizable) {
220c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl;
221if (data_ptr == nullptr){
222custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
223c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable);
224} else {
225custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
226c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable);
227}
228storageImpl_counter += 1;
229return custom_storage_impl;
230}
231
232// Register our dummy storageImpl create method.
233void custom_storage_registry() {
234c10::SetStorageImplCreate(c10::DeviceType::PrivateUse1, &make_custom_storage_impl);
235}
236
237bool custom_storageImpl_called() {
238if (storageImpl_counter > last_storageImpl_saved_value) {
239last_storageImpl_saved_value = storageImpl_counter;
240return true;
241}
242return false;
243}
244
245// basic dummy add function
246at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {
247add_counter += 1;
248// Since this custom device is just for testing, not bothering to implement kernels.
249return at::empty(self.sizes(), self.options());
250}
251
252// basic abs function
253at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
254return at::native::abs_out(self, out);
255}
256
257// A dummy allocator for our custom device, that secretly uses the CPU
258struct DummyCustomAllocator final : at::Allocator {
259DummyCustomAllocator() = default;
260at::DataPtr allocate(size_t nbytes) override {
261void* data = c10::alloc_cpu(nbytes);
262return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, custom_device_index)};
263}
264
265static void ReportAndDelete(void* ptr) {
266if (!ptr) {
267return;
268}
269c10::free_cpu(ptr);
270}
271
272at::DeleterFnPtr raw_deleter() const override {
273return &ReportAndDelete;
274}
275
276void copy_data(void* dest, const void* src, std::size_t count) const final {
277default_copy_data(dest, src, count);
278}
279};
280
281// Register our dummy allocator
282static DummyCustomAllocator global_custom_alloc;
283REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
284
285// basic dummy empty function, so we can directly construct tensors on the custom device
286// This dummy test device will just use the CPU allocator, and ignores pinned memory.
287at::Tensor custom_empty_memory_format(at::IntArrayRef size,
288std::optional<at::ScalarType> dtype,
289std::optional<at::Layout> layout,
290std::optional<at::Device> device,
291std::optional<bool> pin_memory,
292std::optional<at::MemoryFormat> memory_format) {
293constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
294return at::detail::empty_generic(size,
295&global_custom_alloc,
296private_use_ks,
297c10::dtype_or_default(dtype),
298memory_format);
299}
300at::Tensor custom_empty_symint(c10::IntArrayRef size,
301std::optional<at::ScalarType> dtype,
302std::optional<at::Layout> layout,
303std::optional<at::Device> device,
304std::optional<bool> pin_memory,
305std::optional<at::MemoryFormat> memory_format) {
306constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
307return at::detail::empty_generic(size,
308&global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
309}
310
311at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
312// Not bothering to implement.
313return self;
314}
315
316// Unsafe using dummy device data_ptr to creat a cpu tensor, and shared data_ptr.
317at::Tensor unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor& src) {
318TORCH_CHECK(src.device().type() == c10::DeviceType::PrivateUse1,
319"Only support dummy device.");
320const auto& sizes_ = src.sizes();
321const auto& strides_ = src.strides();
322auto storage_offset_ = src.storage_offset();
323at::detail::check_size_nonnegative(sizes_);
324
325size_t size_bytes = at::detail::computeStorageNbytes(sizes_, strides_,
326src.element_size(),
327storage_offset_);
328
329at::DataPtr data_ptr =
330c10::InefficientStdFunctionContext::makeDataPtr(src.storage().mutable_data_ptr().get(),
331[](void*){}, at::kCPU);
332
333c10::Storage storage{c10::Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr),
334/*allocator=*/&global_custom_alloc, /*resizeable=*/false};
335
336constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU);
337at::Tensor tensor = at::detail::make_tensor<c10::TensorImpl>(
338std::move(storage), cpu_ks, src.dtype());
339
340c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
341tensor_impl->set_sizes_and_strides(sizes_, strides_);
342tensor_impl->set_storage_offset(storage_offset_);
343return tensor;
344}
345
346// basic dummy copy_() function, so we can copy from the custom device to/from CPU
347at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
348TORCH_CHECK(
349self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1,
350"Dummy test only allows copy from cpu -> dummy device.");
351TORCH_CHECK(
352dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1,
353"Dummy test only allows copy from cpu -> dummy device.");
354
355// Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
356TORCH_CHECK(self.sizes() == dst.sizes());
357TORCH_CHECK(self.scalar_type() == dst.scalar_type());
358
359if (self.is_contiguous() && dst.is_contiguous()) {
360std::memcpy(dst.storage().data_ptr().get(),
361self.storage().data_ptr().get(),
362self.storage().nbytes());
363} else {
364// Using cpu tensor to accomplishment stride copy.
365auto convert_to_cpu_tensor = [](const at::Tensor& src) -> at::Tensor {
366if (src.device().type() == c10::DeviceType::PrivateUse1) {
367return unsafe_create_cpu_tensor_from_dummy_tensor(src);
368} else {
369return src;
370}
371};
372at::Tensor cpu_self = convert_to_cpu_tensor(self);
373at::Tensor cpu_dst = convert_to_cpu_tensor(dst);
374cpu_dst.copy_(cpu_self);
375}
376
377return dst;
378}
379
380at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) {
381return custom__copy_from(self, dst, false);
382}
383
384at::Tensor custom_empty_strided(c10::IntArrayRef size,
385c10::IntArrayRef stride,
386std::optional<at::ScalarType> dtype_opt,
387std::optional<at::Layout> layout_opt,
388std::optional<at::Device> device_opt,
389std::optional<bool> pin_memory_opt) {
390constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
391auto dtype = c10::dtype_or_default(dtype_opt);
392return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
393}
394
395// Some set operations for the basic use case
396at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
397int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
398c10::IntArrayRef stride = {};
399result.unsafeGetTensorImpl()->set_storage_offset(0);
400at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
401at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
402new_size, stride_opt,
403/*resize_storage=*/!result.is_meta());
404return result;
405}
406
407// Some set operations for the basic use case
408at::Tensor& custom_set_source_Storage_storage_offset(at::Tensor& result,
409c10::Storage storage,
410int64_t storage_offset,
411c10::IntArrayRef size,
412c10::IntArrayRef stride) {
413result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
414at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
415at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
416size, stride_opt,
417/*resize_storage=*/!result.is_meta());
418return result;
419}
420
421const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
422std::optional<at::MemoryFormat> optional_memory_format) {
423at::TensorImpl* tensor_impl = self.unsafeGetTensorImpl();
424tensor_impl->set_sizes_contiguous(size);
425const auto itemsize = tensor_impl->dtype().itemsize();
426const auto offset = tensor_impl->storage_offset();
427const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset);
428// Dummy device is using cpu allocator, so here just call cpu
429// function maybe_resize_storage_cpu in aten/src/ATen/native/Resize.h
430// to get a sufficient memory space.
431at::native::maybe_resize_storage_cpu(tensor_impl, storage_size);
432if (optional_memory_format.has_value()) {
433auto memory_format =
434optional_memory_format.value();
435TORCH_CHECK(
436memory_format != at::MemoryFormat::Preserve,
437"Unsupported memory format",
438memory_format);
439tensor_impl->empty_tensor_restride(memory_format);
440}
441return self;
442}
443
444std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
445custom_scaled_dot_product_fused_attention_overrideable(
446const at::Tensor & query,
447const at::Tensor & key,
448const at::Tensor & value,
449const std::optional<at::Tensor> & attn_bias,
450double dropout_p,
451bool is_causal,
452bool return_debug_mask,
453std::optional<double> scale) {
454const int64_t batch_size = query.size(0);
455const int64_t num_heads = query.size(1);
456const int64_t head_dim_qk = query.size(3);
457const int64_t head_dim_v = value.size(3);
458const int64_t max_seqlen_q = query.size(2);
459const int64_t max_seqlen_kv = key.size(2);
460
461auto opts = query.options();
462auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
463auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
464auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
465opts.dtype(at::kFloat));
466auto philox_seed = at::empty({}, at::dtype(at::kLong));
467auto philox_offset = at::empty({}, at::dtype(at::kLong));
468
469return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask);
470}
471std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
472custom_scaled_dot_product_fused_attention_overrideable_backward(
473const at::Tensor & grad_out,
474const at::Tensor & query,
475const at::Tensor & key,
476const at::Tensor & value,
477const at::Tensor & attn_bias,
478std::array<bool,4> grad_input_mask,
479const at::Tensor & out,
480const at::Tensor & logsumexp,
481const at::Tensor & cum_seq_q,
482const at::Tensor & cum_seq_k,
483int64_t max_q,
484int64_t max_k,
485double dropout_p,
486bool is_causal,
487const at::Tensor & philox_seed,
488const at::Tensor & philox_offset,
489std::optional<double> scale) {
490return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
491at::empty_like(query),
492at::empty_like(key),
493at::empty_like(value),
494at::empty_like(attn_bias));
495}
496
497// This macro does the heavy lifting.
498// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
499// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
500// Later in this file, we map a custom device to the PrivateUse1 device type,
501// which allows user code that puts a tensor on your custom_device to eventually get plumbed
502// into the kernels registered here.
503//
504// This macro registers your kernels to the PyTorch Dispatcher.
505// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
506TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
507m.impl("abs.out", &custom_abs_out);
508m.impl("add.Tensor", &custom_add_Tensor);
509m.impl("empty.memory_format", &custom_empty_symint);
510m.impl("fill_.Scalar", &custom_fill__scalar);
511m.impl("_copy_from", &custom__copy_from);
512m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
513m.impl("empty_strided", &custom_empty_strided);
514m.impl("set_.source_Storage", &custom_set_source_Storage);
515m.impl("set_.source_Storage_storage_offset",&custom_set_source_Storage_storage_offset);
516m.impl("resize_", &custom_resize_);
517m.impl("as_strided", at::native::as_strided_tensorimpl);
518m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
519m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
520m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
521m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
522}
523
524void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
525at::native::cpu_fallback(op, stack);
526}
527
528TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
529m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
530m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
531m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
532m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
533m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
534}
535
536// This basic implementation doesn't bother dealing with different device indices
537// (e.g. custom_device:0 vs. custom_device:1).
538// We could do that by letting the user pass in a device index in our exposed device function.
539// Note that if you do that, you'll also need to register a device guard to core.
540// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
541c10::Device get_custom_device() {
542return c10::Device(c10::DeviceType::PrivateUse1, 0);
543}
544
545bool custom_add_called() {
546bool called = false;
547if (add_counter > last_saved_value) {
548called = true;
549last_saved_value = add_counter;
550}
551return called;
552}
553
554class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
555public:
556// Constructors
557PrivateGeneratorImpl(c10::DeviceIndex device_index) {
558device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
559key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
560}
561~PrivateGeneratorImpl() override = default;
562};
563
564// this is used to register generator
565at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
566return at::make_generator<PrivateGeneratorImpl>(device_index);
567}
568
569void register_generator_first() {
570REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
571}
572
573void register_generator_second() {
574REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
575}
576
577void set_custom_device_index(c10::DeviceIndex device_index) {
578custom_device_index = device_index;
579}
580
581// a global flag used for dummy pin_memory of custom device
582bool custom_pinned_flag = false;
583
584struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
585
586struct FooHooksInterface : public at::PrivateUse1HooksInterface {
587FooHooksInterface(FooHooksArgs) {}
588~FooHooksInterface() override = default;
589const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) override {
590static auto device_gen = make_generator_privateuse1(device_index);
591return device_gen;
592}
593// this is a simple implementation, custom_pinned_flag will be set as true
594// once tensor.pin_memory() is called. And then tensor.is_pinned()
595// always return true no matter what tensor it's called on.
596bool isPinnedPtr(const void* data) const override {
597return custom_pinned_flag;
598}
599c10::Allocator* getPinnedMemoryAllocator() const override {
600custom_pinned_flag = true;
601return c10::GetCPUAllocator();
602}
603};
604
605TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
606C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs)
607// Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
608C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "FooHooks", FooHooksInterface)
609
610static at::PrivateUse1HooksInterface* privateuse1_hooks_local = nullptr;
611static at::PrivateUse1HooksInterface* get_private_hooks() {
612static c10::once_flag once;
613c10::call_once(once, [] {
614privateuse1_hooks_local = PrivateUse1HooksRegistry()->Create("FooHooks", {}).release();
615if (!privateuse1_hooks_local) {
616privateuse1_hooks_local = new FooHooksInterface(FooHooksArgs{});
617}
618});
619return privateuse1_hooks_local;
620}
621
622void register_hook() {
623at::RegisterPrivateUse1HooksInterface(get_private_hooks());
624}
625
626bool is_register_hook() {
627return privateuse1_hooks_local != nullptr;
628}
629
630const at::Generator& default_generator(c10::DeviceIndex device_index) {
631return at::globalContext().defaultGenerator(at::Device(c10::DeviceType::PrivateUse1, device_index));;
632}
633
634void fallback_with_undefined_tensor() {
635at::Tensor first = at::empty((2,3)).to(at::DeviceType::PrivateUse1);
636at::Tensor second = at::Tensor();
637at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1);
638at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1);
639at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1);
640at::TensorList tensors = {first, first};
641at::TensorList undefined_tensors = {first, second};
642at::TensorList steps = {step, step};
643return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors,
644steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false,
645grad_scale, found_inf);
646}
647
648struct CustomAutogradFnReturnsSelf : public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
649
650static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
651return self;
652}
653
654static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
655return {grad_output[0] * 0.5};
656}
657};
658
659struct CustomAutogradFnAliasing : public torch::autograd::Function<CustomAutogradFnAliasing> {
660
661static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
662return self.view_symint(self.sym_sizes());
663}
664
665static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
666return {grad_output[0] * 0.5};
667}
668};
669
670at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
671return CustomAutogradFnReturnsSelf::apply(x);
672}
673
674at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
675return CustomAutogradFnAliasing::apply(x);
676}
677
678// Here, we're exposing a custom device object that corresponds to our custom backend.
679// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
680// that's implemented in C++.
681// The implementation in this file maps directly to the `PrivateUse1` device type.
682PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
683m.def("custom_device", &get_custom_device, "get custom device object");
684m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
685m.def("register_generator_first", ®ister_generator_first, "register generator for custom device firstly");
686m.def("register_generator_second", ®ister_generator_second, "register generator for custom device secondly");
687m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
688m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
689m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
690m.def("custom_set_backend_meta", &custom_set_backend_meta, "a fake set tensor BackendMeta function");
691m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly");
692m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function");
693m.def("register_hook", ®ister_hook, "register_hook for privateuse1");
694m.def("is_register_hook", &is_register_hook, "is_register_hook for privateuse1");
695m.def("default_generator", &default_generator, "default_generator for privateuse1");
696m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
697
698// Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++
699m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self);
700}
701
702TORCH_LIBRARY(_test_funcs, m) {
703m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
704}
705TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) {
706m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing);
707}
708