pytorch

Форк
0
/
aten_op.cc 
56 строк · 1.3 Кб
1
#include "caffe2/contrib/aten/aten_op.h"
2
#include "caffe2/utils/math.h"
3

4
namespace caffe2 {
5

6
namespace internal {
7
at::Tensor index_with_uint8_handling(
8
    const at::Tensor& self,
9
    const torch::List<c10::optional<at::Tensor>>& indices) {
10
  // Support BC only for the simplest case of mask indexing
11
  if (indices.size() == 1) {
12
    c10::optional<at::Tensor> first = indices[0];
13
    if (first.has_value()
14
        && first->scalar_type() == at::kByte) {
15
      TORCH_WARN(
16
          "Indexing with uint8 mask tensor in ATenOp is now deprecated,"
17
          " please use a bool mask instead.");
18
      return at::index(self, {first->to(at::kBool)});
19
    }
20
  }
21
  return at::index(self, indices);
22
}
23
} // namespace internal
24

25
REGISTER_CPU_OPERATOR(ATen, ATenOp<CPUContext>);
26
template <>
27
at::Backend ATenOp<CPUContext>::backend() const {
28
  return at::Backend::CPU;
29
}
30

31
OPERATOR_SCHEMA(ATen);
32

33
namespace math {
34

35
template <>
36
void Set<at::Half, CPUContext>(
37
    const std::int64_t /* N */,
38
    const at::Half h,
39
    at::Half* v,
40
    CPUContext* c) {
41
  Set(0, h.x, (uint16_t*)v, c);
42
}
43

44
template <>
45
void Set<at::BFloat16, CPUContext>(
46
    const std::int64_t /* N */,
47
    const at::BFloat16 b,
48
    at::BFloat16* v,
49
    CPUContext* c) {
50
  Set(0, b.x, (uint16_t*)v, c);
51
}
52

53

54
} // namespace math
55

56
} // namespace caffe2
57

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.