pytorch
1#include "caffe2/core/common_gpu.h"
2#include "caffe2/core/context_gpu.h"
3#include "caffe2/image/image_input_op.h"
4
5namespace caffe2 {
6
7template <>
8bool ImageInputOp<CUDAContext>::ApplyTransformOnGPU(
9const std::vector<std::int64_t>& dims,
10const c10::Device& type) {
11// GPU transform kernel allows explicitly setting output type
12if (output_type_ == TensorProto_DataType_FLOAT) {
13auto* image_output =
14OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(type));
15TransformOnGPU<uint8_t, float, CUDAContext>(
16prefetched_image_on_device_,
17image_output,
18mean_gpu_,
19std_gpu_,
20&context_);
21} else if (output_type_ == TensorProto_DataType_FLOAT16) {
22auto* image_output =
23OperatorBase::OutputTensor(0, dims, at::dtype<at::Half>().device(type));
24TransformOnGPU<uint8_t, at::Half, CUDAContext>(
25prefetched_image_on_device_,
26image_output,
27mean_gpu_,
28std_gpu_,
29&context_);
30} else {
31return false;
32}
33return true;
34}
35
36REGISTER_CUDA_OPERATOR(ImageInput, ImageInputOp<CUDAContext>);
37
38} // namespace caffe2
39