Wonnx is a GPU-accelerated ONNX inference run-time written 100% in Rust, ready for the web.

Supported Platforms (enabled by wgpu)

API Windows Linux & Android macOS & iOS
DX12 (W10 only)
DX11 🚧

= First Class Support — 🆗 = Best Effort Support — 🚧 = Unsupported, but support in progress

Getting started

  • Install Rust
  • Install Vulkan, Metal, or DX12 for the GPU API.
  • Ensure Git LFS is installed
  • git clone this repo.
git clone https://github.com/webonnx/wonnx.git
git lfs install

From the command line

Ensure Git LFS is initialized and has downloaded the model files (in wonnx/examples/data/models). Then, you're all set! You can run an example:

cargo run --example squeeze --release

Or you can try the CLI (see the README for more information):

cargo run --release -- info ./data/models/opt-squeeze.mnist
cargo run --release -- infer ./data/models/opt-squeeze.onnx -i data=./data/images/pelican.jpeg --labels ./data/models/squeeze-labels.txt --top 3

From Python

pip install wonnx

And then:

from wonnx import PySession
session = PySession.from_path(
inputs = {"x": [-1.0, 2.0]}
assert session.run(inputs) == {"y": [0.0, 2.0]}

To build the Python module for development:

cd wonnx-py
python3 -m venv .env
source .env/bin/activate
pip install maturin
maturin develop

Then run python3 with the above Python code!

Running a model from scratch

  • To run an onnx model, first simplify it with onnx-simplifier, with the command:
# pip install -U pip && pip install onnx-simplifier
python -m onnxsim mnist-8.onnx opt-mnist.onnx
  • Then you can run it following the example in the examples folder:
cargo run --example mnist --release
>(); probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(a.1).unwrap()); assert_eq!(probabilities[0].0, 22); }">
fn main() -> HashMap<String, Vec<f32>> {
    let mut input_data = HashMap::new();
    let image = load_squeezenet_image(); // Load image
    input_data.insert("data".to_string(), InputTensor::F32(image.as_slice().unwrap()));

    let session = pollster::block_on(wonnx::Session::from_path(
    .expect("session did not create");
    let result = pollster::block_on(session.run(input_data)).unwrap();
    let result = result["squeezenet0_flatten0_reshape0"];
    let mut probabilities = result.iter().enumerate().collect::<Vec<_>>();

    probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(a.1).unwrap());

    assert_eq!(probabilities[0].0, 22);

Examples are available in the examples folder

Tested models

  • Squeezenet

GPU selection

You may set the following environment variables to influence GPU selection by WGPU:

  • WGPU_ADAPTER_NAME with a substring of the name of the adapter you want to use (e.g. 1080 will match NVIDIA GeForce 1080ti).
  • WGPU_BACKEND with a comma separated list of the backends you want to use (vulkan, metal, dx12, dx11, or gl).
  • WGPU_POWER_PREFERENCE with the power preference to choose when a specific adapter name isn't specified (high or low)

Contribution: On implementing a new Operator

Contribution are very much welcomed even without large experience in DL, WGSL, or Rust. I hope that, this project can be a sandbox for all of us to learn more about those technologies beyond this project initial scope.

To implement an operator all you have to do is:

  1. Add a new matching pattern in compiler.rs
  2. Retrieve its attributes values using the get_attribute function:
("alpha", None, node);">
    let alpha = get_attribute("alpha", Some(1.0), node);
    // or without default value
    let alpha = get_attribute::<f32>("alpha", None, node);
  1. Add any variable you want to use in the WGSL shader using context.
  2. Write a new WGSL template in the templates folder.

Available types are in structs.wgsl but you can also generate new ones within your templates.

  1. Respect the binding layout that each entry is incremented by 1 starting from 0, with input first and output last. If the number of binding is above 4. Increment the binding group. You can change the input within sequencer.rs
  2. Write the logic.

There is default variables in the context:

  • {{ i_lens[0] }}: the length of the input 0. This also work for output: {{ o_lens[0] }} and other input {{ i_lens[1] }}
  • {{ i_shape[0] }}: the array of dimensions of input 0. To get the first dimension of the array, just use: {{ i_shape[0][0] }}
  • {{ i_chunks[0] }}: the size of the chunks of each dimensions of input 0. By default, each variable is represented as a long array of values where to get to specific values you have to move by chunks. Those chunks are represented within this variable. To get the size of the chunks of the first dimensions use: {{ i_chunks[0][0] }}.
  • {{ op_type }} the op type as some op_type like activation are using the same template.
  1. Test it using the utils function and place it in the tests folder. The test can look as follows:
fn test_matmul_square_matrix() {

    let n = 16;
    let mut input_data = HashMap::new();

    let data_a = ndarray::Array2::eye(n);
    let mut data_b = ndarray::Array2::<f32>::zeros((n, n));
    data_b[[0, 0]] = 0.2;
    data_b[[0, 1]] = 0.5;

    let sum = data_a.dot(&data_b);

    input_data.insert("A".to_string(), data_a.as_slice().unwrap());
    input_data.insert("B".to_string(), data_b.as_slice().unwrap());

    let n = n as i64;
    let model = model(graph(
        vec![tensor("A", &[n, n]), tensor("B", &[n, n])],
        vec![tensor("C", &[n, n])],
        vec![node(vec!["A", "B"], vec!["C"], "MatMul", "MatMul", vec![])],

    let session =
        pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

    let result = pollster::block_on(session.run(input_data)).unwrap();

    assert_eq!(result["C"].as_slice(), sum.as_slice().unwrap());

Check out tera documentation for other templating operation: https://tera.netlify.app/docs/

  1. If at any point you want to do optimisation of several node you can do it within sequencer.rs.

Supported Operators (ref ONNX IR)

Operator Since version Implemented
Abs 13, 6, 1
Acos 7
Acosh 9
Add 14, 13, 7, 6, 1
And 7, 1
ArgMax 13, 12, 11, 1
ArgMin 13, 12, 11, 1
Asin 7
Asinh 9
Atan 7
Atanh 9
AveragePool 11, 10, 7, 1
BatchNormalization 15, 14, 9, 7, 6, 1
BitShift 11
Cast 13, 9, 6, 1
Ceil 13, 6, 1
Clip 13, 12, 11, 6, 1
Compress 11, 9
Concat 13, 11, 4, 1
ConcatFromSequence 11
Constant 13, 12, 11, 9, 1
ConstantOfShape 9
Conv 11, 1
ConvInteger 10
ConvTranspose 11, 1
Cos 7
Cosh 9
CumSum 14, 11
DepthToSpace 13, 11, 1
DequantizeLinear 13, 10
Det 11
Div 14, 13, 7, 6, 1
Dropout 13, 12, 10, 7, 6, 1
Einsum 12
Elu 6, 1
Equal 13, 11, 7, 1
Erf 13, 9
Exp 13, 6, 1
Expand 13, 8
EyeLike 9
Flatten 13, 11, 9, 1
Floor 13, 6, 1
GRU 14, 7, 3, 1
Gather 13, 11, 1 (axis=0)
GatherElements 13, 11
GatherND 13, 12, 11
Gemm 13, 11, 9, 7, 6, 1
GlobalAveragePool 1
GlobalLpPool 2, 1
GlobalMaxPool 1
Greater 13, 9, 7, 1
GridSample 16
HardSigmoid 6, 1
Hardmax 13, 11, 1
Identity 16, 14, 13, 1
If 16, 13, 11, 1
InstanceNormalization 6, 1
IsInf 10
IsNaN 13, 9
LRN 13, 1
LSTM 14, 7, 1
LeakyRelu 6, 1
Less 13, 9, 7, 1
Log 13, 6, 1
Loop 16, 13, 11, 1
LpNormalization 1
LpPool 11, 2, 1
MatMul 13, 9, 1
MatMulInteger 10
Max 13, 12, 8, 6, 1
MaxPool 12, 11, 10, 8, 1
MaxRoiPool 1
MaxUnpool 11, 9
Mean 13, 8, 6, 1
Min 13, 12, 8, 6, 1
Mod 13, 10
Mul 14, 13, 7, 6, 1
Multinomial 7
Neg 13, 6, 1
NonMaxSuppression 11, 10
NonZero 13, 9
Not 1
OneHot 11, 9
Optional 15
OptionalGetElement 15
OptionalHasElement 15
Or 7, 1
PRelu 9, 7, 6, 1
Pad 13, 11, 2, 1
Pow 15, 13, 12, 7, 1
QLinearConv 10
QLinearMatMul 10
QuantizeLinear 13, 10
RNN 14, 7, 1
RandomNormal 1
RandomNormalLike 1
RandomUniform 1
RandomUniformLike 1
Reciprocal 13, 6, 1
ReduceL1 13, 11, 1
ReduceL2 13, 11, 1
ReduceLogSum 13, 11, 1
ReduceLogSumExp 13, 11, 1
ReduceMax 13, 12, 11, 1
ReduceMean 13, 11, 1
ReduceMin 13, 12, 11, 1
ReduceProd 13, 11, 1
ReduceSum 13, 11, 1
ReduceSumSquare 13, 11, 1
Relu 14, 13, 6, 1
Reshape 14, 13, 5, 1
Resize 13, 11, 10
ReverseSequence 10
RoiAlign 16, 10
Round 11
Scan 11, 9, 8
Scatter (deprecated) 11, 9
ScatterElements 16, 13, 11
ScatterND 16, 13, 11
Selu 6, 1
SequenceAt 11
SequenceConstruct 11
SequenceEmpty 11
SequenceErase 11
SequenceInsert 11
SequenceLength 11
Shape 15, 13, 1
Shrink 9
Sigmoid 13, 6, 1
Sign 13, 9
Sin 7
Sinh 9
Size 13, 1
Slice 13, 11, 10, 1
Softplus 1
Softsign 1
SpaceToDepth 13, 1
Split 13, 11, 2, 1
SplitToSequence 11
Sqrt 13, 6, 1
Squeeze 13, 11, 1
StringNormalizer 10
Sub 14, 13, 7, 6, 1
Sum 13, 8, 6, 1
Tan 7
Tanh 13, 6, 1
TfIdfVectorizer 9
ThresholdedRelu 10
Tile 13, 6, 1
TopK 11, 10, 1
Transpose 13, 1
Trilu 14
Unique 11
Unsqueeze 13, 11, 1
Upsample (deprecated) 10, 9, 7
Where 16, 9
Xor 7, 1
Function Since version
Bernoulli 15
CastLike 15
Celu 12
DynamicQuantizeLinear 11
GreaterOrEqual 12
HardSwish 14
LessOrEqual 12
LogSoftmax 13, 11, 1
MeanVarianceNormalization 13, 9
NegativeLogLikelihoodLoss 13, 12
Range 11
Softmax 13, 11, 1
SoftmaxCrossEntropyLoss 13, 12
  • Implement `Shape` operator

    Implement `Shape` operator

    Hello! I gave Shape operator a try. I pretty much got it working, however there are some small issues:

    1. tests (run with env OP_TESTED=shape pytest tests/test_specific_op.py) almost pass, but they fail because of different types: it expects int64 while our output is float32. Not sure how to fix that (output array is i32 and I even tried setting scalar_type to I64, but this doesn't help)
    2. shader is actually const, since everything is determined at shader compile time (at template render to be precise). this might be an issue (e.g. for models with dynamic inference size). but I don't know how this could be solved.
    3. minor thing: because of the point above, we don't actually use input_0 in the shader code. but if we don't use it, then it gets removed and then bindings don't check out. I solved this by binding it to unused variable, but there could be a prettier way.

    P.S. I'll cleanup the code once the other issues are solved ;]

    opened by LoipesMas 3
  • Support Stable Diffusion model

    Support Stable Diffusion model

    Is your feature request related to a problem? Please describe. I would like to be able to run Stable Diffusion using wonnx

    Describe the solution you'd like At least, these operators are missing and should be implemented before even trying too run Stable Diffusion on wonnx: Einsum, Erf, Expand, InstanceNormalization, Shape, Slice

    This is the minimum based on this guide that simplifies the onnx model (see the simplification table): https://www.photoroom.com/tech/stable-diffusion-25-percent-faster-and-save-seconds/

    Probably many more things will be needed, but I'm creating this issue because it can be a really interesting use case to be able to run SD in rust on the GPU directly.

    I don't have much experience with wonnx or even ML, but I decided to create this issue because it surprised me how few operators are missing to run this model. I would need to get more experience with stable diffusion, diffusers library and onnx in python before attempting to port it here, but maybe there are more experienced users interested too.

    opened by siriux 5
  • Can't run a single linear layer

    Can't run a single linear layer

    Describe the bug I try to export a single linear layer from PyTorch and get one of the following errors. Error 1: GpuError(CompileError { node: "Gemm_0", error: InvalidInputShape { input_index: 1, input_shape: Shape { dims: [10, 784], data_type: F32 } } }) Error 2: IrError(OutputNodeNotFound("onnx::Add_4"))

    I viewed the resulting onnx file at netron.app at it appeared to be correct.

    To Reproduce

    1. Run the following script
    torch_model = torch.nn.Linear(784, 10)
    model_input = torch.zeros((1, 784))    #This results in error 1. Changing shape to (784,) results in error 2
    torch.onnx.export(torch_model,           # model being run
                      model_input,                      # model input (or a tuple for multiple inputs)
                      "onnx/model.onnx",           # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=11,             # the ONNX version to export the model to
                      do_constant_folding=True, # whether to execute constant folding for optimization
                      input_names = ['input'],     # the model's input names
                      output_names = ['output'], # the model's output names
    1. Optionally run onnx-simplifier but it doesn't do anything on such a simple model.
    2. Run the following rust program
    fn main() {
        #[cfg(not(target_arch = "wasm32"))]
    async fn run () {
        let model_path = Path::new("onnx/model.onnx");
        let _session = wonnx::Session::from_path(model_path).await.unwrap();

    Expected behavior The model should load successfully.

    Desktop PopOS 20.04

    opened by Ryul0rd 6
  • IrError(Type(ParametrizedDimensionUnsupported(


    Describe the bug Exporting a HuggingFace model using the recommended method results in the following error: thread 'main' panicked at 'called 'Result::unwrap()' on an 'Err' value: IrError(Type(ParametrizedDimensionUnsupported("batch")))' The inclusion of the batch dimension is not only what HuggingFace tool does but also what the official PyTorch docs recommend for exporting to onnx.

    To Reproduce

    1. pip install transformers[onnx]
    2. python -m transformers.onnx --model=bert-base-uncased --feature=default onnx/
    fn main() {
        #[cfg(not(target_arch = "wasm32"))]
    async fn run () {
        let model_path = Path::new("onnx/model.onnx");
        let _session = wonnx::Session::from_path(model_path).await.unwrap();

    Expected behavior The unwrap call should not encounter an error.

    Desktop Linux PopOS 20.04

    opened by Ryul0rd 6
