llvm-project

Форк
0
/
TFLiteUtils.cpp 
249 строк · 8.3 Кб
1
//===- TFUtils.cpp - TFLite-based evaluation utilities --------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements utilities for interfacing with TFLite.
10
//
11
//===----------------------------------------------------------------------===//
12
#include "llvm/Config/config.h"
13
#if defined(LLVM_HAVE_TFLITE)
14

15
#include "llvm/ADT/Twine.h"
16
#include "llvm/Analysis/Utils/TFUtils.h"
17
#include "llvm/Support/Base64.h"
18
#include "llvm/Support/CommandLine.h"
19
#include "llvm/Support/Debug.h"
20
#include "llvm/Support/JSON.h"
21
#include "llvm/Support/MemoryBuffer.h"
22
#include "llvm/Support/Path.h"
23
#include "llvm/Support/raw_ostream.h"
24

25
#include "tensorflow/lite/interpreter.h"
26
#include "tensorflow/lite/kernels/register.h"
27
#include "tensorflow/lite/model.h"
28
#include "tensorflow/lite/model_builder.h"
29
#include "tensorflow/lite/op_resolver.h"
30
#include "tensorflow/lite/logger.h"
31

32
#include <cassert>
33
#include <numeric>
34
#include <optional>
35

36
using namespace llvm;
37

38
namespace llvm {
39
class EvaluationResultImpl {
40
public:
41
  EvaluationResultImpl(const std::vector<const TfLiteTensor *> &Outputs)
42
      : Outputs(Outputs){};
43

44
  const TfLiteTensor *getOutput(size_t I) { return Outputs[I]; }
45

46
  EvaluationResultImpl(const EvaluationResultImpl &) = delete;
47
  EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
48

49
private:
50
  const std::vector<const TfLiteTensor *> Outputs;
51
};
52

53
class TFModelEvaluatorImpl {
54
public:
55
  TFModelEvaluatorImpl(StringRef SavedModelPath,
56
                       const std::vector<TensorSpec> &InputSpecs,
57
                       const std::vector<TensorSpec> &OutputSpecs,
58
                       const char *Tags);
59

60
  bool isValid() const { return IsValid; }
61
  size_t outputSize() const { return Output.size(); }
62

63
  std::unique_ptr<EvaluationResultImpl> evaluate() {
64
    Interpreter->Invoke();
65
    return std::make_unique<EvaluationResultImpl>(Output);
66
  }
67

68
  const std::vector<TfLiteTensor *> &getInput() const { return Input; }
69

70
  ~TFModelEvaluatorImpl();
71

72
private:
73
  std::unique_ptr<tflite::FlatBufferModel> Model;
74

75
  /// The objects necessary for carrying out an evaluation of the SavedModel.
76
  /// They are expensive to set up, and we maintain them accross all the
77
  /// evaluations of the model.
78
  std::unique_ptr<tflite::Interpreter> Interpreter;
79

80
  /// The input tensors. We set up the tensors once and just mutate theirs
81
  /// scalars before each evaluation. The input tensors keep their value after
82
  /// an evaluation.
83
  std::vector<TfLiteTensor *> Input;
84

85
  /// The output nodes.
86
  std::vector<const TfLiteTensor *> Output;
87

88
  void invalidate() { IsValid = false; }
89

90
  bool IsValid = true;
91

92
  /// Reusable utility for ensuring we can bind the requested Name to a node in
93
  /// the SavedModel Graph.
94
  bool checkReportAndInvalidate(const TfLiteTensor *Tensor,
95
                                const TensorSpec &Spec);
96
};
97

98
} // namespace llvm
99

100
TFModelEvaluatorImpl::TFModelEvaluatorImpl(
101
    StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
102
    const std::vector<TensorSpec> &OutputSpecs, const char *Tags = "serve")
103
    : Input(InputSpecs.size()), Output(OutputSpecs.size()) {
104
  // INFO and DEBUG messages could be numerous and not particularly interesting
105
  tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_WARNING);
106
  // FIXME: make ErrorReporter a member (may also need subclassing
107
  // StatefulErrorReporter) to easily get the latest error status, for
108
  // debugging.
109
  tflite::StderrReporter ErrorReporter;
110
  SmallVector<char, 128> TFLitePathBuff;
111
  llvm::sys::path::append(TFLitePathBuff, SavedModelPath, "model.tflite");
112
  StringRef TFLitePath(TFLitePathBuff.data(), TFLitePathBuff.size());
113
  Model = tflite::FlatBufferModel::BuildFromFile(TFLitePath.str().c_str(),
114
                                                 &ErrorReporter);
115
  if (!Model) {
116
    invalidate();
117
    return;
118
  }
119

120
  tflite::ops::builtin::BuiltinOpResolver Resolver;
121
  tflite::InterpreterBuilder Builder(*Model, Resolver);
122
  Builder(&Interpreter);
123

124
  if (!Interpreter) {
125
    invalidate();
126
    return;
127
  }
128

129
  // We assume the input buffers are valid for the lifetime of the interpreter.
130
  // By default, tflite allocates memory in an arena and will periodically take
131
  // away memory and reallocate it in a different location after evaluations in
132
  // order to improve utilization of the buffers owned in the arena. So, we
133
  // explicitly mark our input buffers as persistent to avoid this behavior.
134
  for (size_t I = 0; I < Interpreter->inputs().size(); ++I)
135
    Interpreter->tensor(I)->allocation_type =
136
        TfLiteAllocationType::kTfLiteArenaRwPersistent;
137

138
  if (Interpreter->AllocateTensors() != TfLiteStatus::kTfLiteOk) {
139
    invalidate();
140
    return;
141
  }
142
  // Known inputs and outputs
143
  StringMap<int> InputsMap;
144
  StringMap<int> OutputsMap;
145
  for (size_t I = 0; I < Interpreter->inputs().size(); ++I)
146
    InputsMap[Interpreter->GetInputName(I)] = I;
147
  for (size_t I = 0; I < Interpreter->outputs().size(); ++I)
148
    OutputsMap[Interpreter->GetOutputName(I)] = I;
149

150
  size_t NumberFeaturesPassed = 0;
151
  for (size_t I = 0; I < InputSpecs.size(); ++I) {
152
    auto &InputSpec = InputSpecs[I];
153
    auto MapI = InputsMap.find(InputSpec.name() + ":" +
154
                               std::to_string(InputSpec.port()));
155
    if (MapI == InputsMap.end()) {
156
      Input[I] = nullptr;
157
      continue;
158
    }
159
    Input[I] = Interpreter->tensor(MapI->second);
160
    if (!checkReportAndInvalidate(Input[I], InputSpec))
161
      return;
162
    std::memset(Input[I]->data.data, 0,
163
                InputSpecs[I].getTotalTensorBufferSize());
164
    ++NumberFeaturesPassed;
165
  }
166

167
  if (NumberFeaturesPassed < Interpreter->inputs().size()) {
168
    // we haven't passed all the required features to the model, throw an error.
169
    errs() << "Required feature(s) have not been passed to the ML model";
170
    invalidate();
171
    return;
172
  }
173

174
  for (size_t I = 0; I < OutputSpecs.size(); ++I) {
175
    const auto &OutputSpec = OutputSpecs[I];
176
    Output[I] = Interpreter->output_tensor(
177
        OutputsMap[OutputSpec.name() + ":" +
178
                   std::to_string(OutputSpec.port())]);
179
    if (!checkReportAndInvalidate(Output[I], OutputSpec))
180
      return;
181
  }
182
}
183

184
TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
185
                                   const std::vector<TensorSpec> &InputSpecs,
186
                                   const std::vector<TensorSpec> &OutputSpecs,
187
                                   const char *Tags)
188
    : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, OutputSpecs,
189
                                    Tags)) {
190
  if (!Impl->isValid())
191
    Impl.reset();
192
}
193

194
TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {}
195

196
bool TFModelEvaluatorImpl::checkReportAndInvalidate(const TfLiteTensor *Tensor,
197
                                                    const TensorSpec &Spec) {
198
  if (!Tensor) {
199
    errs() << "Could not find TF_Output named: " + Spec.name();
200
    IsValid = false;
201
  }
202
  if (Spec.getTotalTensorBufferSize() != Tensor->bytes)
203
    IsValid = false;
204

205
  // If the total sizes match, there could still be a mismatch in the shape.
206
  // We ignore that for now.
207

208
  return IsValid;
209
}
210

211
std::optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
212
  if (!isValid())
213
    return std::nullopt;
214
  return EvaluationResult(Impl->evaluate());
215
}
216

217
void *TFModelEvaluator::getUntypedInput(size_t Index) {
218
  TfLiteTensor *T = Impl->getInput()[Index];
219
  if (!T)
220
    return nullptr;
221
  return T->data.data;
222
}
223

224
TFModelEvaluator::EvaluationResult::EvaluationResult(
225
    std::unique_ptr<EvaluationResultImpl> Impl)
226
    : Impl(std::move(Impl)) {}
227

228
TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
229
    : Impl(std::move(Other.Impl)) {}
230

231
TFModelEvaluator::EvaluationResult &
232
TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
233
  Impl = std::move(Other.Impl);
234
  return *this;
235
}
236

237
void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
238
  return Impl->getOutput(Index)->data.data;
239
}
240

241
const void *
242
TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
243
  return Impl->getOutput(Index)->data.data;
244
}
245

246
TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
247
TFModelEvaluator::~TFModelEvaluator() {}
248

249
#endif // defined(LLVM_HAVE_TFLITE)
250

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.