MesaTEE GBDT-RS : a fast and secure GBDT library, supporting TEEs such as Intel SGX and ARM TrustZone

Overview

MesaTEE GBDT-RS : a fast and secure GBDT library, supporting TEEs such as Intel SGX and ARM TrustZone

Build Status codecov

MesaTEE GBDT-RS is a gradient boost decision tree library written in Safe Rust. There is no unsafe rust code in the library.

MesaTEE GBDT-RS provides the training and inference capabilities. And it can use the models trained by xgboost to do inference tasks.

New! The MesaTEE GBDT-RS paper has been accepted by IEEE S&P'19!

Supported Task

Supppoted task for both training and inference

  1. Linear regression: use SquaredError and LAD loss types
  2. Binary classification (labeled with 1 and -1): use LogLikelyhood loss type

Compatibility with xgboost

At this time, MesaTEE GBDT-RS support to use model trained by xgboost to do inference. The model should be trained by xgboost with following configruation:

  1. booster: gbtree
  2. objective: "reg:linear", "reg:logistic", "binary:logistic", "binary:logitraw", "multi:softprob", "multi:softmax" or "rank:pairwise".

We have tested that MesaTEE GBDT-RS is compatible with xgboost 0.81 and 0.82

Quick Start

Training Steps

  1. Set configuration
  2. Load training data
  3. Train the model
  4. (optional) Save the model

Inference Steps

  1. Load the model
  2. Load the test data
  3. Inference the test data

Example

    use gbdt::config::Config;
    use gbdt::decision_tree::{DataVec, PredVec};
    use gbdt::gradient_boost::GBDT;
    use gbdt::input::{InputFormat, load};

    let mut cfg = Config::new();
    cfg.set_feature_size(22);
    cfg.set_max_depth(3);
    cfg.set_iterations(50);
    cfg.set_shrinkage(0.1);
    cfg.set_loss("LogLikelyhood"); 
    cfg.set_debug(true);
    cfg.set_data_sample_ratio(1.0);
    cfg.set_feature_sample_ratio(1.0);
    cfg.set_training_optimization_level(2);

    // load data
    let train_file = "dataset/agaricus-lepiota/train.txt";
    let test_file = "dataset/agaricus-lepiota/test.txt";

    let mut input_format = InputFormat::csv_format();
    input_format.set_feature_size(22);
    input_format.set_label_index(22);
    let mut train_dv: DataVec = load(train_file, input_format).expect("failed to load training data");
    let test_dv: DataVec = load(test_file, input_format).expect("failed to load test data");

    // train and save model
    let mut gbdt = GBDT::new(&cfg);
    gbdt.fit(&mut train_dv);
    gbdt.save_model("gbdt.model").expect("failed to save the model");

    // load model and do inference
    let model = GBDT::load_model("gbdt.model").expect("failed to load the model");
    let predicted: PredVec = model.predict(&test_dv);

Example code

  • Linear regression: examples/iris.rs
  • Binary classification: examples/agaricus-lepiota.rs

Use models trained by xgboost

Steps

  1. Use xgboost to train a model
  2. Use examples/convert_xgboost.py to convert the model
    • Usage: python convert_xgboost.py xgboost_model_path objective output_path
    • Note convert_xgboost.py depends on xgboost python libraries. The converted model can be used on machines without xgboost
  3. In rust code, call GBDT::load_from_xgboost(model_path, objective) to load the model
  4. Do inference
  5. (optional) Call GBDT::save_model to save the model to MesaTEE GBDT-RS native format.

Example code

  • "reg:linear": examples/test-xgb-reg-linear.rs
  • "reg:logistic": examples/test-xgb-reg-logistic.rs
  • "binary:logistic": examples/test-xgb-binary-logistic.rs
  • "binary:logitraw": examples/test-xgb-binary-logistic.rs
  • "multi:softprob": examples/test-xgb-multi-softprob.rs
  • "multi:softmax": examples/test-xgb-multi-softmax.rs
  • "rank:pairwise": examples/test-xgb-rank-pairwise.rs

Multi-threading

Training:

At this time, training in MesaTEE GBDT-RS is single-threaded.

Inference:

The related inference functions are single-threaded. But they are thread-safe. We provide an inference example using multi threads in example/test-multithreads.rs

SGX usage

Because MesaTEE GBDT-RS is written in pure rust, with the help of rust-sgx-sdk, it can be used in sgx enclave easily as:

gbdt_sgx = { git = "https://github.com/mesalock-linux/gbdt-rs" }

This would import a crate named gbdt_sgx. If you prefer gbdt as normal:

gbdt = { package = "gbdt_sgx", git = "https://github.com/mesalock-linux/gbdt-rs" }

For more information and concret examples, please look at directory sgx/gbdt-sgx-test.

License

Apache 2.0

Authors

Tianyi Li @n0b0dyCN [email protected]

Tongxin Li @litongxin1991 [email protected]

Yu Ding @dingelish [email protected]

Steering Committee

Tao Wei, Yulong Zhang

Acknowledgment

Thanks to @qiyiping for his/her great previous work gbdt. We read his/her code before starting this project.

Comments
  • Unable to use xgboost trained model

    Unable to use xgboost trained model

    Hello, Thank you for the great library. I have an issue, I train a model in python using xgboost, but when I try to use "from_xgboost_dump" I get the following error: thread 'main' panicked at 'failed to load the model: ParseIntError { kind: Empty }'

    I was wondering if it would be related to the version of the used xgboost library?!

    Thank you

    opened by NadirSahllal 4
  • Enhance error return to fulfill std::error::Error + 'static + Send + Sync

    Enhance error return to fulfill std::error::Error + 'static + Send + Sync

    Hi, I am a user of this wonderful lib, a huge big thanks first!

    recently, I try to use anyhow in my application and causing this error

    E0277: `dyn std::error::Error` cannot be shared between threads safely  `dyn std::error::Error` cannot be shared bet
    ween threads safely  help: the trait `std::marker::Sync` is not implemented for `dyn std::error::Error` note: requir
    ed because of the requirements on the impl of `std::marker::Sync` for `std::ptr::Unique<dyn std::error::Error>` note
    : required because it appears within the type `std::boxed::Box<dyn std::error::Error>` note: required because of the
     requirements on the impl of `std::convert::From<std::boxed::Box<dyn std::error::Error>>` for `anyhow::Error` note:
    required by `std::convert::From::from`
    

    so I modify all the return error type from

    Box<dyn std::error::Error>
    

    into

    Box<dyn std::error::Error + 'static + Sync + Send>
    

    this may lead to some benefits:

    1. able to use this lib in multithreads
    2. able to use anyhow this modern application library

    hope you will accept this, thank you.

    opened by tommady 2
  • Impurity calculation question

    Impurity calculation question

    Here is the get_impurity method in decision_tree.rs: for pair in sorted_data.iter() { let (index, feature_value) = *pair; if feature_value == VALUE_TYPE_UNKNOWN { let cv: &CacheValue = &cache.cache_value[index]; s += cv.s; ss += cv.ss; c += cv.c; unknown += 1; } else { break; } } I wanna ask why use break instead of continue to break the loop. I think the purpose of this code is to count and calculate the unknown value of samples.

    opened by czzmmc 2
  • CI: add code coverage

    CI: add code coverage

    This PR add auto code coverage into CI.

    • https://ci.mesalock-linux.org/mesalock-linux/gbdt-rs/31/25/2

    Since all stdout will be written into database. If possible, please remove some logs of examples into stdout to reduce the execution time of the pipeline.

    opened by mssun 0
  • ValueType

    ValueType

    Hey, i would like to change the ValueType from f32 to f64. Is there any way to do it from myself or you have to implement new stuffs ?

    I saw in your code source that you defined in this way :

    ///! For now we only support std::$t using this macro.
    /// We will generalize ValueType in future.
    macro_rules! def_value_type {
        ($t: tt) => {
            pub type ValueType = $t;
            pub const VALUE_TYPE_MAX: ValueType = std::$t::MAX;
            pub const VALUE_TYPE_MIN: ValueType = std::$t::MIN;
            pub const VALUE_TYPE_UNKNOWN: ValueType = VALUE_TYPE_MIN;
        };
    }
    
    // use continous variables for decision tree
    def_value_type!(f32);
    

    Can you bind the ValueType to an f64 behind the feature flag ?

    Thanks, Alexis D.

    opened by alexis2804 0
  • small difference between gdbt-rs and rust-xgboost(native)

    small difference between gdbt-rs and rust-xgboost(native)

    Hi,

    I'm experiencing small delta between prediction (same model, same inputs), of gdbt-rs and rust-xgboost, using xbtree and logistic regression, (https://github.com/davechallis/rust-xgboost) which is based on the C++ implementation.

    I'm researching this at the moment and suspect a few causers:

    1. floating point precision differences native to C++ vs Rust
    2. different XGB implementation
    3. I'm training on python and loading into Rust via the convert script -- so maybe a problem in reading the dump on the Rust side (I assume the save side is OK because its using the C++ lib)

    From your experience is this a known issue? or maybe you can point me into a more specific direction to research from what I listed above?

    Thanks

    UPDATE: I have now narrowed it down to initializing parameters on the Python side vs Rust side. Looks like some of the parameters are not loaded or taking into account differently. When both models in Python and Rust sides are loaded with no parameters - results are equal.

    opened by jondot 0
  • Crossfold analysis example?

    Crossfold analysis example?

    Do you have a way to do this?

    Or do I have to split the data - fine to do - but then can I call model.fit() multiple times to update or will it overwrite the model?

    opened by ronniec95 1
  • Use more Rust like interfaces for Config

    Use more Rust like interfaces for Config

    Just some small changes

    For gbdt::config::Config,

    • a ConfigBuilder (following the builder pattern) that has a build() method to create a Config() object is more Rust like.
    • The to_string() method is not necessary if you implement Display.
    • The config.set_loss function takes str rather than an enum.

    Are you open to PRs?

    opened by ronniec95 1
  • Feature importance

    Feature importance

    Feature importance shows how valuable each feature was in training a model. It can be used to select features.

    It will be a good feature for gbdt-rs to provide feature importance.

    enhancement 
    opened by litongxin1991 0
Owner
MesaLock Linux
A Memory-Safe Linux Distribution
MesaLock Linux
Robust and Fast tokenizations alignment library for Rust and Python

Robust and Fast tokenizations alignment library for Rust and Python

Yohei Tamura 14 Dec 10, 2022
A fast and cross-platform Signed Distance Function (SDF) viewer, easily integrated with your SDF library.

SDF Viewer (demo below) A fast and cross-platform Signed Distance Function (SDF) viewer, easily integrated with your SDF library. A Signed Distance Fu

null 39 Dec 21, 2022
Rust wrapper for the Fast Artificial Neural Network library

fann-rs Rust wrapper for the Fast Artificial Neural Network (FANN) library. This crate provides a safe interface to FANN on top of the low-level bindi

Andreas Fackler 12 Jul 17, 2022
l2 is a fast, Pytorch-style Tensor+Autograd library written in Rust

l2 • ?? A Pytorch-style Tensor+Autograd library written in Rust Installation • Contributing • Authors • License • Acknowledgements Made by Bilal Khan

Bilal Khan 163 Dec 25, 2022
convolutions-rs is a crate that provides a fast, well-tested convolutions library for machine learning

convolutions-rs convolutions-rs is a crate that provides a fast, well-tested convolutions library for machine learning written entirely in Rust with m

null 10 Jun 28, 2022
Ecosystem of libraries and tools for writing and executing extremely fast GPU code fully in Rust.

Ecosystem of libraries and tools for writing and executing extremely fast GPU code fully in Rust.

Riccardo D'Ambrosio 2.1k Jan 5, 2023
Ecosystem of libraries and tools for writing and executing fast GPU code fully in Rust.

The Rust CUDA Project An ecosystem of libraries and tools for writing and executing extremely fast GPU code fully in Rust Guide | Getting Started | Fe

Rust GPU 2.1k Dec 30, 2022
💥 Fast State-of-the-Art Tokenizers optimized for Research and Production

Provides an implementation of today's most used tokenizers, with a focus on performance and versatility. Main features: Train new vocabularies and tok

Hugging Face 6.2k Jan 2, 2023
A fast, safe and easy to use reinforcement learning framework in Rust.

RSRL (api) Reinforcement learning should be fast, safe and easy to use. Overview rsrl provides generic constructs for reinforcement learning (RL) expe

Thomas Spooner 139 Dec 13, 2022
K-dimensional tree in Rust for fast geospatial indexing and lookup

kdtree K-dimensional tree in Rust for fast geospatial indexing and nearest neighbors lookup Crate Documentation Usage Benchmark License Usage Add kdtr

Rui Hu 154 Jan 4, 2023
Fast, accessible and privacy friendly AI deployment

Mithril Security - BlindAI Website | LinkedIn | Blog | Twitter | Documentation | Discord Fast, accessible and privacy friendly AI deployment ?? ?? Bli

Mithril Security 312 Dec 23, 2022
scalable and fast unofficial osu! server implementation

gamma! the new bancho server for theta! built for scalability and speed configuration configuration is done either through gamma.toml, or through envi

null 3 Jan 7, 2023
Fast hierarchical agglomerative clustering in Rust.

kodama This crate provides a fast implementation of agglomerative hierarchical clustering. This library is released under the MIT license. The ideas a

Diffeo 61 Oct 7, 2022
Fwumious Wabbit, fast on-line machine learning toolkit written in Rust

Fwumious Wabbit is a very fast machine learning tool built with Rust inspired by and partially compatible with Vowpal Wabbit (much love! read more abo

Outbrain 115 Dec 9, 2022
Instance Distance is a fast pure-Rust implementation of the Hierarchical Navigable Small Worlds paper

Fast approximate nearest neighbor searching in Rust, based on HNSW index

Instant Domain Search, Inc. 135 Dec 24, 2022
FFSVM stands for "Really Fast Support Vector Machine"

In One Sentence You trained a SVM using libSVM, now you want the highest possible performance during (real-time) classification, like games or VR. Hig

Ralf Biedert 53 Nov 24, 2022
A blazing fast CLIP gRPC service in rust.

CLIP as service in Rust A blazing fast gRPC server for CLIP model, powered by ONNX. Only text model can be used now. Build cargo build --bin clip-as-s

Rorical 6 Mar 6, 2023
Network-agnostic, high-level game networking library for client-side prediction and server reconciliation.

WARNING: This crate currently depends on nightly rust unstable and incomplete features. crystalorb Network-agnostic, high-level game networking librar

Ernest Wong 175 Dec 31, 2022
Msgpack serialization/deserialization library for Python, written in Rust using PyO3, and rust-msgpack. Reboot of orjson. msgpack.org[Python]

ormsgpack ormsgpack is a fast msgpack library for Python. It is a fork/reboot of orjson It serializes faster than msgpack-python and deserializes a bi

Aviram Hassan 139 Dec 30, 2022