pytorch

Форк
0
/
pybind_state_ideep.cc 
222 строки · 6.8 Кб
1
// Note(jiayq): the import_array function is done inside
2
// caffe2_python.cc. Read
3
// http://docs.scipy.org/doc/numpy-1.10.1/reference/c-api.array.html#miscellaneous
4
// for more details.
5
#define NO_IMPORT_ARRAY
6

7
#include "pybind_state.h"
8

9
#include <pybind11/pybind11.h>
10
#include <pybind11/stl.h>
11

12
#include <caffe2/ideep/ideep_utils.h>
13
#include "caffe2/ideep/operators/operator_fallback_ideep.h"
14

15
namespace caffe2 {
16
namespace python {
17

18
USE_IDEEP_DEF_ALIASES();
19

20
class IDeepFetcher;
21
class IDeepFeeder;
22

23
REGISTER_IDEEP_OPERATOR(Python, IDEEPFallbackOp<PythonOp<CPUContext, false>>);
24

25
REGISTER_BLOB_FETCHER((TypeMeta::Id<itensor>()), IDeepFetcher);
26
REGISTER_BLOB_FEEDER(IDEEP, IDeepFeeder);
27

28
class IDeepFetcher : public BlobFetcherBase {
29
  TypeMeta type_transform(const itensor& atensor) {
30
    switch (atensor.get_data_type()) {
31
      case itensor::data_type::f32:
32
        return TypeMeta::Make<float>();
33
      case itensor::data_type::s32:
34
        return TypeMeta::Make<int>();
35
      case itensor::data_type::s8:
36
        return TypeMeta::Make<int8_t>();
37
      case itensor::data_type::u8:
38
        return TypeMeta::Make<uint8_t>();
39
      default:
40
        // Should we throw exception?
41
        return TypeMeta();
42
    }
43
  }
44

45
 public:
46
  pybind11::object Fetch(const Blob& blob) override {
47
    try {
48
      return FetchTensor(blob.Get<itensor>(), true).obj;
49
    } catch (ideep::error& e) {
50
      LOG(ERROR) << "IDEEP error: " << e.message;
51
      throw;
52
    }
53
  }
54

55
  FetchedBlob FetchTensor(const itensor& atensor, bool force_copy) {
56
#ifdef USE_NUMPY
57
    FetchedBlob result;
58
    CAFFE_ENFORCE(
59
        (atensor.ndims() != 0) &&
60
            (atensor.get_nelems() == 0 || atensor.get_data_handle() != nullptr),
61
        "Trying to fetch uninitialized tensor");
62
    // NOTE: Only support float so far.
63
    const int numpy_type = NPY_FLOAT;
64
    CAFFE_ENFORCE(
65
        numpy_type != -1,
66
        "Unsupported ideep memory data type? This usually should not happen "
67
        "since ideep memory usually only do float and double.");
68
    itensor::dims dims;
69
    bool need_reorder = atensor.need_reorder();
70
    if (atensor.get_data_type() == idtype::f32 && !atensor.has_scale()) {
71
      // For FP32 path, only support NCHW format input, so if atensor
72
      // has NHWC format, we need reorder it to NCHW format.
73
      dims = atensor.get_dims();
74
      need_reorder = need_reorder || atensor.get_desc().is_nhwc();
75
    } else {
76
      dims = atensor.get_public_format_dims();
77
    }
78
    std::vector<npy_intp> npy_dims(dims.begin(), dims.end());
79

80
    result.copied = force_copy || need_reorder;
81
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
82
    void* outPtr;
83
    if (result.copied) {
84
      result.obj = py::reinterpret_steal<py::object>(
85
          PyArray_SimpleNew(atensor.ndims(), npy_dims.data(), numpy_type));
86
      outPtr = static_cast<void*>(
87
          PyArray_DATA(reinterpret_cast<PyArrayObject*>(result.obj.ptr())));
88
    } else {
89
      outPtr = atensor.get_data_handle();
90
      result.obj = py::reinterpret_steal<py::object>(PyArray_SimpleNewFromData(
91
          atensor.ndims(), npy_dims.data(), numpy_type, outPtr));
92
    }
93

94
    if (numpy_type == NPY_OBJECT) {
95
      CAFFE_THROW("We don't support strings.");
96
    }
97

98
    if (result.copied) {
99
      if (atensor.get_data_type() == idtype::f32 && !atensor.has_scale()) {
100
        itensor temp_ten(atensor.get_desc().to_default_format(), outPtr);
101
        atensor.reorder_to(temp_ten);
102
      } else {
103
        atensor.to_public(outPtr);
104
      }
105
    }
106

107
    return result;
108
#else
109
    CAFFE_THROW("Caffe2 was compiled without NumPy support.");
110
#endif // USE_NUMPY
111
  }
112
};
113

114
class IDeepFeeder : public BlobFeederBase {
115
  itensor::data_type type_transform(const TypeMeta meta) {
116
    if (meta == TypeMeta::Make<float>())
117
      return itensor::data_type::f32;
118
    else if (meta == TypeMeta::Make<int>())
119
      return itensor::data_type::s32;
120
    else if (meta == TypeMeta::Make<int8_t>())
121
      return itensor::data_type::s8;
122
    else if (meta == TypeMeta::Make<uint8_t>())
123
      return itensor::data_type::u8;
124
    else
125
      return itensor::data_type::undef;
126
  }
127

128
 public:
129
  void FeedTensor(
130
      const DeviceOption& option,
131
      PyArrayObject* original_array,
132
      itensor* tensor) {
133
#ifdef USE_NUMPY
134
    PyArrayObject* array = PyArray_GETCONTIGUOUS(original_array);
135
    auto g = MakeGuard([&]() { Py_XDECREF(array); });
136
    const auto npy_type = PyArray_TYPE(array);
137
    const TypeMeta meta = NumpyTypeToCaffe(npy_type);
138
    CAFFE_ENFORCE_NE(
139
        meta,
140
        ScalarType::Undefined,
141
        "This numpy data type is not supported: ",
142
        PyArray_TYPE(array),
143
        ".");
144

145
    int ndim = PyArray_NDIM(array);
146
    npy_intp* npy_dims = PyArray_DIMS(array);
147

148
    itensor::dims adims;
149
    for (int i = 0; i < ndim; i++) {
150
      adims.push_back(static_cast<itensor::dims::value_type>(npy_dims[i]));
151
    }
152

153
    switch (npy_type) {
154
      case NPY_OBJECT:
155
      case NPY_UNICODE:
156
        CAFFE_THROW("IDeep doesn't support string");
157
        break;
158
      default:
159
        auto type = type_transform(meta);
160
        if (tensor->get_dims() != adims || type != tensor->get_data_type()) {
161
          tensor->resize(adims, type);
162
        }
163
        tensor->feed_from(adims, type, static_cast<void*>(PyArray_DATA(array)));
164
    }
165
#else
166
    CAFFE_THROW("Caffe2 was compiled without NumPy support.");
167
#endif // USE_NUMPY
168
  }
169

170
  bool ZeroDim(PyArrayObject* array) {
171
#ifdef USE_NUMPY
172
    int ndim = PyArray_NDIM(array);
173
    return ndim == 0;
174
#else
175
    CAFFE_THROW("Caffe2 was compiled without NumPy support.");
176
#endif
177
  }
178

179
  void Feed(
180
      const DeviceOption& option,
181
      PyArrayObject* original_array,
182
      Blob* blob,
183
      bool in_place) override {
184
#ifdef USE_NUMPY
185
    try {
186
      PyArrayObject* array = PyArray_GETCONTIGUOUS(original_array);
187
      auto g = MakeGuard([&]() { Py_XDECREF(array); });
188

189
      const auto npy_type = PyArray_TYPE(array);
190
      const TypeMeta meta = NumpyTypeToCaffe(npy_type);
191

192
      // TODO: if necessary, use dispatcher.
193
      if ((in_place && blob->IsType<itensor>()) ||
194
          (meta.Match<float>() && !ZeroDim(original_array))) {
195
        FeedTensor(option, original_array, blob->GetMutable<itensor>());
196
      } else {
197
        DeviceOption cpu_option(option);
198
        cpu_option.set_device_type(DeviceTypeProto::PROTO_CPU);
199
        TensorFeeder<CPUContext> cpu_tensor_feeder;
200
        if (in_place) {
201
          cpu_tensor_feeder.FeedTensor(
202
              cpu_option,
203
              original_array,
204
              BlobGetMutableTensor(blob, OptionToDevice(cpu_option).type()),
205
              true);
206
        } else {
207
          blob->Reset<Tensor>(new Tensor(
208
              cpu_tensor_feeder.FeedTensor(cpu_option, original_array)));
209
        }
210
      }
211
    } catch (ideep::error& e) {
212
      LOG(ERROR) << "IDEEP error: " << e.message;
213
      throw;
214
    }
215
#else
216
    CAFFE_THROW("Caffe2 was compiled without NumPy support.");
217
#endif
218
  }
219
};
220

221
} // namespace python
222
} // namespace caffe2
223

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.