CSS-LM
460 строк · 17.4 Кб
1from argparse import ArgumentParser
2from os import listdir, makedirs
3from pathlib import Path
4from typing import Dict, List, Optional, Tuple
5
6from packaging.version import Version, parse
7
8from transformers import is_tf_available, is_torch_available
9from transformers.file_utils import ModelOutput
10from transformers.pipelines import Pipeline, pipeline
11from transformers.tokenization_utils import BatchEncoding
12
13
14# This is the minimal required version to
15# support some ONNX Runtime features
16ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
17
18
19SUPPORTED_PIPELINES = [
20"feature-extraction",
21"ner",
22"sentiment-analysis",
23"fill-mask",
24"question-answering",
25"text-generation",
26"translation_en_to_fr",
27"translation_en_to_de",
28"translation_en_to_ro",
29]
30
31
32class OnnxConverterArgumentParser(ArgumentParser):
33"""
34Wraps all the script arguments supported to export transformers models to ONNX IR
35"""
36
37def __init__(self):
38super().__init__("ONNX Converter")
39
40self.add_argument(
41"--pipeline", type=str, choices=SUPPORTED_PIPELINES, default="feature-extraction",
42)
43self.add_argument(
44"--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)",
45)
46self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
47self.add_argument(
48"--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model",
49)
50self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
51self.add_argument(
52"--check-loading", action="store_true", help="Check ONNX is able to load the model",
53)
54self.add_argument(
55"--use-external-format", action="store_true", help="Allow exporting model >= than 2Gb",
56)
57self.add_argument(
58"--quantize", action="store_true", help="Quantize the neural network to be run with int8",
59)
60self.add_argument("output")
61
62
63def generate_identified_filename(filename: Path, identifier: str) -> Path:
64"""
65Append a string-identifier at the end (before the extension, if any) to the provided filepath.
66Args:
67filename: pathlib.Path The actual path object we would like to add an identifier suffix
68identifier: The suffix to add
69
70Returns: String with concatenated indentifier at the end of the filename
71"""
72return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
73
74
75def check_onnxruntime_requirements(minimum_version: Version):
76"""
77Check onnxruntime is installed and if the installed version match is recent enough.
78Raises:
79ImportError: If onnxruntime is not installed or too old version is found
80"""
81try:
82import onnxruntime
83
84# Parse the version of the installed onnxruntime
85ort_version = parse(onnxruntime.__version__)
86
87# We require 1.4.0 minimum
88if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
89raise ImportError(
90f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
91f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
92f"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
93)
94
95except ImportError:
96raise ImportError(
97"onnxruntime doesn't seem to be currently installed. "
98"Please install the onnxruntime by running `pip install onnxruntime`"
99" and relaunch the conversion."
100)
101
102
103def ensure_valid_input(model, tokens, input_names):
104"""
105Ensure input are presented in the correct order, without any None
106Args:
107model: The model used to forward the input data
108tokens: BatchEncoding holding the input data
109input_names: The name of the inputs
110
111Returns: Tuple
112
113"""
114print("Ensuring inputs are in correct order")
115
116model_args_name = model.forward.__code__.co_varnames
117model_args, ordered_input_names = [], []
118for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
119if arg_name in input_names:
120ordered_input_names.append(arg_name)
121model_args.append(tokens[arg_name])
122else:
123print(f"{arg_name} is not present in the generated input list.")
124break
125
126print("Generated inputs order: {}".format(ordered_input_names))
127return ordered_input_names, tuple(model_args)
128
129
130def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
131"""
132Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model.
133Args:
134nlp: The pipeline object holding the model to be exported
135framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
136
137Returns:
138- List of the inferred input variable names
139- List of the inferred output variable names
140- Dictionary with input/output variables names as key and shape tensor as value
141- a BatchEncoding reference which was used to infer all the above information
142"""
143
144def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
145if isinstance(tensor, (tuple, list)):
146return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
147
148else:
149# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
150axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
151if is_input:
152if len(tensor.shape) == 2:
153axes[1] = "sequence"
154else:
155raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
156else:
157seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
158axes.update({dim: "sequence" for dim in seq_axes})
159
160print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
161return axes
162
163tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
164seq_len = tokens.input_ids.shape[-1]
165outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
166if isinstance(outputs, ModelOutput):
167outputs = outputs.to_tuple()
168if not isinstance(outputs, (list, tuple)):
169outputs = (outputs,)
170
171# Generate input names & axes
172input_vars = list(tokens.keys())
173input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
174
175# flatten potentially grouped outputs (past for gpt2, attentions)
176outputs_flat = []
177for output in outputs:
178if isinstance(output, (tuple, list)):
179outputs_flat.extend(output)
180else:
181outputs_flat.append(output)
182
183# Generate output names & axes
184output_names = [f"output_{i}" for i in range(len(outputs_flat))]
185output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
186
187# Create the aggregated axes representation
188dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
189return input_vars, output_names, dynamic_axes, tokens
190
191
192def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
193"""
194Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model)
195Args:
196pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
197framework: The actual model to convert the pipeline from ("pt" or "tf")
198model: The model name which will be loaded by the pipeline
199tokenizer: The tokenizer name which will be loaded by the pipeline, defaut to the model's value
200
201Returns: Pipeline object
202
203"""
204# If no tokenizer provided
205if tokenizer is None:
206tokenizer = model
207
208# Check the wanted framework is available
209if framework == "pt" and not is_torch_available():
210raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
211if framework == "tf" and not is_tf_available():
212raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
213
214print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
215
216# Allocate tokenizer and model
217return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework)
218
219
220def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
221"""
222Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR)
223Args:
224nlp: The pipeline to be exported
225opset: The actual version of the ONNX operator set to use
226output: Path where will be stored the generated ONNX model
227use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
228
229Returns:
230
231"""
232if not is_torch_available():
233raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
234
235import torch
236from torch.onnx import export
237
238print(f"Using framework PyTorch: {torch.__version__}")
239
240with torch.no_grad():
241input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
242ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
243
244export(
245nlp.model,
246model_args,
247f=output.as_posix(),
248input_names=ordered_input_names,
249output_names=output_names,
250dynamic_axes=dynamic_axes,
251do_constant_folding=True,
252use_external_data_format=use_external_format,
253enable_onnx_checker=True,
254opset_version=opset,
255)
256
257
258def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
259"""
260Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
261Args:
262nlp: The pipeline to be exported
263opset: The actual version of the ONNX operator set to use
264output: Path where will be stored the generated ONNX model
265
266Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow
267
268"""
269if not is_tf_available():
270raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
271
272print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
273
274try:
275import tensorflow as tf
276from keras2onnx import convert_keras, save_model, __version__ as k2ov
277
278print(f"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}")
279
280# Build
281input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
282
283# Forward
284nlp.model.predict(tokens.data)
285onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)
286save_model(onnx_model, output.as_posix())
287
288except ImportError as e:
289raise Exception(f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first.")
290
291
292def convert(
293framework: str,
294model: str,
295output: Path,
296opset: int,
297tokenizer: Optional[str] = None,
298use_external_format: bool = False,
299pipeline_name: str = "feature-extraction",
300):
301"""
302Convert the pipeline object to the ONNX Intermediate Representation (IR) format.
303Args:
304framework: The framework the pipeline is backed by ("pt" or "tf")
305model: The name of the model to load for the pipeline
306output: The path where the ONNX graph will be stored
307opset: The actual version of the ONNX operator set to use
308tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
309use_external_format: Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
310pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
311
312Returns:
313
314"""
315print(f"ONNX opset version set to: {opset}")
316
317# Load the pipeline
318nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer)
319
320if not output.parent.exists():
321print(f"Creating folder {output.parent}")
322makedirs(output.parent.as_posix())
323elif len(listdir(output.parent.as_posix())) > 0:
324raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
325
326# Export the graph
327if framework == "pt":
328convert_pytorch(nlp, opset, output, use_external_format)
329else:
330convert_tensorflow(nlp, opset, output)
331
332
333def optimize(onnx_model_path: Path) -> Path:
334"""
335Load the model at the specified path and let onnxruntime look at transformations on the graph
336to enable all the optimizations possible
337Args:
338onnx_model_path: filepath where the model binary description is stored
339
340Returns: Path where the optimized model binary description has been saved
341
342"""
343from onnxruntime import SessionOptions, InferenceSession
344
345# Generate model name with suffix "optimized"
346opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
347sess_option = SessionOptions()
348sess_option.optimized_model_filepath = opt_model_path.as_posix()
349_ = InferenceSession(onnx_model_path.as_posix(), sess_option)
350
351print(f"Optimized model has been written at {opt_model_path}: \N{heavy check mark}")
352print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
353
354return opt_model_path
355
356
357def quantize(onnx_model_path: Path) -> Path:
358"""
359Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU.
360Args:
361onnx_model_path: Path to location the exported ONNX model is stored
362
363Returns: The Path generated for the quantized
364"""
365try:
366import onnx
367from onnxruntime.quantization import quantize, QuantizationMode
368
369onnx_model = onnx.load(onnx_model_path.as_posix())
370
371# Discussed with @yufenglee from ONNX runtime, this will be address in the next release of onnxruntime
372print(
373"As of onnxruntime 1.4.0, models larger than 2GB will fail to quantize due to protobuf constraint.\n"
374"This limitation will be removed in the next release of onnxruntime."
375)
376
377quantized_model = quantize(
378model=onnx_model, quantization_mode=QuantizationMode.IntegerOps, force_fusions=True, symmetric_weight=True,
379)
380
381# Append "-quantized" at the end of the model's name
382quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
383
384# Save model
385print(f"Quantized model has been written at {quantized_model_path}: \N{heavy check mark}")
386onnx.save_model(quantized_model, quantized_model_path.as_posix())
387
388return quantized_model_path
389except Exception as ie:
390print(f"Error while quantizing the model:\n{str(ie)}")
391
392
393def verify(path: Path):
394from onnxruntime import InferenceSession, SessionOptions
395from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
396
397print(f"Checking ONNX model loading from: {path} ...")
398try:
399onnx_options = SessionOptions()
400_ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
401print(f"Model {path} correctly loaded: \N{heavy check mark}")
402except RuntimeException as re:
403print(f"Error while loading the model {re}: \N{heavy ballot x}")
404
405
406if __name__ == "__main__":
407parser = OnnxConverterArgumentParser()
408args = parser.parse_args()
409
410# Make sure output is absolute path
411args.output = Path(args.output).absolute()
412
413try:
414print("\n====== Converting model to ONNX ======")
415# Convert
416convert(
417args.framework,
418args.model,
419args.output,
420args.opset,
421args.tokenizer,
422args.use_external_format,
423args.pipeline,
424)
425
426if args.quantize:
427# Ensure requirements for quantization on onnxruntime is met
428check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
429
430# onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch
431if args.framework == "tf":
432print(
433"\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
434"\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
435"\t For more information, please refer to the onnxruntime documentation:\n"
436"\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
437)
438
439print("\n====== Optimizing ONNX model ======")
440
441# Quantization works best when using the optimized version of the model
442args.optimized_output = optimize(args.output)
443
444# Do the quantization on the right graph
445args.quantized_output = quantize(args.optimized_output)
446
447# And verify
448if args.check_loading:
449print("\n====== Check exported ONNX model(s) ======")
450verify(args.output)
451
452if hasattr(args, "optimized_output"):
453verify(args.optimized_output)
454
455if hasattr(args, "quantized_output"):
456verify(args.quantized_output)
457
458except Exception as e:
459print(f"Error while converting the model: {e}")
460exit(1)
461