Export and run models with ONNX

This article is part of a tutorial series on txtai, an AI-powered semantic search platform.

The ONNX runtime provides a common serialization format for machine learning models. ONNX supports a number of different platforms/languages and has features built in to help reduce inference time.

PyTorch has robust support for exporting Torch models to ONNX. This enables exporting Hugging Face Transformer and/or other downstream models directly to ONNX.

ONNX opens an avenue for direct inference using a number of languages and platforms. For example, a model could be run directly on Android to limit data sent to a third party service. ONNX is an exciting development with a lot of promise. Microsoft has also released Hummingbird which enables exporting traditional models (sklearn, decision trees, logistical regression..) to ONNX.

This article will cover how to export models to ONNX using txtai. These models will then be directly run in Python, JavaScript, Java and Rust. Currently, txtai supports all these languages through it's API and that is still the recommended approach.

Install dependencies

Install txtai and all dependencies. Since this article uses ONNX quantization, we need to install the pipeline extras package.

pip install txtai[pipeline]

Run a model with ONNX

Let's get right to it! The following example exports a sentiment analysis model to ONNX and runs an inference session.

import numpy as np

from onnxruntime import InferenceSession, SessionOptions
from transformers import AutoTokenizer
from txtai.pipeline import HFOnnx

# Normalize logits using sigmoid function
sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x))

# Export to ONNX
onnx = HFOnnx()
model = onnx("distilbert-base-uncased-finetuned-sst-2-english", "text-classification")

# Start inference session
options = SessionOptions()
session = InferenceSession(model, options)

# Tokenize
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
tokens = tokenizer(["I am happy", "I am mad"], return_tensors="np")

# Print results
outputs = session.run(None, dict(tokens))
print(sigmoid(outputs[0]))
[[0.01295124 0.9909526 ]
 [0.9874723  0.0297817 ]]

And just like that, there are results! The text classification model is judging sentiment using two labels, 0 for negative to 1 for positive. The results above shows the probability of each label per text snippet.

The ONNX pipeline loads the model, converts the graph to ONNX and returns. Note that no output file was provided, in this case the ONNX model is returned as a byte array. If an output file is provided, this method returns the output path.

Train and Export a model for Text Classification

Next we'll combine the ONNX pipeline with a Trainer pipeline to create a "train and export to ONNX" workflow.

from datasets import load_dataset
from txtai.pipeline import HFTrainer

trainer = HFTrainer()

# Hugging Face dataset
ds = load_dataset("glue", "sst2")
data = ds["train"].select(range(5000)).flatten_indices()

# Train new model using 5,000 SST2 records (in-memory)
model, tokenizer = trainer("google/electra-base-discriminator", data, columns=("sentence", "label"))

# Export model trained in-memory to ONNX (still in-memory)
output = onnx((model, tokenizer), "text-classification", quantize=True)

# Start inference session
options = SessionOptions()
session = InferenceSession(output, options)

# Tokenize
tokens = tokenizer(["I am happy", "I am mad"], return_tensors="np")

# Print results
outputs = session.run(None, dict(tokens))
print(sigmoid(outputs[0]))
[[0.02424305 0.9557785 ]
 [0.95884305 0.05541185]]

The results are similar to the previous step, although this model is only trained on a fraction of the sst2 dataset. Lets save this model for later.

text = onnx((model, tokenizer), "text-classification", "text-classify.onnx", quantize=True)

Export a Sentence Embeddings model

The ONNX pipeline also supports exporting sentence embeddings models trained with the sentence-transformers package.

embeddings = onnx("sentence-transformers/paraphrase-MiniLM-L6-v2", "pooling", "embeddings.onnx", quantize=True)

Now let's run the model with ONNX.

from sklearn.metrics.pairwise import cosine_similarity

options = SessionOptions()
session = InferenceSession(embeddings, options)

tokens = tokenizer(["I am happy", "I am glad"], return_tensors="np")

outputs = session.run(None, dict(tokens))[0]

print(cosine_similarity(outputs))
[[1.0000002 0.8430618]
 [0.8430618 1.       ]]

The code above tokenizes two separate text snippets ("I am happy" and "I am glad") and runs it through the ONNX model.

This outputs two embeddings arrays and those arrays are compared using cosine similarity. As we can see, the two text snippets have close semantic meaning.

Load an ONNX model with txtai

txtai has built-in support for ONNX models. Loading an ONNX model is seamless and Embeddings and Pipelines support it. The following section shows how to load a classification pipeline and embeddings model backed by ONNX.

from txtai.embeddings import Embeddings
from txtai.pipeline import Labels

labels = Labels(("text-classify.onnx", "google/electra-base-discriminator"), dynamic=False)
print(labels(["I am happy", "I am mad"]))

embeddings = Embeddings({"path": "embeddings.onnx", "tokenizer": "sentence-transformers/paraphrase-MiniLM-L6-v2"})
print(embeddings.similarity("I am happy", ["I am glad"]))
[[(1, 0.9988517761230469), (0, 0.0011482156114652753)], [(0, 0.997488260269165), (1, 0.0025116782635450363)]]
[(0, 0.8581848740577698)]

JavaScript

So far, we've exported models to ONNX and run them through Python. This already has a lot of advantages, which include fast inference times, quantization and less software dependencies. But ONNX really shines when we run a model trained in Python in other languages/platforms.

Let's try running the models trained above in JavaScript. First step is getting the Node.js environment and dependencies setup.

import os

!mkdir js
os.chdir("/content/js")
{
  "name": "onnx-test",
  "private": true,
  "version": "1.0.0",
  "description": "ONNX Runtime Node.js test",
  "main": "index.js",
  "dependencies": {
    "onnxruntime-node": ">=1.8.0",
    "tokenizers": "file:tokenizers/bindings/node"
  }
}
# Copy ONNX models
!cp ../text-classify.onnx .
!cp ../embeddings.onnx .

# Save copy of Bert Tokenizer
tokenizer.save_pretrained("bert")

# Get tokenizers project
!git clone https://github.com/huggingface/tokenizers.git

os.chdir("/content/js/tokenizers/bindings/node")

# Install Rust
!apt-get install rustc

# Build tokenizers project locally as version on NPM isn't working properly for latest version of Node.js
!npm install --also=dev
!npm run dev

# Install all dependencies
os.chdir("/content/js")
!npm install

Next we'll write the inference code in JavaScript to an index.js file.

const ort = require('onnxruntime-node');
const { promisify } = require('util');
const { Tokenizer } = require("tokenizers/dist/bindings/tokenizer");

function sigmoid(data) {
    return data.map(x => 1 / (1 + Math.exp(-x)))
}

function softmax(data) { 
    return data.map(x => Math.exp(x) / (data.map(y => Math.exp(y))).reduce((a,b) => a+b)) 
}

function similarity(v1, v2) {
    let dot = 0.0;
    let norm1 = 0.0;
    let norm2 = 0.0;

    for (let x = 0; x < v1.length; x++) {
        dot += v1[x] * v2[x];
        norm1 += Math.pow(v1[x], 2);
        norm2 += Math.pow(v2[x], 2);
    }

    return dot / (Math.sqrt(norm1) * Math.sqrt(norm2));
}

function tokenizer(path) {
    let tokenizer = Tokenizer.fromFile(path);
    return promisify(tokenizer.encode.bind(tokenizer));
}

async function predict(session, text) {
    try {
        // Tokenize input
        let encode = tokenizer("bert/tokenizer.json");
        let output = await encode(text);

        let ids = output.getIds().map(x => BigInt(x))
        let mask = output.getAttentionMask().map(x => BigInt(x))
        let tids = output.getTypeIds().map(x => BigInt(x))

        // Convert inputs to tensors    
        let tensorIds = new ort.Tensor('int64', BigInt64Array.from(ids), [1, ids.length]);
        let tensorMask = new ort.Tensor('int64', BigInt64Array.from(mask), [1, mask.length]);
        let tensorTids = new ort.Tensor('int64', BigInt64Array.from(tids), [1, tids.length]);

        let inputs = null;
        if (session.inputNames.length > 2) {
            inputs = { input_ids: tensorIds, attention_mask: tensorMask, token_type_ids: tensorTids};
        }
        else {
            inputs = { input_ids: tensorIds, attention_mask: tensorMask};
        }

        return await session.run(inputs);
    } catch (e) {
        console.error(`failed to inference ONNX model: ${e}.`);
    }
}

async function main() {
    let args = process.argv.slice(2);
    if (args.length > 1) {
        // Run sentence embeddings
        const session = await ort.InferenceSession.create('./embeddings.onnx');

        let v1 = await predict(session, args[0]);
        let v2 = await predict(session, args[1]);

        // Unpack results
        v1 = v1.embeddings.data;
        v2 = v2.embeddings.data;

        // Print similarity
        console.log(similarity(Array.from(v1), Array.from(v2)));
    }
    else {
        // Run text classifier
        const session = await ort.InferenceSession.create('./text-classify.onnx');
        let results = await predict(session, args[0]);

        // Normalize results using softmax and print
        console.log(softmax(results.logits.data));
    }
}

main();

Run Text Classification in JavaScript with ONNX

!node . "I am happy"
!node . "I am mad"
Float32Array(2) [ 0.001104647060856223, 0.9988954067230225 ]
Float32Array(2) [ 0.9976443648338318, 0.00235558208078146 ]

First off, have to say this is 🔥🔥🔥! Just amazing that this model can be fully run in JavaScript. It's a great time to be in NLP!

The steps above installed a JavaScript environment with dependencies to run ONNX and tokenize data in JavaScript. The text classification model previously created is loaded into the JavaScript ONNX runtime and inference is run.

As a reminder, the text classification model is judging sentiment using two labels, 0 for negative to 1 for positive. The results above shows the probability of each label per text snippet.

Build sentence embeddings and compare similarity in JavaScript with ONNX

!node . "I am happy", "I am glad"
0.8414919420066624

Once again....wow!! The sentence embeddings model produces vectors that can be used to compare semantic similarity, -1 being most dissimilar and 1 being most similar.

While the results don't match the exported model exactly, it's very close. Worth mentioning again that this is 100% JavaScript, no API or remote calls, all within node.

Java

Let's try the same thing with Java. The following sections initialize a Java build environment and writes out the code necessary to run the ONNX inference.

import os

os.chdir("/content")
!mkdir java
os.chdir("/content/java")

# Copy ONNX models
!cp ../text-classify.onnx .
!cp ../embeddings.onnx .

# Save copy of Bert Tokenizer
tokenizer.save_pretrained("bert")

!mkdir -p src/main/java

# Install gradle
!wget https://services.gradle.org/distributions/gradle-7.2-bin.zip
!unzip -o gradle-7.2-bin.zip
!gradle-7.2/bin/gradle wrapper
apply plugin: "java"

repositories {
    mavenCentral()
}

dependencies {
    implementation "com.robrua.nlp:easy-bert:1.0.3"
    implementation "com.microsoft.onnxruntime:onnxruntime:1.8.1"
}

java {
    toolchain {
        languageVersion = JavaLanguageVersion.of(8)
    }
}

jar {
    archiveBaseName = "onnxjava"
}

task onnx(type: JavaExec) {
    description = "Runs ONNX demo"
    classpath = sourceSets.main.runtimeClasspath
    main = "OnnxDemo"
}
import java.io.File;

import java.nio.LongBuffer;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;

import com.robrua.nlp.bert.FullTokenizer;

class Tokens {
    public long[] ids;
    public long[] mask;
    public long[] types;
}

class Tokenizer {
    private FullTokenizer tokenizer;

    public Tokenizer(String path) {
        File vocab = new File(path);
        this.tokenizer = new FullTokenizer(vocab, true);
    }

    public Tokens tokenize(String text) {
        // Build list of tokens
        List<String> tokensList = new ArrayList();
        tokensList.add("[CLS]"); 
        tokensList.addAll(Arrays.asList(tokenizer.tokenize(text)));
        tokensList.add("[SEP]");

        int[] ids = tokenizer.convert(tokensList.toArray(new String[0]));

        Tokens tokens = new Tokens();

        // input ids    
        tokens.ids = Arrays.stream(ids).mapToLong(i -> i).toArray();

        // attention mask
        tokens.mask = new long[ids.length];
        Arrays.fill(tokens.mask, 1);

        // token type ids
        tokens.types = new long[ids.length];
        Arrays.fill(tokens.types, 0);

        return tokens;
    }
}

class Inference {
    private Tokenizer tokenizer;
    private OrtEnvironment env;
    private OrtSession session;

    public Inference(String model) throws Exception {
        this.tokenizer = new Tokenizer("bert/vocab.txt");
        this.env = OrtEnvironment.getEnvironment();
        this.session = env.createSession(model, new OrtSession.SessionOptions());
    }

    public float[][] predict(String text) throws Exception {
        Tokens tokens = this.tokenizer.tokenize(text);

        Map<String, OnnxTensor> inputs = new HashMap<String, OnnxTensor>();
        inputs.put("input_ids", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.ids),  new long[]{1, tokens.ids.length}));
        inputs.put("attention_mask", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.mask),  new long[]{1, tokens.mask.length}));
        inputs.put("token_type_ids", OnnxTensor.createTensor(env, LongBuffer.wrap(tokens.types),  new long[]{1, tokens.types.length}));

        return (float[][])session.run(inputs).get(0).getValue();
    }
}

class Vectors {
    public static double similarity(float[] v1, float[] v2) {
        double dot = 0.0;
        double norm1 = 0.0;
        double norm2 = 0.0;

        for (int x = 0; x < v1.length; x++) {
            dot += v1[x] * v2[x];
            norm1 += Math.pow(v1[x], 2);
            norm2 += Math.pow(v2[x], 2);
        }

        return dot / (Math.sqrt(norm1) * Math.sqrt(norm2));
    }

    public static float[] softmax(float[] input) {
        double[] t = new double[input.length];
        double sum = 0.0;

        for (int x = 0; x < input.length; x++) {
            double val = Math.exp(input[x]);
            sum += val;
            t[x] = val;
        }

        float[] output = new float[input.length];
        for (int x = 0; x < output.length; x++) {
            output[x] = (float) (t[x] / sum);
        }

        return output;
    }
}

public class OnnxDemo {
    public static void main(String[] args) {
        try {
            if (args.length < 2) {
              Inference inference = new Inference("text-classify.onnx");

              float[][] v1 = inference.predict(args[0]);

              System.out.println(Arrays.toString(Vectors.softmax(v1[0])));
            }
            else {
              Inference inference = new Inference("embeddings.onnx");
              float[][] v1 = inference.predict(args[0]);
              float[][] v2 = inference.predict(args[1]);

              System.out.println(Vectors.similarity(v1[0], v2[0]));
            }
        }
        catch (Exception ex) {
            ex.printStackTrace();
        }
    }
}

Run Text Classification in Java with ONNX

!./gradlew -q --console=plain onnx --args='"I am happy"' 2> /dev/null
!./gradlew -q --console=plain onnx --args='"I am mad"' 2> /dev/null
[0.0011046471, 0.99889535]
[0.9976444, 0.002355582]

The command above tokenizes the input and runs inference with a text classification model previously created using a Java ONNX inference session.

As a reminder, the text classification model is judging sentiment using two labels, 0 for negative to 1 for positive. The results above shows the probability of each label per text snippet.

Build sentence embeddings and compare similarity in Java with ONNX

!./gradlew -q --console=plain onnx --args='"I am happy" "I am glad"' 2> /dev/null
0.8581848568615768

The sentence embeddings model produces vectors that can be used to compare semantic similarity, -1 being most dissimilar and 1 being most similar.

This is 100% Java, no API or remote calls, all within the JVM. Still think it's amazing!

Rust

Last but not least, let's try Rust. The following sections initialize a Rust build environment and writes out the code necessary to run the ONNX inference.

import os

os.chdir("/content")
!mkdir rust
os.chdir("/content/rust")

# Copy ONNX models
!cp ../text-classify.onnx .
!cp ../embeddings.onnx .

# Save copy of Bert Tokenizer
tokenizer.save_pretrained("bert")

# Install Rust
!apt-get install rustc

!mkdir -p src
[package]
name = "onnx-test"
version = "1.0.0"
description = """
ONNX Runtime Rust test
"""
edition = "2018"

[dependencies]
onnxruntime = { version = "0.0.14"}
tokenizers = { version = "0.10.1"}
use onnxruntime::environment::Environment;
use onnxruntime::GraphOptimizationLevel;
use onnxruntime::ndarray::{Array2, Axis};
use onnxruntime::tensor::OrtOwnedTensor;

use std::env;

use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
use tokenizers::models::wordpiece::WordPiece;
use tokenizers::normalizers::bert::BertNormalizer;
use tokenizers::pre_tokenizers::bert::BertPreTokenizer;
use tokenizers::processors::bert::BertProcessing;
use tokenizers::tokenizer::{Result, Tokenizer, EncodeInput};

fn tokenize(text: String, inputs: usize) -> Vec<Array2<i64>> {
    // Load tokenizer
    let mut tokenizer = Tokenizer::new(Box::new(
        WordPiece::from_files("bert/vocab.txt")
            .build()
            .expect("Vocab file not found"),
    ));

    tokenizer.with_normalizer(Box::new(BertNormalizer::default()));
    tokenizer.with_pre_tokenizer(Box::new(BertPreTokenizer));
    tokenizer.with_decoder(Box::new(WordPieceDecoder::default()));
    tokenizer.with_post_processor(Box::new(BertProcessing::new(
        (
            String::from("[SEP]"),
            tokenizer.get_model().token_to_id("[SEP]").unwrap(),
        ),
        (
            String::from("[CLS]"),
            tokenizer.get_model().token_to_id("[CLS]").unwrap(),
        ),
    )));

    // Encode input text
    let encoding = tokenizer.encode(EncodeInput::Single(text), true).unwrap();

    let v1: Vec<i64> = encoding.get_ids().to_vec().into_iter().map(|x| x as i64).collect();
    let v2: Vec<i64> = encoding.get_attention_mask().to_vec().into_iter().map(|x| x as i64).collect();
    let v3: Vec<i64> = encoding.get_type_ids().to_vec().into_iter().map(|x| x as i64).collect();

    let ids = Array2::from_shape_vec((1, v1.len()), v1).unwrap();
    let mask = Array2::from_shape_vec((1, v2.len()), v2).unwrap();
    let tids = Array2::from_shape_vec((1, v3.len()), v3).unwrap();

    return if inputs > 2 { vec![ids, mask, tids] } else { vec![ids, mask] };
}

fn predict(text: String, softmax: bool) -> Vec<f32> {
    // Start onnx session
    let environment = Environment::builder()
        .with_name("test")
        .build().unwrap();

    // Derive model path
    let model = if softmax { "text-classify.onnx" } else { "embeddings.onnx" };

    let mut session = environment
        .new_session_builder().unwrap()
        .with_optimization_level(GraphOptimizationLevel::Basic).unwrap()
        .with_number_threads(1).unwrap()
        .with_model_from_file(model).unwrap();

    let inputs = tokenize(text, session.inputs.len());

    // Run inference and print result
    let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(inputs).unwrap();
    let output: &OrtOwnedTensor<f32, _> = &outputs[0];

    let probabilities: Vec<f32>;
    if softmax {
        probabilities = output
            .softmax(Axis(1))
            .iter()
            .copied()
            .collect::<Vec<_>>();
    }
    else {
        probabilities= output
            .iter()
            .copied()
            .collect::<Vec<_>>();
    }

    return probabilities;
}

fn similarity(v1: &Vec<f32>, v2: &Vec<f32>) -> f64 {
    let mut dot = 0.0;
    let mut norm1 = 0.0;
    let mut norm2 = 0.0;

    for x in 0..v1.len() {
        dot += v1[x] * v2[x];
        norm1 += v1[x].powf(2.0);
        norm2 += v2[x].powf(2.0);
    }

    return dot as f64 / (norm1.sqrt() * norm2.sqrt()) as f64
}

fn main() -> Result<()> {
    // Tokenize input string
    let args: Vec<String> = env::args().collect();

    if args.len() <= 2 {
      let v1 = predict(args[1].to_string(), true);
      println!("{:?}", v1);
    }
    else {
      let v1 = predict(args[1].to_string(), false);
      let v2 = predict(args[2].to_string(), false);
      println!("{:?}", similarity(&v1, &v2));
    }

    Ok(())
}

Run Text Classification in Rust with ONNX

!cargo run "I am happy" 2> /dev/null
!cargo run "I am mad" 2> /dev/null
[0.0011003953, 0.99889964]
[0.9976444, 0.0023555849]

The command above tokenizes the input and runs inference with a text classification model previously created using a Rust ONNX inference session.

As a reminder, the text classification model is judging sentiment using two labels, 0 for negative to 1 for positive. The results above shows the probability of each label per text snippet.

Build sentence embeddings and compare similarity in Rust with ONNX

!cargo run "I am happy" "I am glad" 2> /dev/null
0.8583641740656903

The sentence embeddings model produces vectors that can be used to compare semantic similarity, -1 being most dissimilar and 1 being most similar.

Once again, this is 100% Rust, no API or remote calls. And yes, still think it's amazing!

Wrapping up

This notebook covered how to export models to ONNX using txtai. These models were then run in Python, JavaScript, Java and Rust. Golang was also evaluated but there doesn't currently appear to be a stable enough ONNX runtime available.

This method provides a way to train and run machine learning models using a number of programming languages on a number of platforms.

The following is a non-exhaustive list of use cases.

  • Build locally executed models for mobile/edge devices
  • Run models with Java/JavaScript/Rust development stacks when teams prefer not to add Python to the mix
  • Export models to ONNX for Python inference to improve CPU performance and/or reduce number of software dependencies

19