1
#include "caffe2/contrib/aten/aten_op.h"
2
#include "caffe2/utils/math.h"
7
at::Tensor index_with_uint8_handling(
8
const at::Tensor& self,
9
const torch::List<c10::optional<at::Tensor>>& indices) {
10
// Support BC only for the simplest case of mask indexing
11
if (indices.size() == 1) {
12
c10::optional<at::Tensor> first = indices[0];
14
&& first->scalar_type() == at::kByte) {
16
"Indexing with uint8 mask tensor in ATenOp is now deprecated,"
17
" please use a bool mask instead.");
18
return at::index(self, {first->to(at::kBool)});
21
return at::index(self, indices);
23
} // namespace internal
25
REGISTER_CPU_OPERATOR(ATen, ATenOp<CPUContext>);
27
at::Backend ATenOp<CPUContext>::backend() const {
28
return at::Backend::CPU;
36
void Set<at::Half, CPUContext>(
37
const std::int64_t /* N */,
41
Set(0, h.x, (uint16_t*)v, c);
45
void Set<at::BFloat16, CPUContext>(
46
const std::int64_t /* N */,
50
Set(0, b.x, (uint16_t*)v, c);