A deep learning library for rust

Overview

Alumina

An experimental deep learning library written in pure rust. Breakage expected on each release in the short term. See mnist.rs in examples or Rusty_SR for usage samples.

Contributions

Issues are a great place for discussion, problems, requests, and coordinating future work.

Blatantly incorrect documentation contributions are encouraged as a way to guide efforts on docs, just submit a PR and fill a doc comment with anything from your best guess to passive aggressive nursery rhymes.

Documentation

Patchy until the library settles down, particularly until the graph abstraction is finalised and the switch to ndarray is completed.

Progress

  • Computation hypergraph
  • Dense Connection and Bias operations
  • Loss functions
    • Mean Squared Error
    • Categorical Cross Entropy
    • SoftMax Cross Entropy
    • Binary Cross Entropy
  • Activations
    • Tanh
    • Logistic
    • Identity
    • ReLU
    • LeakyReLU
    • ELU
    • SoftMax
    • SRGB Curves
    • BeLU
    • SoftExp
    • SoftPlus
  • Spatial operations
    • Shape constraint propagation
    • N-dimensional Convolution
      • Arbitrary padding
      • Strides
    • N-dimensional AvgPooling
    • N-dimensional spaxel shuffling for "Sub-pixel Convolution"
    • N-dimensional Linear-Interpolation (backprop not finished)
    • Global Pooling
    • Broadcasting
  • Data Loading
    • Mnist
    • Cifar
    • Image Folders
    • Imagenet (ILSVRC)
  • SGD
  • RMSProp
  • ADAM
  • CAIN
    • Adaptive BatchSize
    • Adaptive Learning Rate
    • Adaptive Momentum
  • Basic numerical tests
  • Limit Optimiser evaluation batch size to stay within memory limits
  • Selectively disable calculation of forward values, node derivatives and parameter derivatives
  • Builder patterns for operation contruction
  • Split Graph struct into mutable GraphBuilder and immutable Sub-Graphs
    • Replace 'accidentally quadratic' graph algorithms
    • Replace up-front allocation with Sub-Graph optimised allocation/deallocation patterns based on liveness analysis of nodes
  • Overhaul data ingestion, particularly buffering input processing/reads.
  • Move to bluss' ndarray where possible (long overdue)
  • Improve naming inter/intra-library consistancy
  • Complete Documentation
  • Reduce ability to express illegal states in API
  • Move from panics to error-chain
  • Guard unsafe code rigourously
  • Comprehensive tests
  • Arrayfire as an option for sgemm on APUs

Distant

  • RNNs
  • Efficient probablistic structures (e.g. generative RNNs)
  • Graph optimisation passes and inplace operations
  • Support for both dynamic and static graphs

License

MIT

Comments
  • MNIST: Replace dots with hyphens in dataset filenames

    MNIST: Replace dots with hyphens in dataset filenames

    The original filenames as downloaded from http://yann.lecun.com/exdb/mnist/ don't contain dots in their filename but hyphens.

    Here's a screenshot of the website: mnist_filenames

    Even in 2004 was the hyphen used for the filenames rather than dots. I suspect an application did not liked handling files without extensions...

    This PR simply changes the expected file names to be loaded to contain hyphens.

    opened by nbigaouette 5
  • No license.

    No license.

    I would like to fork and modify to then pull, but hesitant to do so without any license, please add a license, I would recommend MIT, but it is your choice!

    opened by ergpopler 1
  • MNIST: Add download functionality

    MNIST: Add download functionality

    This PR adds a download_mnist cargo feature (enabled by default) that exposes a new function to download the MNIST dataset and extract it automatically.

    The function returns the directory where the files have been extracted so that it can be passed to functions loading the dataset.

    This prevents having hardcoded paths in the example code.

    opened by nbigaouette 1
  • Consider running a (recent) rustfmt on the source code

    Consider running a (recent) rustfmt on the source code

    Having a rust setup that autoformats the code using rustfmt makes it hard to hack on alumina since the two formats are quite different.

    Rustfmt 1.0 RC came out two weeks ago (https://www.ncameron.org/blog/rustfmt-1-rc/) and is easily installed using rustup: rustup component add rustfmt-preview.

    Then a simple cargo fmt will format the code base.

    This changes a lot of files so I think the maintainer should be in a better position to do this.

    Would you consider doing so?

    Thank you!

    opened by nbigaouette 0
  • How to use MSE loss?

    How to use MSE loss?

    thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Error(SubgraphInsufficientInputsForOutputs(["prediction_loss_gradient"]), State { next_error: None, backtrace: Some(stack backtrace:
       0:     0x55d11fc81d54 - backtrace::backtrace::libunwind::trace
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/backtrace-0.3.4/src/backtrace/libunwind.rs:53
                             - backtrace::backtrace::trace<closure>
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/backtrace-0.3.4/src/backtrace/mod.rs:42
       1:     0x55d11fc7bc2c - backtrace::capture::{{impl}}::new_unresolved
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/backtrace-0.3.4/src/capture.rs:88
       2:     0x55d11fc7bb7e - backtrace::capture::{{impl}}::new
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/backtrace-0.3.4/src/capture.rs:63
       3:     0x55d11fc79406 - error_chain::make_backtrace
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/error-chain-0.11.0/src/lib.rs:616
       4:     0x55d11fc7946f - error_chain::{{impl}}::default
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/error-chain-0.11.0/src/lib.rs:710
       5:     0x55d11fc2ab91 - alumina::graph::{{impl}}::from_kind
                            at /home/g/Desktop/learned-index/<impl_error_chain_processed macros>:53
       6:     0x55d11fc2ae7c - alumina::graph::{{impl}}::from
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:16
       7:     0x55d11fc131fc - core::convert::{{impl}}::into<alumina::graph::ErrorKind,alumina::graph::Error>
                            at /checkout/src/libcore/convert.rs:415
       8:     0x55d11fc233a7 - alumina::graph::find_pass_order
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:1059
       9:     0x55d11fc1faad - alumina::graph::{{impl}}::new
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:730
      10:     0x55d11fc1693b - alumina::graph::{{impl}}::subgraph
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:194
      11:     0x55d11fc16d73 - alumina::graph::{{impl}}::default_subgraph
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:210
      12:     0x55d11fb08dd2 - alumina::opt::adam::{{impl}}::new
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/opt/adam.rs:38
      13:     0x55d11fa307bc - learned_index::nn::learn_test
                            at src/lib/nn.rs:40
      14:     0x55d11fa1f4d6 - learned_btree::main
                            at src/bin/btree/main.rs:13
      15:     0x55d11fd349ce - panic_unwind::__rust_maybe_catch_panic
                            at /checkout/src/libpanic_unwind/lib.rs:101
      16:     0x55d11fd20a73 - std::panic::catch_unwind<closure,()>
                            at /checkout/src/libstd/panicking.rs:459
                             - std::rt::lang_start
                            at /checkout/src/libstd/rt.rs:58
      17:     0x55d11fa1f51d - main
      18:     0x7f8ee148382f - __libc_start_main
      19:     0x55d11fa1be68 - _start
      20:                0x0 - <unknown>) })', /checkout/src/libcore/result.rs:916:4
    stack backtrace:
       0: std::sys::unix::backtrace::tracing::imp::unwind_backtrace
                 at /checkout/src/libstd/sys/unix/backtrace/tracing/gcc_s.rs:49
       1: std::sys_common::backtrace::print
                 at /checkout/src/libstd/sys_common/backtrace.rs:68
                 at /checkout/src/libstd/sys_common/backtrace.rs:57
       2: std::panicking::default_hook::{{closure}}
                 at /checkout/src/libstd/panicking.rs:381
       3: std::panicking::default_hook
                 at /checkout/src/libstd/panicking.rs:397
       4: std::panicking::rust_panic_with_hook
                 at /checkout/src/libstd/panicking.rs:577
       5: std::panicking::begin_panic
                 at /checkout/src/libstd/panicking.rs:538
       6: std::panicking::begin_panic_fmt
                 at /checkout/src/libstd/panicking.rs:522
       7: rust_begin_unwind
                 at /checkout/src/libstd/panicking.rs:498
       8: core::panicking::panic_fmt
                 at /checkout/src/libcore/panicking.rs:71
       9: core::result::unwrap_failed
                 at /checkout/src/libcore/macros.rs:23
      10: <core::result::Result<T, E>>::unwrap
                 at /checkout/src/libcore/result.rs:782
      11: learned_btree::main
                 at src/bin/btree/main.rs:13
      12: __rust_maybe_catch_panic
                 at /checkout/src/libpanic_unwind/lib.rs:101
      13: std::rt::lang_start
                 at /checkout/src/libstd/panicking.rs:459
                 at /checkout/src/libstd/rt.rs:58
      14: main
      15: __libc_start_main
      16: _start
    

    I am testing out this framework. So currently I am modifying the MNIST example, changing the loss function to MSE and getting the error above.

    /// A common mnist network with two hidden layers of 800 units and tanh activation functions
    #[allow(unused)]
    fn mnist_tanh_800(regularise: f32) -> Result<GraphDef> {
        let mut g = GraphDef::new();
    
        let input = g.new_node(shape![Unknown, 1], "input", tag![])?;
        let labels = g.new_node(shape![Unknown, 1], "labels", tag![])?;
    
        let layer1 = g.new_node(shape![Unknown, 10], "layer1", tag![])?;
        let layer1_activ = g.new_node(shape![Unknown, 10], "layer1_activ", tag![])?;
    
        let layer2 = g.new_node(shape![Unknown, 10], "layer2", tag![])?;
        let layer2_activ = g.new_node(shape![Unknown, 10], "layer2_activ", tag![])?;
    
        let prediction = g.new_node(shape![Unknown, 1], "prediction", tag![])?;
        let softmax = g.new_node(shape![Unknown, 1], "softmax", tag![])?;
    
        let prediction_loss = g.new_node(shape![Unknown], "prediction_loss", tag![])?;
    
        g.new_op(Linear::new(&input, &layer1).init(Linear::msra(1.0)), tag![])?;
        g.new_op(Bias::new(&layer1), tag![])?;
        g.new_op(Tanh::new(&layer1, &layer1_activ), tag![])?;
    
        g.new_op(Linear::new(&layer1_activ, &layer2).init(Linear::msra(1.0)), tag![])?;
        g.new_op(Bias::new(&layer2), tag![])?;
        g.new_op(Tanh::new(&layer2, &layer2_activ), tag![])?;
    
        g.new_op(Linear::new(&layer2_activ, &prediction).init(Linear::msra(1.0)), tag![])?;
        g.new_op(Softmax::new(&prediction, &softmax), tag![])?;
        g.new_op(CrossEntropy::new(&softmax, &labels), tag![])?;
    
        g.new_op(Mse::new(&prediction, &labels).output(&prediction_loss), tag![])?;
    
        Ok(g)
    }
    
    opened by 0b01 2
Owner
zza
zza
🦀 Example of serving deep learning models in Rust with batched prediction

rust-dl-webserver This project provides an example of serving a deep learning model with batched prediction using Rust. In particular it runs a GPT2 m

Evan Pete Walsh 28 Dec 15, 2022
Deep learning superresolution in pure rust

Rusty_SR A Rust super-resolution tool, which when given a low resolution image utilises deep learning to infer the corresponding high resolution image

zza 189 Dec 9, 2022
Open deep learning compiler stack for cpu, gpu and specialized accelerators

Open Deep Learning Compiler Stack Documentation | Contributors | Community | Release Notes Apache TVM is a compiler stack for deep learning systems. I

The Apache Software Foundation 8.9k Jan 4, 2023
Awesome deep learning crate

NeuroFlow is fast neural networks (deep learning) Rust crate. It relies on three pillars: speed, reliability, and speed again. Hello, everyone! Work o

Mikhail Kravets 70 Nov 20, 2022
Messing around with deep learning

Deep Learning Test Implementing deep learning in Rust using just a linear algebra library (nalgebra). The neural network (4 hidden layers, 32 neurons

Dmitry Zamkov 9 Jun 22, 2022
High performance distributed framework for training deep learning recommendation models based on PyTorch.

PERSIA (Parallel rEcommendation tRaining System with hybrId Acceleration) is developed by AI platform@Kuaishou Technology, collaborating with ETH. It

null 340 Dec 30, 2022
Deep learning at the speed of light.

luminal Deep learning at the speed of light. use luminal::prelude::*; // Setup graph and tensors let mut cx = Graph::new(); let a = cx.new_tensor::<R

Joe Fioti 3 Jul 25, 2023
Deep recommender systems for Rust

sbr An implementation of sequence recommenders based on the wyrm autdifferentiaton library. sbr-rs sbr implements efficient recommender algorithms whi

Maciej Kula 112 Dec 24, 2022
☁ Puff ☁ - The deep stack framework.

☁ Puff ☁ Python with an async runtime built-in Rust for GraphQL, ASGI, WSGI, Postgres, PubSub, Redis, Distributed Tasks, and HTTP2 Client. What is Puf

Kyle Hanson 290 Jan 8, 2023
Flexible, reusable reinforcement learning (Q learning) implementation in Rust

Rurel Rurel is a flexible, reusable reinforcement learning (Q learning) implementation in Rust. Release documentation In Cargo.toml: rurel = "0.2.0"

Milan Boers 60 Dec 29, 2022
A Rust library with homemade machine learning models to classify the MNIST dataset. Built in an attempt to get familiar with advanced Rust concepts.

mnist-classifier Ideas UPDATED: Finish CLI Flags Parallelize conputationally intensive functions Class-based naive bayes README Image parsing Confusio

Neil Kaushikkar 0 Sep 2, 2021
Machine Learning library for Rust

rusty-machine This library is no longer actively maintained. The crate is currently on version 0.5.4. Read the API Documentation to learn more. And he

James Lucas 1.2k Dec 31, 2022
Machine Learning Library for Rust

autograph Machine Learning Library for Rust undergoing maintenance Features Portable accelerated compute Run SPIR-V shaders on GPU's that support Vulk

null 223 Jan 1, 2023
Reinforcement learning library written in Rust

REnforce Reinforcement library written in Rust This library is still in early stages, and the API has not yet been finalized. The documentation can be

Niven Achenjang 20 Jun 14, 2022
Border is a reinforcement learning library in Rust

Border Border is a reinforcement learning library in Rust. For reusability of both RL environments and agents, this library provides a reference imple

Taku Yoshioka 1 Dec 15, 2022
Mars is a rust machine learning library. [Goal is to make Simple as possible]

Mars Mars (ma-rs) is an blazingly fast rust machine learning library. Simple and Powerful! ?? ?? Contribution: Feel free to build this project. This i

KoBruh 3 Dec 25, 2022
A machine learning library in Rust from scratch.

Machine Learning in Rust Learn the Rust programming language through implementing classic machine learning algorithms. This project is self-completed

Chi Zuo 39 Jan 17, 2023
convolutions-rs is a crate that provides a fast, well-tested convolutions library for machine learning

convolutions-rs convolutions-rs is a crate that provides a fast, well-tested convolutions library for machine learning written entirely in Rust with m

null 10 Jun 28, 2022
A machine learning library for supervised training of parametrized models

Vikos Vikos is a library for supervised training of parameterized, regression, and classification models Design Goals Model representations, cost func

Blue Yonder GmbH 10 May 10, 2022