llvm-project
249 строк · 8.3 Кб
1//===- TFUtils.cpp - TFLite-based evaluation utilities --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for interfacing with TFLite.
10//
11//===----------------------------------------------------------------------===//
12#include "llvm/Config/config.h"13#if defined(LLVM_HAVE_TFLITE)14
15#include "llvm/ADT/Twine.h"16#include "llvm/Analysis/Utils/TFUtils.h"17#include "llvm/Support/Base64.h"18#include "llvm/Support/CommandLine.h"19#include "llvm/Support/Debug.h"20#include "llvm/Support/JSON.h"21#include "llvm/Support/MemoryBuffer.h"22#include "llvm/Support/Path.h"23#include "llvm/Support/raw_ostream.h"24
25#include "tensorflow/lite/interpreter.h"26#include "tensorflow/lite/kernels/register.h"27#include "tensorflow/lite/model.h"28#include "tensorflow/lite/model_builder.h"29#include "tensorflow/lite/op_resolver.h"30#include "tensorflow/lite/logger.h"31
32#include <cassert>33#include <numeric>34#include <optional>35
36using namespace llvm;37
38namespace llvm {39class EvaluationResultImpl {40public:41EvaluationResultImpl(const std::vector<const TfLiteTensor *> &Outputs)42: Outputs(Outputs){};43
44const TfLiteTensor *getOutput(size_t I) { return Outputs[I]; }45
46EvaluationResultImpl(const EvaluationResultImpl &) = delete;47EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;48
49private:50const std::vector<const TfLiteTensor *> Outputs;51};52
53class TFModelEvaluatorImpl {54public:55TFModelEvaluatorImpl(StringRef SavedModelPath,56const std::vector<TensorSpec> &InputSpecs,57const std::vector<TensorSpec> &OutputSpecs,58const char *Tags);59
60bool isValid() const { return IsValid; }61size_t outputSize() const { return Output.size(); }62
63std::unique_ptr<EvaluationResultImpl> evaluate() {64Interpreter->Invoke();65return std::make_unique<EvaluationResultImpl>(Output);66}67
68const std::vector<TfLiteTensor *> &getInput() const { return Input; }69
70~TFModelEvaluatorImpl();71
72private:73std::unique_ptr<tflite::FlatBufferModel> Model;74
75/// The objects necessary for carrying out an evaluation of the SavedModel.76/// They are expensive to set up, and we maintain them accross all the77/// evaluations of the model.78std::unique_ptr<tflite::Interpreter> Interpreter;79
80/// The input tensors. We set up the tensors once and just mutate theirs81/// scalars before each evaluation. The input tensors keep their value after82/// an evaluation.83std::vector<TfLiteTensor *> Input;84
85/// The output nodes.86std::vector<const TfLiteTensor *> Output;87
88void invalidate() { IsValid = false; }89
90bool IsValid = true;91
92/// Reusable utility for ensuring we can bind the requested Name to a node in93/// the SavedModel Graph.94bool checkReportAndInvalidate(const TfLiteTensor *Tensor,95const TensorSpec &Spec);96};97
98} // namespace llvm99
100TFModelEvaluatorImpl::TFModelEvaluatorImpl(101StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,102const std::vector<TensorSpec> &OutputSpecs, const char *Tags = "serve")103: Input(InputSpecs.size()), Output(OutputSpecs.size()) {104// INFO and DEBUG messages could be numerous and not particularly interesting105tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_WARNING);106// FIXME: make ErrorReporter a member (may also need subclassing107// StatefulErrorReporter) to easily get the latest error status, for108// debugging.109tflite::StderrReporter ErrorReporter;110SmallVector<char, 128> TFLitePathBuff;111llvm::sys::path::append(TFLitePathBuff, SavedModelPath, "model.tflite");112StringRef TFLitePath(TFLitePathBuff.data(), TFLitePathBuff.size());113Model = tflite::FlatBufferModel::BuildFromFile(TFLitePath.str().c_str(),114&ErrorReporter);115if (!Model) {116invalidate();117return;118}119
120tflite::ops::builtin::BuiltinOpResolver Resolver;121tflite::InterpreterBuilder Builder(*Model, Resolver);122Builder(&Interpreter);123
124if (!Interpreter) {125invalidate();126return;127}128
129// We assume the input buffers are valid for the lifetime of the interpreter.130// By default, tflite allocates memory in an arena and will periodically take131// away memory and reallocate it in a different location after evaluations in132// order to improve utilization of the buffers owned in the arena. So, we133// explicitly mark our input buffers as persistent to avoid this behavior.134for (size_t I = 0; I < Interpreter->inputs().size(); ++I)135Interpreter->tensor(I)->allocation_type =136TfLiteAllocationType::kTfLiteArenaRwPersistent;137
138if (Interpreter->AllocateTensors() != TfLiteStatus::kTfLiteOk) {139invalidate();140return;141}142// Known inputs and outputs143StringMap<int> InputsMap;144StringMap<int> OutputsMap;145for (size_t I = 0; I < Interpreter->inputs().size(); ++I)146InputsMap[Interpreter->GetInputName(I)] = I;147for (size_t I = 0; I < Interpreter->outputs().size(); ++I)148OutputsMap[Interpreter->GetOutputName(I)] = I;149
150size_t NumberFeaturesPassed = 0;151for (size_t I = 0; I < InputSpecs.size(); ++I) {152auto &InputSpec = InputSpecs[I];153auto MapI = InputsMap.find(InputSpec.name() + ":" +154std::to_string(InputSpec.port()));155if (MapI == InputsMap.end()) {156Input[I] = nullptr;157continue;158}159Input[I] = Interpreter->tensor(MapI->second);160if (!checkReportAndInvalidate(Input[I], InputSpec))161return;162std::memset(Input[I]->data.data, 0,163InputSpecs[I].getTotalTensorBufferSize());164++NumberFeaturesPassed;165}166
167if (NumberFeaturesPassed < Interpreter->inputs().size()) {168// we haven't passed all the required features to the model, throw an error.169errs() << "Required feature(s) have not been passed to the ML model";170invalidate();171return;172}173
174for (size_t I = 0; I < OutputSpecs.size(); ++I) {175const auto &OutputSpec = OutputSpecs[I];176Output[I] = Interpreter->output_tensor(177OutputsMap[OutputSpec.name() + ":" +178std::to_string(OutputSpec.port())]);179if (!checkReportAndInvalidate(Output[I], OutputSpec))180return;181}182}
183
184TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,185const std::vector<TensorSpec> &InputSpecs,186const std::vector<TensorSpec> &OutputSpecs,187const char *Tags)188: Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, OutputSpecs,189Tags)) {190if (!Impl->isValid())191Impl.reset();192}
193
194TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {}195
196bool TFModelEvaluatorImpl::checkReportAndInvalidate(const TfLiteTensor *Tensor,197const TensorSpec &Spec) {198if (!Tensor) {199errs() << "Could not find TF_Output named: " + Spec.name();200IsValid = false;201}202if (Spec.getTotalTensorBufferSize() != Tensor->bytes)203IsValid = false;204
205// If the total sizes match, there could still be a mismatch in the shape.206// We ignore that for now.207
208return IsValid;209}
210
211std::optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {212if (!isValid())213return std::nullopt;214return EvaluationResult(Impl->evaluate());215}
216
217void *TFModelEvaluator::getUntypedInput(size_t Index) {218TfLiteTensor *T = Impl->getInput()[Index];219if (!T)220return nullptr;221return T->data.data;222}
223
224TFModelEvaluator::EvaluationResult::EvaluationResult(225std::unique_ptr<EvaluationResultImpl> Impl)226: Impl(std::move(Impl)) {}227
228TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)229: Impl(std::move(Other.Impl)) {}230
231TFModelEvaluator::EvaluationResult &232TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {233Impl = std::move(Other.Impl);234return *this;235}
236
237void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {238return Impl->getOutput(Index)->data.data;239}
240
241const void *242TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {243return Impl->getOutput(Index)->data.data;244}
245
246TFModelEvaluator::EvaluationResult::~EvaluationResult() {}247TFModelEvaluator::~TFModelEvaluator() {}248
249#endif // defined(LLVM_HAVE_TFLITE)250