pytorch

Форк
0
/
image_input_op_gpu.cc 
38 строк · 1.1 Кб
1
#include "caffe2/core/common_gpu.h"
2
#include "caffe2/core/context_gpu.h"
3
#include "caffe2/image/image_input_op.h"
4

5
namespace caffe2 {
6

7
template <>
8
bool ImageInputOp<CUDAContext>::ApplyTransformOnGPU(
9
    const std::vector<std::int64_t>& dims,
10
    const c10::Device& type) {
11
  // GPU transform kernel allows explicitly setting output type
12
  if (output_type_ == TensorProto_DataType_FLOAT) {
13
    auto* image_output =
14
        OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(type));
15
    TransformOnGPU<uint8_t, float, CUDAContext>(
16
        prefetched_image_on_device_,
17
        image_output,
18
        mean_gpu_,
19
        std_gpu_,
20
        &context_);
21
  } else if (output_type_ == TensorProto_DataType_FLOAT16) {
22
    auto* image_output =
23
        OperatorBase::OutputTensor(0, dims, at::dtype<at::Half>().device(type));
24
    TransformOnGPU<uint8_t, at::Half, CUDAContext>(
25
        prefetched_image_on_device_,
26
        image_output,
27
        mean_gpu_,
28
        std_gpu_,
29
        &context_);
30
  } else {
31
    return false;
32
  }
33
  return true;
34
}
35

36
REGISTER_CUDA_OPERATOR(ImageInput, ImageInputOp<CUDAContext>);
37

38
} // namespace caffe2
39

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.