pytorch

Форк
0
51 строка · 2.0 Кб
1

2
#include "ios_caffe.h"
3
#include "caffe2/core/tensor.h"
4
#include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h"
5
#include "caffe2/predictor/predictor.h"
6

7
Caffe2IOSPredictor* MakeCaffe2Predictor(const std::string& init_net_str,
8
                                        const std::string& predict_net_str,
9
                                        bool disableMultithreadProcessing,
10
                                        bool allowMetalOperators,
11
                                        std::string& errorMessage) {
12
  caffe2::NetDef init_net, predict_net;
13
  init_net.ParseFromString(init_net_str);
14
  predict_net.ParseFromString(predict_net_str);
15

16
  Caffe2IOSPredictor* predictor = NULL;
17
  try {
18
    predictor = Caffe2IOSPredictor::NewCaffe2IOSPredictor(
19
        init_net, predict_net, disableMultithreadProcessing, allowMetalOperators);
20
  } catch (const std::exception& e) {
21
    std::string error = e.what();
22
    errorMessage.swap(error);
23
    return NULL;
24
  }
25
  return predictor;
26
}
27

28
void GenerateStylizedImage(std::vector<float>& originalImage,
29
                           const std::string& init_net_str,
30
                           const std::string& predict_net_str,
31
                           int height,
32
                           int width,
33
                           std::vector<float>& dataOut) {
34
  caffe2::NetDef init_net, predict_net;
35
  init_net.ParseFromString(init_net_str);
36
  predict_net.ParseFromString(predict_net_str);
37
  caffe2::Predictor p(init_net, predict_net);
38

39
  std::vector<int> dims({1, 3, height, width});
40
  caffe2::Tensor input(caffe2::CPU);
41
  input.Resize(dims);
42
  input.ShareExternalPointer(originalImage.data());
43
  caffe2::Predictor::TensorList input_vec;
44
  input_vec.emplace_back(std::move(input));
45
  caffe2::Predictor::TensorList output_vec;
46
  p(input_vec, &output_vec);
47
  assert(output_vec.size() == 1);
48
  caffe2::TensorCPU* output = &output_vec.front();
49
  // output is our styled image
50
  float* outputArray = output->mutable_data<float>();
51
  dataOut.assign(outputArray, outputArray + output->size());
52
}
53

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.