txtai

Форк
0
/
18_Export_and_run_models_with_ONNX.ipynb 
1352 строки · 46.1 Кб
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": [],
7
      "collapsed_sections": []
8
    },
9
    "kernelspec": {
10
      "name": "python3",
11
      "display_name": "Python 3"
12
    },
13
    "accelerator": "GPU"
14
  },
15
  "cells": [
16
    {
17
      "cell_type": "markdown",
18
      "metadata": {
19
        "id": "4Pjmz-RORV8E"
20
      },
21
      "source": [
22
        "# Export and run models with ONNX\n",
23
        "\n",
24
        "The [ONNX runtime](https://onnx.ai/) provides a common serialization format for machine learning models. ONNX supports a number of [different platforms/languages](https://onnxruntime.ai/docs/how-to/install.html#requirements) and has features built in to help reduce inference time. \n",
25
        "\n",
26
        "PyTorch has robust support for exporting Torch models to ONNX. This enables exporting Hugging Face Transformer and/or other downstream models directly to ONNX. \n",
27
        "\n",
28
        "ONNX opens an avenue for direct inference using a number of languages and platforms. For example, a model could be run directly on Android to limit data sent to a third party service. ONNX is an exciting development with a lot of promise. Microsoft has also released [Hummingbird](https://github.com/microsoft/hummingbird) which enables exporting traditional models (sklearn, decision trees, logistical regression..) to ONNX. \n",
29
        "\n",
30
        "This notebook will cover how to export models to ONNX using txtai. These models will then be directly run in Python, JavaScript, Java and Rust. Currently, txtai supports all these languages through it's API and that is still the recommended approach. "
31
      ]
32
    },
33
    {
34
      "cell_type": "markdown",
35
      "metadata": {
36
        "id": "Dk31rbYjSTYm"
37
      },
38
      "source": [
39
        "# Install dependencies\n",
40
        "\n",
41
        "Install `txtai` and all dependencies. Since this notebook uses ONNX quantization, we need to install the pipeline extras package."
42
      ]
43
    },
44
    {
45
      "cell_type": "code",
46
      "metadata": {
47
        "id": "XMQuuun2R06J"
48
      },
49
      "source": [
50
        "%%capture\n",
51
        "!pip install datasets git+https://github.com/neuml/txtai#egg=txtai[pipeline]"
52
      ],
53
      "execution_count": 25,
54
      "outputs": []
55
    },
56
    {
57
      "cell_type": "markdown",
58
      "metadata": {
59
        "id": "PNPJ95cdTKSS"
60
      },
61
      "source": [
62
        "# Run a model with ONNX\n",
63
        "\n",
64
        "Let's get right to it! The following example exports a sentiment analysis model to ONNX and runs an inference session.\n",
65
        "\n"
66
      ]
67
    },
68
    {
69
      "cell_type": "code",
70
      "metadata": {
71
        "id": "USb4JXZHxqTA",
72
        "colab": {
73
          "base_uri": "https://localhost:8080/"
74
        },
75
        "outputId": "28d3e70e-efa9-4b07-a602-6ffd89d1279f"
76
      },
77
      "source": [
78
        "import numpy as np\n",
79
        "\n",
80
        "from onnxruntime import InferenceSession, SessionOptions\n",
81
        "from transformers import AutoTokenizer\n",
82
        "from txtai.pipeline import HFOnnx\n",
83
        "\n",
84
        "# Normalize logits using sigmoid function\n",
85
        "sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x))\n",
86
        "\n",
87
        "# Export to ONNX\n",
88
        "onnx = HFOnnx()\n",
89
        "model = onnx(\"distilbert-base-uncased-finetuned-sst-2-english\", \"text-classification\")\n",
90
        "\n",
91
        "# Start inference session\n",
92
        "options = SessionOptions()\n",
93
        "session = InferenceSession(model, options)\n",
94
        "\n",
95
        "# Tokenize\n",
96
        "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\")\n",
97
        "tokens = tokenizer([\"I am happy\", \"I am mad\"], return_tensors=\"np\")\n",
98
        "\n",
99
        "# Print results\n",
100
        "outputs = session.run(None, dict(tokens))\n",
101
        "print(sigmoid(outputs[0]))"
102
      ],
103
      "execution_count": 26,
104
      "outputs": [
105
        {
106
          "output_type": "stream",
107
          "name": "stdout",
108
          "text": [
109
            "[[0.01295124 0.9909526 ]\n",
110
            " [0.9874723  0.0297817 ]]\n"
111
          ]
112
        }
113
      ]
114
    },
115
    {
116
      "cell_type": "markdown",
117
      "metadata": {
118
        "id": "jkmQoQvlmHfQ"
119
      },
120
      "source": [
121
        "And just like that, there are results! The text classification model is judging sentiment using two labels, 0 for negative to 1 for positive. The results above shows the probability of each label per text snippet.\n",
122
        "\n",
123
        "The ONNX pipeline loads the model, converts the graph to ONNX and returns. Note that no output file was provided, in this case the ONNX model is returned as a byte array. If an output file is provided, this method returns the output path."
124
      ]
125
    },
126
    {
127
      "cell_type": "markdown",
128
      "metadata": {
129
        "id": "yFAOHVmXml8o"
130
      },
131
      "source": [
132
        "# Train and Export a model for Text Classification\n",
133
        "\n",
134
        "Next we'll combine the ONNX pipeline with a Trainer pipeline to create a \"train and export to ONNX\" workflow."
135
      ]
136
    },
137
    {
138
      "cell_type": "code",
139
      "metadata": {
140
        "colab": {
141
          "base_uri": "https://localhost:8080/",
142
          "height": 579
143
        },
144
        "id": "Wh8TkszumlIe",
145
        "outputId": "864f2074-ae50-40d6-bc34-2b1d86a71488"
146
      },
147
      "source": [
148
        "from datasets import load_dataset\n",
149
        "from txtai.pipeline import HFTrainer\n",
150
        "\n",
151
        "trainer = HFTrainer()\n",
152
        "\n",
153
        "# Hugging Face dataset\n",
154
        "ds = load_dataset(\"glue\", \"sst2\")\n",
155
        "data = ds[\"train\"].select(range(10000)).flatten_indices()\n",
156
        "\n",
157
        "# Train new model using 10,000 SST2 records (in-memory)\n",
158
        "model, tokenizer = trainer(\"google/electra-base-discriminator\", data, columns=(\"sentence\", \"label\"))\n",
159
        "\n",
160
        "# Export model trained in-memory to ONNX (still in-memory)\n",
161
        "output = onnx((model, tokenizer), \"text-classification\", quantize=True)\n",
162
        "\n",
163
        "# Start inference session\n",
164
        "options = SessionOptions()\n",
165
        "session = InferenceSession(output, options)\n",
166
        "\n",
167
        "# Tokenize\n",
168
        "tokens = tokenizer([\"I am happy\", \"I am mad\"], return_tensors=\"np\")\n",
169
        "\n",
170
        "# Print results\n",
171
        "outputs = session.run(None, dict(tokens))\n",
172
        "print(sigmoid(outputs[0]))"
173
      ],
174
      "execution_count": 27,
175
      "outputs": [
176
        {
177
          "output_type": "stream",
178
          "name": "stderr",
179
          "text": [
180
            "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e28d0e20a676bad0.arrow\n",
181
            "WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-d7b5d80ca22204f9.arrow\n",
182
            "Some weights of the model checkpoint at google/electra-base-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias']\n",
183
            "- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
184
            "- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
185
            "Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.dense.bias']\n",
186
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
187
            "/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
188
            "  FutureWarning,\n",
189
            "You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
190
          ]
191
        },
192
        {
193
          "output_type": "display_data",
194
          "data": {
195
            "text/html": [
196
              "\n",
197
              "    <div>\n",
198
              "      \n",
199
              "      <progress value='3750' max='3750' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
200
              "      [3750/3750 07:56, Epoch 3/3]\n",
201
              "    </div>\n",
202
              "    <table border=\"1\" class=\"dataframe\">\n",
203
              "  <thead>\n",
204
              " <tr style=\"text-align: left;\">\n",
205
              "      <th>Step</th>\n",
206
              "      <th>Training Loss</th>\n",
207
              "    </tr>\n",
208
              "  </thead>\n",
209
              "  <tbody>\n",
210
              "    <tr>\n",
211
              "      <td>500</td>\n",
212
              "      <td>0.396800</td>\n",
213
              "    </tr>\n",
214
              "    <tr>\n",
215
              "      <td>1000</td>\n",
216
              "      <td>0.330900</td>\n",
217
              "    </tr>\n",
218
              "    <tr>\n",
219
              "      <td>1500</td>\n",
220
              "      <td>0.232400</td>\n",
221
              "    </tr>\n",
222
              "    <tr>\n",
223
              "      <td>2000</td>\n",
224
              "      <td>0.188200</td>\n",
225
              "    </tr>\n",
226
              "    <tr>\n",
227
              "      <td>2500</td>\n",
228
              "      <td>0.173600</td>\n",
229
              "    </tr>\n",
230
              "    <tr>\n",
231
              "      <td>3000</td>\n",
232
              "      <td>0.068600</td>\n",
233
              "    </tr>\n",
234
              "    <tr>\n",
235
              "      <td>3500</td>\n",
236
              "      <td>0.069800</td>\n",
237
              "    </tr>\n",
238
              "  </tbody>\n",
239
              "</table><p>"
240
            ]
241
          },
242
          "metadata": {}
243
        },
244
        {
245
          "output_type": "stream",
246
          "name": "stdout",
247
          "text": [
248
            "[[0.01525715 0.975399  ]\n",
249
            " [0.97395283 0.04432926]]\n"
250
          ]
251
        }
252
      ]
253
    },
254
    {
255
      "cell_type": "markdown",
256
      "metadata": {
257
        "id": "lE7dPj3tsn5S"
258
      },
259
      "source": [
260
        "The results are similar to the previous step, although this model is only trained on a fraction of the sst2 dataset. Lets save this model for later."
261
      ]
262
    },
263
    {
264
      "cell_type": "code",
265
      "metadata": {
266
        "id": "Q_kAFYd_s_Bi"
267
      },
268
      "source": [
269
        "onnx = HFOnnx()\n",
270
        "text = onnx((model, tokenizer), \"text-classification\", \"text-classify.onnx\", quantize=True)"
271
      ],
272
      "execution_count": 29,
273
      "outputs": []
274
    },
275
    {
276
      "cell_type": "markdown",
277
      "metadata": {
278
        "id": "ugNZO4c-uAS-"
279
      },
280
      "source": [
281
        "# Export a Sentence Embeddings model\n",
282
        "\n",
283
        "The ONNX pipeline also supports exporting sentence embeddings models trained with the [sentence-transformers](https://github.com/UKPLab/sentence-transformers) package. "
284
      ]
285
    },
286
    {
287
      "cell_type": "code",
288
      "metadata": {
289
        "id": "x9B7qOk_uQRN"
290
      },
291
      "source": [
292
        "embeddings = onnx(\"sentence-transformers/paraphrase-MiniLM-L6-v2\", \"pooling\", \"embeddings.onnx\", quantize=True)"
293
      ],
294
      "execution_count": 30,
295
      "outputs": []
296
    },
297
    {
298
      "cell_type": "markdown",
299
      "metadata": {
300
        "id": "rirMSM2kvgJF"
301
      },
302
      "source": [
303
        "Now let's run the model with ONNX."
304
      ]
305
    },
306
    {
307
      "cell_type": "code",
308
      "metadata": {
309
        "id": "6MBraENcu8Oz",
310
        "colab": {
311
          "base_uri": "https://localhost:8080/"
312
        },
313
        "outputId": "ba4528d2-6d6a-4181-e9c4-3d2a98d6663a"
314
      },
315
      "source": [
316
        "from sklearn.metrics.pairwise import cosine_similarity\n",
317
        "\n",
318
        "options = SessionOptions()\n",
319
        "session = InferenceSession(embeddings, options)\n",
320
        "\n",
321
        "tokens = tokenizer([\"I am happy\", \"I am glad\"], return_tensors=\"np\")\n",
322
        "\n",
323
        "outputs = session.run(None, dict(tokens))[0]\n",
324
        "\n",
325
        "print(cosine_similarity(outputs))"
326
      ],
327
      "execution_count": 31,
328
      "outputs": [
329
        {
330
          "output_type": "stream",
331
          "name": "stdout",
332
          "text": [
333
            "[[0.99999994 0.8474637 ]\n",
334
            " [0.8474637  0.9999997 ]]\n"
335
          ]
336
        }
337
      ]
338
    },
339
    {
340
      "cell_type": "markdown",
341
      "metadata": {
342
        "id": "pwgU4vu8vk0T"
343
      },
344
      "source": [
345
        "The code above tokenizes two separate text snippets (\"I am happy\" and \"I am glad\") and runs it through the ONNX model. \n",
346
        "\n",
347
        "This outputs two embeddings arrays and those arrays are compared using cosine similarity. As we can see, the two text snippets have close semantic meaning."
348
      ]
349
    },
350
    {
351
      "cell_type": "markdown",
352
      "metadata": {
353
        "id": "t_OQaQeIb7UB"
354
      },
355
      "source": [
356
        "# Load an ONNX model with txtai\n",
357
        "\n",
358
        "txtai has built-in support for ONNX models. Loading an ONNX model is seamless and Embeddings and Pipelines support it. The following section shows how to load a classification pipeline and embeddings model backed by ONNX."
359
      ]
360
    },
361
    {
362
      "cell_type": "code",
363
      "metadata": {
364
        "id": "vhsFzCRBby-h",
365
        "colab": {
366
          "base_uri": "https://localhost:8080/"
367
        },
368
        "outputId": "745d69b0-035b-4e44-a881-57f9ece171ab"
369
      },
370
      "source": [
371
        "from txtai.embeddings import Embeddings\n",
372
        "from txtai.pipeline import Labels\n",
373
        "\n",
374
        "labels = Labels((\"text-classify.onnx\", \"google/electra-base-discriminator\"), dynamic=False)\n",
375
        "print(labels([\"I am happy\", \"I am mad\"]))\n",
376
        "\n",
377
        "embeddings = Embeddings({\"path\": \"embeddings.onnx\", \"tokenizer\": \"sentence-transformers/paraphrase-MiniLM-L6-v2\"})\n",
378
        "print(embeddings.similarity(\"I am happy\", [\"I am glad\"]))"
379
      ],
380
      "execution_count": 32,
381
      "outputs": [
382
        {
383
          "output_type": "stream",
384
          "name": "stdout",
385
          "text": [
386
            "[[(1, 0.999687910079956), (0, 0.0003121310146525502)], [(0, 0.9991233944892883), (1, 0.0008765518432483077)]]\n",
387
            "[(0, 0.8298245072364807)]\n"
388
          ]
389
        }
390
      ]
391
    },
392
    {
393
      "cell_type": "markdown",
394
      "metadata": {
395
        "id": "Xx8G29hkwdNY"
396
      },
397
      "source": [
398
        "# JavaScript\n",
399
        "\n",
400
        "So far, we've exported models to ONNX and run them through Python. This already has a lot of advantages, which include fast inference times, quantization and less software dependencies. But ONNX really shines when we run a model trained in Python in other languages/platforms.\n",
401
        "\n",
402
        "Let's try running the models trained above in JavaScript. First step is getting the Node.js environment and dependencies setup.\n"
403
      ]
404
    },
405
    {
406
      "cell_type": "code",
407
      "metadata": {
408
        "id": "_RK79O9c4Z_y"
409
      },
410
      "source": [
411
        "%%capture\n",
412
        "import os\n",
413
        "\n",
414
        "os.chdir(\"/content\")\n",
415
        "!mkdir js\n",
416
        "os.chdir(\"/content/js\")\n",
417
        "\n",
418
        "# Copy ONNX models\n",
419
        "!cp ../text-classify.onnx .\n",
420
        "!cp ../embeddings.onnx .\n",
421
        "\n",
422
        "# Get tokenizers project\n",
423
        "!git clone https://github.com/huggingface/tokenizers.git\n",
424
        "\n",
425
        "os.chdir(\"/content/js/tokenizers/bindings/node\")\n",
426
        "\n",
427
        "# Install Rust to compile tokenizer bindings\n",
428
        "!apt-get install rustc cargo\n",
429
        "\n",
430
        "# Build tokenizers package locally as binary version on npm doesn't work for latest version of Node.js\n",
431
        "!npm install --also=dev\n",
432
        "!npm run dev\n",
433
        "\n",
434
        "os.chdir(\"/content/js\")"
435
      ],
436
      "execution_count": 33,
437
      "outputs": []
438
    },
439
    {
440
      "cell_type": "code",
441
      "metadata": {
442
        "id": "0HtVEl74xrZ7",
443
        "colab": {
444
          "base_uri": "https://localhost:8080/"
445
        },
446
        "outputId": "7ffea868-6dd7-4603-e04c-ce8d4c557ff6"
447
      },
448
      "source": [
449
        "%%writefile package.json\n",
450
        "{\n",
451
        "  \"name\": \"onnx-test\",\n",
452
        "  \"private\": true,\n",
453
        "  \"version\": \"1.0.0\",\n",
454
        "  \"description\": \"ONNX Runtime Node.js test\",\n",
455
        "  \"main\": \"index.js\",\n",
456
        "  \"dependencies\": {\n",
457
        "    \"onnxruntime-node\": \">=1.12.1\",\n",
458
        "    \"tokenizers\": \"file:tokenizers/bindings/node\"\n",
459
        "  }\n",
460
        "}"
461
      ],
462
      "execution_count": 34,
463
      "outputs": [
464
        {
465
          "output_type": "stream",
466
          "name": "stdout",
467
          "text": [
468
            "Writing package.json\n"
469
          ]
470
        }
471
      ]
472
    },
473
    {
474
      "cell_type": "code",
475
      "source": [
476
        "%%capture\n",
477
        "\n",
478
        "# Install all dependencies\n",
479
        "!npm install"
480
      ],
481
      "metadata": {
482
        "id": "4naPtk-iBI-g"
483
      },
484
      "execution_count": 35,
485
      "outputs": []
486
    },
487
    {
488
      "cell_type": "markdown",
489
      "metadata": {
490
        "id": "At85iA8U63iV"
491
      },
492
      "source": [
493
        "Next we'll write the inference code in JavaScript to an index.js file."
494
      ]
495
    },
496
    {
497
      "cell_type": "code",
498
      "metadata": {
499
        "id": "RImohEnFyFg0",
500
        "colab": {
501
          "base_uri": "https://localhost:8080/"
502
        },
503
        "outputId": "094ef937-0650-4d5b-f4e3-79f114bd9807",
504
        "cellView": "form"
505
      },
506
      "source": [
507
        "#@title\n",
508
        "%%writefile index.js\n",
509
        "const ort = require('onnxruntime-node');\n",
510
        "const { promisify } = require('util');\n",
511
        "const { Tokenizer } = require(\"tokenizers/dist/bindings/tokenizer\");\n",
512
        "\n",
513
        "function sigmoid(data) {\n",
514
        "    return data.map(x => 1 / (1 + Math.exp(-x)))\n",
515
        "}\n",
516
        "\n",
517
        "function softmax(data) { \n",
518
        "    return data.map(x => Math.exp(x) / (data.map(y => Math.exp(y))).reduce((a,b) => a+b)) \n",
519
        "}\n",
520
        "\n",
521
        "function similarity(v1, v2) {\n",
522
        "    let dot = 0.0;\n",
523
        "    let norm1 = 0.0;\n",
524
        "    let norm2 = 0.0;\n",
525
        "\n",
526
        "    for (let x = 0; x < v1.length; x++) {\n",
527
        "        dot += v1[x] * v2[x];\n",
528
        "        norm1 += Math.pow(v1[x], 2);\n",
529
        "        norm2 += Math.pow(v2[x], 2);\n",
530
        "    }\n",
531
        "\n",
532
        "    return dot / (Math.sqrt(norm1) * Math.sqrt(norm2));\n",
533
        "}\n",
534
        "\n",
535
        "function tokenizer() {\n",
536
        "    let tokenizer = Tokenizer.fromPretrained(\"bert-base-uncased\");\n",
537
        "    return promisify(tokenizer.encode.bind(tokenizer));\n",
538
        "}\n",
539
        "\n",
540
        "async function predict(session, text) {\n",
541
        "    try {\n",
542
        "        // Tokenize input\n",
543
        "        let encode = tokenizer();\n",
544
        "        let output = await encode(text);\n",
545
        "\n",
546
        "        let ids = output.getIds().map(x => BigInt(x))\n",
547
        "        let mask = output.getAttentionMask().map(x => BigInt(x))\n",
548
        "        let tids = output.getTypeIds().map(x => BigInt(x))\n",
549
        "\n",
550
        "        // Convert inputs to tensors    \n",
551
        "        let tensorIds = new ort.Tensor('int64', BigInt64Array.from(ids), [1, ids.length]);\n",
552
        "        let tensorMask = new ort.Tensor('int64', BigInt64Array.from(mask), [1, mask.length]);\n",
553
        "        let tensorTids = new ort.Tensor('int64', BigInt64Array.from(tids), [1, tids.length]);\n",
554
        "\n",
555
        "        let inputs = null;\n",
556
        "        if (session.inputNames.length > 2) {\n",
557
        "            inputs = { input_ids: tensorIds, attention_mask: tensorMask, token_type_ids: tensorTids};\n",
558
        "        }\n",
559
        "        else {\n",
560
        "            inputs = { input_ids: tensorIds, attention_mask: tensorMask};\n",
561
        "        }\n",
562
        "\n",
563
        "        return await session.run(inputs);\n",
564
        "    } catch (e) {\n",
565
        "        console.error(`failed to inference ONNX model: ${e}.`);\n",
566
        "    }\n",
567
        "}\n",
568
        "\n",
569
        "async function main() {\n",
570
        "    let args = process.argv.slice(2);\n",
571
        "    if (args.length > 1) {\n",
572
        "        // Run sentence embeddings\n",
573
        "        const session = await ort.InferenceSession.create('./embeddings.onnx');\n",
574
        "\n",
575
        "        let v1 = await predict(session, args[0]);\n",
576
        "        let v2 = await predict(session, args[1]);\n",
577
        "\n",
578
        "        // Unpack results\n",
579
        "        v1 = v1.embeddings.data;\n",
580
        "        v2 = v2.embeddings.data;\n",
581
        "\n",
582
        "        // Print similarity\n",
583
        "        console.log(similarity(Array.from(v1), Array.from(v2)));\n",
584
        "    }\n",
585
        "    else {\n",
586
        "        // Run text classifier\n",
587
        "        const session = await ort.InferenceSession.create('./text-classify.onnx');\n",
588
        "        let results = await predict(session, args[0]);\n",
589
        "\n",
590
        "        // Normalize results using softmax and print\n",
591
        "        console.log(softmax(results.logits.data));\n",
592
        "    }\n",
593
        "}\n",
594
        "\n",
595
        "main();"
596
      ],
597
      "execution_count": 36,
598
      "outputs": [
599
        {
600
          "output_type": "stream",
601
          "name": "stdout",
602
          "text": [
603
            "Writing index.js\n"
604
          ]
605
        }
606
      ]
607
    },
608
    {
609
      "cell_type": "markdown",
610
      "metadata": {
611
        "id": "rZI9PJzi6_bO"
612
      },
613
      "source": [
614
        "## Run Text Classification in JavaScript with ONNX"
615
      ]
616
    },
617
    {
618
      "cell_type": "code",
619
      "metadata": {
620
        "id": "bdz68KZT1Jfm",
621
        "colab": {
622
          "base_uri": "https://localhost:8080/"
623
        },
624
        "outputId": "48c4a427-3108-436c-eac9-564835ea061c"
625
      },
626
      "source": [
627
        "!node . \"I am happy\"\n",
628
        "!node . \"I am mad\""
629
      ],
630
      "execution_count": 37,
631
      "outputs": [
632
        {
633
          "output_type": "stream",
634
          "name": "stdout",
635
          "text": [
636
            "Float32Array(2) [ \u001b[33m0.0003121308400295675\u001b[39m, \u001b[33m0.9996878504753113\u001b[39m ]\n",
637
            "Float32Array(2) [ \u001b[33m0.9991234540939331\u001b[39m, \u001b[33m0.0008765519596636295\u001b[39m ]\n"
638
          ]
639
        }
640
      ]
641
    },
642
    {
643
      "cell_type": "markdown",
644
      "metadata": {
645
        "id": "swSEmqto33VP"
646
      },
647
      "source": [
648
        "First off, have to say this is 🔥🔥🔥! Just amazing that this model can be fully run in JavaScript. It's a great time to be in NLP!\n",
649
        "\n",
650
        "The steps above installed a JavaScript environment with dependencies to run ONNX and tokenize data in JavaScript. The text classification model previously created is loaded into the JavaScript ONNX runtime and inference is run.\n",
651
        "\n",
652
        "As a reminder, the text classification model is judging sentiment using two labels, 0 for negative to 1 for positive. The results above shows the probability of each label per text snippet."
653
      ]
654
    },
655
    {
656
      "cell_type": "markdown",
657
      "metadata": {
658
        "id": "5Az9YaDc6u9P"
659
      },
660
      "source": [
661
        "## Build sentence embeddings and compare similarity in JavaScript with ONNX"
662
      ]
663
    },
664
    {
665
      "cell_type": "code",
666
      "metadata": {
667
        "id": "10jcUbUx6MAI",
668
        "colab": {
669
          "base_uri": "https://localhost:8080/"
670
        },
671
        "outputId": "ac751cee-5f44-4dad-c164-5d124be75ec3"
672
      },
673
      "source": [
674
        "!node . \"I am happy\", \"I am glad\""
675
      ],
676
      "execution_count": 38,
677
      "outputs": [
678
        {
679
          "output_type": "stream",
680
          "name": "stdout",
681
          "text": [
682
            "\u001b[33m0.8285076844387538\u001b[39m\n"
683
          ]
684
        }
685
      ]
686
    },
687
    {
688
      "cell_type": "markdown",
689
      "metadata": {
690
        "id": "8Jyk-9Ko78Ma"
691
      },
692
      "source": [
693
        "Once again....wow!! The sentence embeddings model produces vectors that can be used to compare semantic similarity, -1 being most dissimilar and 1 being most similar.\n",
694
        "\n",
695
        "While the results don't match the exported model exactly, it's very close. Worth mentioning again that this is 100% JavaScript, no API or remote calls, all within node."
696
      ]
697
    },
698
    {
699
      "cell_type": "markdown",
700
      "metadata": {
701
        "id": "BQeMBNWO9Hpr"
702
      },
703
      "source": [
704
        "# Java\n",
705
        "\n",
706
        "Let's try the same thing with Java. The following sections initialize a Java build environment and writes out the code necessary to run the ONNX inference."
707
      ]
708
    },
709
    {
710
      "cell_type": "code",
711
      "source": [
712
        "%%capture\n",
713
        "import os\n",
714
        "\n",
715
        "os.chdir(\"/content\")\n",
716
        "!mkdir java\n",
717
        "os.chdir(\"/content/java\")\n",
718
        "\n",
719
        "# Copy ONNX models\n",
720
        "!cp ../text-classify.onnx .\n",
721
        "!cp ../embeddings.onnx .\n",
722
        "\n",
723
        "# Save copy of Bert Tokenizer\n",
724
        "tokenizer.save_pretrained(\"bert\")\n",
725
        "\n",
726
        "!mkdir -p src/main/java\n",
727
        "\n",
728
        "# Install gradle\n",
729
        "!wget https://services.gradle.org/distributions/gradle-7.5.1-bin.zip\n",
730
        "!unzip -o gradle-7.5.1-bin.zip"
731
      ],
732
      "metadata": {
733
        "id": "L1YMoO7WwkEk"
734
      },
735
      "execution_count": 39,
736
      "outputs": []
737
    },
738
    {
739
      "cell_type": "code",
740
      "metadata": {
741
        "id": "gjZ2p7Jf9mOV",
742
        "colab": {
743
          "base_uri": "https://localhost:8080/"
744
        },
745
        "outputId": "0cc3f4ef-bac6-4c13-dc6b-8482353bb741"
746
      },
747
      "source": [
748
        "%%writefile build.gradle\n",
749
        "apply plugin: \"java\"\n",
750
        "\n",
751
        "repositories {\n",
752
        "    mavenCentral()\n",
753
        "}\n",
754
        "\n",
755
        "dependencies {\n",
756
        "    implementation \"com.robrua.nlp:easy-bert:1.0.3\"\n",
757
        "    implementation \"com.microsoft.onnxruntime:onnxruntime:1.12.1\"\n",
758
        "}\n",
759
        "\n",
760
        "java {\n",
761
        "    toolchain {\n",
762
        "        languageVersion = JavaLanguageVersion.of(8)\n",
763
        "    }\n",
764
        "}\n",
765
        "\n",
766
        "jar {\n",
767
        "    archiveBaseName = \"onnxjava\"\n",
768
        "}\n",
769
        "\n",
770
        "task onnx(type: JavaExec) {\n",
771
        "    description = \"Runs ONNX demo\"\n",
772
        "    classpath = sourceSets.main.runtimeClasspath\n",
773
        "    main = \"OnnxDemo\"\n",
774
        "}"
775
      ],
776
      "execution_count": 40,
777
      "outputs": [
778
        {
779
          "output_type": "stream",
780
          "name": "stdout",
781
          "text": [
782
            "Writing build.gradle\n"
783
          ]
784
        }
785
      ]
786
    },
787
    {
788
      "cell_type": "code",
789
      "metadata": {
790
        "id": "9wlVWVky9NZ3"
791
      },
792
      "source": [
793
        "%%capture\n",
794
        "\n",
795
        "# Create environment\n",
796
        "!gradle-7.5.1/bin/gradle wrapper"
797
      ],
798
      "execution_count": 41,
799
      "outputs": []
800
    },
801
    {
802
      "cell_type": "code",
803
      "metadata": {
804
        "id": "vnxKGSuz_fnj",
805
        "colab": {
806
          "base_uri": "https://localhost:8080/"
807
        },
808
        "outputId": "c1b73bf8-c5f4-46fa-df90-29bd60507287",
809
        "cellView": "form"
810
      },
811
      "source": [
812
        "#@title\n",
813
        "%%writefile src/main/java/OnnxDemo.java\n",
814
        "import java.io.File;\n",
815
        "\n",
816
        "import java.nio.LongBuffer;\n",
817
        "\n",
818
        "import java.util.Arrays;\n",
819
        "import java.util.ArrayList;\n",
820
        "import java.util.HashMap;\n",
821
        "import java.util.List;\n",
822
        "import java.util.Map;\n",
823
        "\n",
824
        "import ai.onnxruntime.OnnxTensor;\n",
825
        "import ai.onnxruntime.OrtEnvironment;\n",
826
        "import ai.onnxruntime.OrtSession;\n",
827
        "import ai.onnxruntime.OrtSession.Result;\n",
828
        "\n",
829
        "import com.robrua.nlp.bert.FullTokenizer;\n",
830
        "\n",
831
        "class Tokens {\n",
832
        "    public long[] ids;\n",
833
        "    public long[] mask;\n",
834
        "    public long[] types;\n",
835
        "}\n",
836
        "\n",
837
        "class Tokenizer {\n",
838
        "    private FullTokenizer tokenizer;\n",
839
        "\n",
840
        "    public Tokenizer(String path) {\n",
841
        "        File vocab = new File(path);\n",
842
        "        this.tokenizer = new FullTokenizer(vocab, true);\n",
843
        "    }\n",
844
        "\n",
845
        "    public Tokens tokenize(String text) {\n",
846
        "        // Build list of tokens\n",
847
        "        List<String> tokensList = new ArrayList();\n",
848
        "        tokensList.add(\"[CLS]\"); \n",
849
        "        tokensList.addAll(Arrays.asList(tokenizer.tokenize(text)));\n",
850
        "        tokensList.add(\"[SEP]\");\n",
851
        "\n",
852
        "        int[] ids = tokenizer.convert(tokensList.toArray(new String[0]));\n",
853
        "\n",
854
        "        Tokens tokens = new Tokens();\n",
855
        "\n",
856
        "        // input ids    \n",
857
        "        tokens.ids = Arrays.stream(ids).mapToLong(i -> i).toArray();\n",
858
        "\n",
859
        "        // attention mask\n",
860
        "        tokens.mask = new long[ids.length];\n",
861
        "        Arrays.fill(tokens.mask, 1);\n",
862
        "\n",
863
        "        // token type ids\n",
864
        "        tokens.types = new long[ids.length];\n",
865
        "        Arrays.fill(tokens.types, 0);\n",
866
        "\n",
867
        "        return tokens;\n",
868
        "    }\n",
869
        "}\n",
870
        "\n",
871
        "class Inference {\n",
872
        "    private Tokenizer tokenizer;\n",
873
        "    private OrtEnvironment env;\n",
874
        "    private OrtSession session;\n",
875
        "\n",
876
        "    public Inference(String model) throws Exception {\n",
877
        "        this.tokenizer = new Tokenizer(\"bert/vocab.txt\");\n",
878
        "        this.env = OrtEnvironment.getEnvironment();\n",
879
        "        this.session = env.createSession(model, new OrtSession.SessionOptions());\n",
880
        "    }\n",
881
        "\n",
882
        "    public float[][] predict(String text) throws Exception {\n",
883
        "        Tokens tokens = this.tokenizer.tokenize(text);\n",
884
        "\n",
885
        "        Map<String, OnnxTensor> inputs = new HashMap<String, OnnxTensor>();\n",
886
        "        inputs.put(\"input_ids\", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.ids),  new long[]{1, tokens.ids.length}));\n",
887
        "        inputs.put(\"attention_mask\", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.mask),  new long[]{1, tokens.mask.length}));\n",
888
        "        inputs.put(\"token_type_ids\", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.types),  new long[]{1, tokens.types.length}));\n",
889
        "\n",
890
        "        return (float[][])session.run(inputs).get(0).getValue();\n",
891
        "    }\n",
892
        "}\n",
893
        "\n",
894
        "class Vectors {\n",
895
        "    public static double similarity(float[] v1, float[] v2) {\n",
896
        "        double dot = 0.0;\n",
897
        "        double norm1 = 0.0;\n",
898
        "        double norm2 = 0.0;\n",
899
        "\n",
900
        "        for (int x = 0; x < v1.length; x++) {\n",
901
        "            dot += v1[x] * v2[x];\n",
902
        "            norm1 += Math.pow(v1[x], 2);\n",
903
        "            norm2 += Math.pow(v2[x], 2);\n",
904
        "        }\n",
905
        "\n",
906
        "        return dot / (Math.sqrt(norm1) * Math.sqrt(norm2));\n",
907
        "    }\n",
908
        "\n",
909
        "    public static float[] softmax(float[] input) {\n",
910
        "        double[] t = new double[input.length];\n",
911
        "        double sum = 0.0;\n",
912
        "\n",
913
        "        for (int x = 0; x < input.length; x++) {\n",
914
        "            double val = Math.exp(input[x]);\n",
915
        "            sum += val;\n",
916
        "            t[x] = val;\n",
917
        "        }\n",
918
        "\n",
919
        "        float[] output = new float[input.length];\n",
920
        "        for (int x = 0; x < output.length; x++) {\n",
921
        "            output[x] = (float) (t[x] / sum);\n",
922
        "        }\n",
923
        "\n",
924
        "        return output;\n",
925
        "    }\n",
926
        "}\n",
927
        "\n",
928
        "public class OnnxDemo {\n",
929
        "    public static void main(String[] args) {\n",
930
        "        try {\n",
931
        "            if (args.length < 2) {\n",
932
        "              Inference inference = new Inference(\"text-classify.onnx\");\n",
933
        "\n",
934
        "              float[][] v1 = inference.predict(args[0]);\n",
935
        "\n",
936
        "              System.out.println(Arrays.toString(Vectors.softmax(v1[0])));\n",
937
        "            }\n",
938
        "            else {\n",
939
        "              Inference inference = new Inference(\"embeddings.onnx\");\n",
940
        "              float[][] v1 = inference.predict(args[0]);\n",
941
        "              float[][] v2 = inference.predict(args[1]);\n",
942
        "\n",
943
        "              System.out.println(Vectors.similarity(v1[0], v2[0]));\n",
944
        "            }\n",
945
        "        }\n",
946
        "        catch (Exception ex) {\n",
947
        "            ex.printStackTrace();\n",
948
        "        }\n",
949
        "    }\n",
950
        "}"
951
      ],
952
      "execution_count": 42,
953
      "outputs": [
954
        {
955
          "output_type": "stream",
956
          "name": "stdout",
957
          "text": [
958
            "Writing src/main/java/OnnxDemo.java\n"
959
          ]
960
        }
961
      ]
962
    },
963
    {
964
      "cell_type": "markdown",
965
      "metadata": {
966
        "id": "qQuuXw97Z_I7"
967
      },
968
      "source": [
969
        "## Run Text Classification in Java with ONNX"
970
      ]
971
    },
972
    {
973
      "cell_type": "code",
974
      "metadata": {
975
        "id": "hFXyH96gAZpu",
976
        "colab": {
977
          "base_uri": "https://localhost:8080/"
978
        },
979
        "outputId": "efd5b783-b23a-407c-8577-c18e3a6cb984"
980
      },
981
      "source": [
982
        "!./gradlew -q --console=plain onnx --args='\"I am happy\"' 2> /dev/null\n",
983
        "!./gradlew -q --console=plain onnx --args='\"I am mad\"' 2> /dev/null"
984
      ],
985
      "execution_count": 43,
986
      "outputs": [
987
        {
988
          "output_type": "stream",
989
          "name": "stdout",
990
          "text": [
991
            "[3.1213084E-4, 0.99968785]\n",
992
            "\u001b[m[0.99912345, 8.7655196E-4]\n",
993
            "\u001b[m"
994
          ]
995
        }
996
      ]
997
    },
998
    {
999
      "cell_type": "markdown",
1000
      "metadata": {
1001
        "id": "pE3FSsAAaJHe"
1002
      },
1003
      "source": [
1004
        "The command above tokenizes the input and runs inference with a text classification model previously created using a Java ONNX inference session. \n",
1005
        "\n",
1006
        "As a reminder, the text classification model is judging sentiment using two labels, 0 for negative to 1 for positive. The results above shows the probability of each label per text snippet."
1007
      ]
1008
    },
1009
    {
1010
      "cell_type": "markdown",
1011
      "metadata": {
1012
        "id": "Bux8v0C4aDyP"
1013
      },
1014
      "source": [
1015
        "## Build sentence embeddings and compare similarity in Java with ONNX"
1016
      ]
1017
    },
1018
    {
1019
      "cell_type": "code",
1020
      "metadata": {
1021
        "id": "f6zE9VrwCcUa",
1022
        "colab": {
1023
          "base_uri": "https://localhost:8080/"
1024
        },
1025
        "outputId": "988e59d0-943f-45b6-d37e-5fc1ebbbcefe"
1026
      },
1027
      "source": [
1028
        "!./gradlew -q --console=plain onnx --args='\"I am happy\" \"I am glad\"' 2> /dev/null"
1029
      ],
1030
      "execution_count": 44,
1031
      "outputs": [
1032
        {
1033
          "output_type": "stream",
1034
          "name": "stdout",
1035
          "text": [
1036
            "0.8298244656285757\n",
1037
            "\u001b[m"
1038
          ]
1039
        }
1040
      ]
1041
    },
1042
    {
1043
      "cell_type": "markdown",
1044
      "metadata": {
1045
        "id": "0uepOZvJDOCB"
1046
      },
1047
      "source": [
1048
        "The sentence embeddings model produces vectors that can be used to compare semantic similarity, -1 being most dissimilar and 1 being most similar. \n",
1049
        "\n",
1050
        "This is 100% Java, no API or remote calls, all within the JVM. Still think it's amazing!"
1051
      ]
1052
    },
1053
    {
1054
      "cell_type": "markdown",
1055
      "metadata": {
1056
        "id": "faRu9EAJDUXw"
1057
      },
1058
      "source": [
1059
        "# Rust\n",
1060
        "\n",
1061
        "Last but not least, let's try Rust. The following sections initialize a Rust build environment and writes out the code necessary to run the ONNX inference."
1062
      ]
1063
    },
1064
    {
1065
      "cell_type": "code",
1066
      "metadata": {
1067
        "id": "X3Xp1KLhelqw"
1068
      },
1069
      "source": [
1070
        "%%capture\n",
1071
        "import os\n",
1072
        "\n",
1073
        "os.chdir(\"/content\")\n",
1074
        "!mkdir rust\n",
1075
        "os.chdir(\"/content/rust\")\n",
1076
        "\n",
1077
        "# Copy ONNX models\n",
1078
        "!cp ../text-classify.onnx .\n",
1079
        "!cp ../embeddings.onnx .\n",
1080
        "\n",
1081
        "# Install Rust\n",
1082
        "!apt-get install rustc cargo\n",
1083
        "\n",
1084
        "!mkdir -p src"
1085
      ],
1086
      "execution_count": 45,
1087
      "outputs": []
1088
    },
1089
    {
1090
      "cell_type": "code",
1091
      "metadata": {
1092
        "id": "c7hz--Gne6Oa",
1093
        "colab": {
1094
          "base_uri": "https://localhost:8080/"
1095
        },
1096
        "outputId": "d98ad709-5675-4193-e598-e6cbe12edda3"
1097
      },
1098
      "source": [
1099
        "%%writefile Cargo.toml\n",
1100
        "[package]\n",
1101
        "name = \"onnx-test\"\n",
1102
        "version = \"1.0.0\"\n",
1103
        "description = \"\"\"\n",
1104
        "ONNX Runtime Rust test\n",
1105
        "\"\"\"\n",
1106
        "edition = \"2018\"\n",
1107
        "\n",
1108
        "[dependencies]\n",
1109
        "onnxruntime = { version = \"0.0.14\"}\n",
1110
        "tokenizers = { version = \"0.13.1\"}"
1111
      ],
1112
      "execution_count": 46,
1113
      "outputs": [
1114
        {
1115
          "output_type": "stream",
1116
          "name": "stdout",
1117
          "text": [
1118
            "Writing Cargo.toml\n"
1119
          ]
1120
        }
1121
      ]
1122
    },
1123
    {
1124
      "cell_type": "code",
1125
      "metadata": {
1126
        "id": "_8fdRvO1fFBm",
1127
        "colab": {
1128
          "base_uri": "https://localhost:8080/"
1129
        },
1130
        "outputId": "53168684-12c2-46fc-e0e6-54c4c6d03cb1",
1131
        "cellView": "form"
1132
      },
1133
      "source": [
1134
        "#@title\n",
1135
        "%%writefile src/main.rs\n",
1136
        "use onnxruntime::environment::Environment;\n",
1137
        "use onnxruntime::GraphOptimizationLevel;\n",
1138
        "use onnxruntime::ndarray::{Array2, Axis};\n",
1139
        "use onnxruntime::tensor::OrtOwnedTensor;\n",
1140
        "\n",
1141
        "use std::env;\n",
1142
        "\n",
1143
        "use tokenizers::tokenizer::{Result, Tokenizer};\n",
1144
        "\n",
1145
        "fn tokenize(text: String, inputs: usize) -> Vec<Array2<i64>> {\n",
1146
        "    // Load tokenizer from HF Hub\n",
1147
        "    let tokenizer = Tokenizer::from_pretrained(\"bert-base-uncased\", None).unwrap();\n",
1148
        "\n",
1149
        "    // Encode input text\n",
1150
        "    let encoding = tokenizer.encode(text, true).unwrap();\n",
1151
        "\n",
1152
        "    let v1: Vec<i64> = encoding.get_ids().to_vec().into_iter().map(|x| x as i64).collect();\n",
1153
        "    let v2: Vec<i64> = encoding.get_attention_mask().to_vec().into_iter().map(|x| x as i64).collect();\n",
1154
        "    let v3: Vec<i64> = encoding.get_type_ids().to_vec().into_iter().map(|x| x as i64).collect();\n",
1155
        "\n",
1156
        "    let ids = Array2::from_shape_vec((1, v1.len()), v1).unwrap();\n",
1157
        "    let mask = Array2::from_shape_vec((1, v2.len()), v2).unwrap();\n",
1158
        "    let tids = Array2::from_shape_vec((1, v3.len()), v3).unwrap();\n",
1159
        "\n",
1160
        "    return if inputs > 2 { vec![ids, mask, tids] } else { vec![ids, mask] };\n",
1161
        "}\n",
1162
        "\n",
1163
        "fn predict(text: String, softmax: bool) -> Vec<f32> {\n",
1164
        "    // Start onnx session\n",
1165
        "    let environment = Environment::builder()\n",
1166
        "        .with_name(\"test\")\n",
1167
        "        .build().unwrap();\n",
1168
        "\n",
1169
        "    // Derive model path\n",
1170
        "    let model = if softmax { \"text-classify.onnx\" } else { \"embeddings.onnx\" };\n",
1171
        "\n",
1172
        "    let mut session = environment\n",
1173
        "        .new_session_builder().unwrap()\n",
1174
        "        .with_optimization_level(GraphOptimizationLevel::Basic).unwrap()\n",
1175
        "        .with_number_threads(1).unwrap()\n",
1176
        "        .with_model_from_file(model).unwrap();\n",
1177
        "\n",
1178
        "    let inputs = tokenize(text, session.inputs.len());\n",
1179
        "\n",
1180
        "    // Run inference and print result\n",
1181
        "    let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(inputs).unwrap();\n",
1182
        "    let output: &OrtOwnedTensor<f32, _> = &outputs[0];\n",
1183
        "\n",
1184
        "    let probabilities: Vec<f32>;\n",
1185
        "    if softmax {\n",
1186
        "        probabilities = output\n",
1187
        "            .softmax(Axis(1))\n",
1188
        "            .iter()\n",
1189
        "            .copied()\n",
1190
        "            .collect::<Vec<_>>();\n",
1191
        "    }\n",
1192
        "    else {\n",
1193
        "        probabilities= output\n",
1194
        "            .iter()\n",
1195
        "            .copied()\n",
1196
        "            .collect::<Vec<_>>();\n",
1197
        "    }\n",
1198
        "\n",
1199
        "    return probabilities;\n",
1200
        "}\n",
1201
        "\n",
1202
        "fn similarity(v1: &Vec<f32>, v2: &Vec<f32>) -> f64 {\n",
1203
        "    let mut dot = 0.0;\n",
1204
        "    let mut norm1 = 0.0;\n",
1205
        "    let mut norm2 = 0.0;\n",
1206
        "\n",
1207
        "    for x in 0..v1.len() {\n",
1208
        "        dot += v1[x] * v2[x];\n",
1209
        "        norm1 += v1[x].powf(2.0);\n",
1210
        "        norm2 += v2[x].powf(2.0);\n",
1211
        "    }\n",
1212
        "\n",
1213
        "    return dot as f64 / (norm1.sqrt() * norm2.sqrt()) as f64\n",
1214
        "}\n",
1215
        "\n",
1216
        "fn main() -> Result<()> {\n",
1217
        "    // Tokenize input string\n",
1218
        "    let args: Vec<String> = env::args().collect();\n",
1219
        "\n",
1220
        "    if args.len() <= 2 {\n",
1221
        "      let v1 = predict(args[1].to_string(), true);\n",
1222
        "      println!(\"{:?}\", v1);\n",
1223
        "    }\n",
1224
        "    else {\n",
1225
        "      let v1 = predict(args[1].to_string(), false);\n",
1226
        "      let v2 = predict(args[2].to_string(), false);\n",
1227
        "      println!(\"{:?}\", similarity(&v1, &v2));\n",
1228
        "    }\n",
1229
        "\n",
1230
        "    Ok(())\n",
1231
        "}"
1232
      ],
1233
      "execution_count": 47,
1234
      "outputs": [
1235
        {
1236
          "output_type": "stream",
1237
          "name": "stdout",
1238
          "text": [
1239
            "Writing src/main.rs\n"
1240
          ]
1241
        }
1242
      ]
1243
    },
1244
    {
1245
      "cell_type": "markdown",
1246
      "metadata": {
1247
        "id": "OdfQFY-MiA-n"
1248
      },
1249
      "source": [
1250
        "## Run Text Classification in Rust with ONNX"
1251
      ]
1252
    },
1253
    {
1254
      "cell_type": "code",
1255
      "metadata": {
1256
        "id": "b0ymX4ftgWcT",
1257
        "colab": {
1258
          "base_uri": "https://localhost:8080/"
1259
        },
1260
        "outputId": "84b42d3c-82d4-46fc-fb84-94967bd5330f"
1261
      },
1262
      "source": [
1263
        "!cargo run \"I am happy\" 2> /dev/null\n",
1264
        "!cargo run \"I am mad\" 2> /dev/null"
1265
      ],
1266
      "execution_count": 48,
1267
      "outputs": [
1268
        {
1269
          "output_type": "stream",
1270
          "name": "stdout",
1271
          "text": [
1272
            "[0.00030939875, 0.9996906]\n",
1273
            "[0.99912345, 0.0008765513]\n"
1274
          ]
1275
        }
1276
      ]
1277
    },
1278
    {
1279
      "cell_type": "markdown",
1280
      "metadata": {
1281
        "id": "NKccz6bBiIgW"
1282
      },
1283
      "source": [
1284
        "The command above tokenizes the input and runs inference with a text classification model previously created using a Rust ONNX inference session. \n",
1285
        "\n",
1286
        "As a reminder, the text classification model is judging sentiment using two labels, 0 for negative to 1 for positive. The results above shows the probability of each label per text snippet."
1287
      ]
1288
    },
1289
    {
1290
      "cell_type": "markdown",
1291
      "metadata": {
1292
        "id": "1D1kN0yNiEg7"
1293
      },
1294
      "source": [
1295
        "## Build sentence embeddings and compare similarity in Rust with ONNX"
1296
      ]
1297
    },
1298
    {
1299
      "cell_type": "code",
1300
      "metadata": {
1301
        "id": "A9p6F_ODhenH",
1302
        "colab": {
1303
          "base_uri": "https://localhost:8080/"
1304
        },
1305
        "outputId": "b43ad47e-e1f3-4748-d854-a0dc2024b780"
1306
      },
1307
      "source": [
1308
        "!cargo run \"I am happy\" \"I am glad\" 2> /dev/null"
1309
      ],
1310
      "execution_count": 49,
1311
      "outputs": [
1312
        {
1313
          "output_type": "stream",
1314
          "name": "stdout",
1315
          "text": [
1316
            "0.8298246060854143\n"
1317
          ]
1318
        }
1319
      ]
1320
    },
1321
    {
1322
      "cell_type": "markdown",
1323
      "metadata": {
1324
        "id": "TQ7Wvn0OiRr4"
1325
      },
1326
      "source": [
1327
        "The sentence embeddings model produces vectors that can be used to compare semantic similarity, -1 being most dissimilar and 1 being most similar. \n",
1328
        "\n",
1329
        "Once again, this is 100% Rust, no API or remote calls. And yes, still think it's amazing!"
1330
      ]
1331
    },
1332
    {
1333
      "cell_type": "markdown",
1334
      "metadata": {
1335
        "id": "-_FNKUWtjLsO"
1336
      },
1337
      "source": [
1338
        "# Wrapping up\n",
1339
        "\n",
1340
        "This notebook covered how to export models to ONNX using txtai. These models were then run in Python, JavaScript, Java and Rust. Golang was also evaluated but there doesn't currently appear to be a stable enough ONNX runtime available. \n",
1341
        "\n",
1342
        "This method provides a way to train and run machine learning models using a number of programming languages on a number of platforms.\n",
1343
        "\n",
1344
        "The following is a non-exhaustive list of use cases. \n",
1345
        "\n",
1346
        "*   Build locally executed models for mobile/edge devices\n",
1347
        "*   Run models with Java/JavaScript/Rust development stacks when teams prefer not to add Python to the mix\n",
1348
        "*   Export models to ONNX for Python inference to improve CPU performance and/or reduce number of software dependencies"
1349
      ]
1350
    }
1351
  ]
1352
}
1353

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.