A deep learning library for rust



An experimental deep learning library written in pure rust. Breakage expected on each release in the short term. See mnist.rs in examples or Rusty_SR for usage samples.


Issues are a great place for discussion, problems, requests, and coordinating future work.

Blatantly incorrect documentation contributions are encouraged as a way to guide efforts on docs, just submit a PR and fill a doc comment with anything from your best guess to passive aggressive nursery rhymes.


Patchy until the library settles down, particularly until the graph abstraction is finalised and the switch to ndarray is completed.


  • Computation hypergraph
  • Dense Connection and Bias operations
  • Loss functions
    • Mean Squared Error
    • Categorical Cross Entropy
    • SoftMax Cross Entropy
    • Binary Cross Entropy
  • Activations
    • Tanh
    • Logistic
    • Identity
    • ReLU
    • LeakyReLU
    • ELU
    • SoftMax
    • SRGB Curves
    • BeLU
    • SoftExp
    • SoftPlus
  • Spatial operations
    • Shape constraint propagation
    • N-dimensional Convolution
      • Arbitrary padding
      • Strides
    • N-dimensional AvgPooling
    • N-dimensional spaxel shuffling for "Sub-pixel Convolution"
    • N-dimensional Linear-Interpolation (backprop not finished)
    • Global Pooling
    • Broadcasting
  • Data Loading
    • Mnist
    • Cifar
    • Image Folders
    • Imagenet (ILSVRC)
  • SGD
  • RMSProp
  • ADAM
  • CAIN
    • Adaptive BatchSize
    • Adaptive Learning Rate
    • Adaptive Momentum
  • Basic numerical tests
  • Limit Optimiser evaluation batch size to stay within memory limits
  • Selectively disable calculation of forward values, node derivatives and parameter derivatives
  • Builder patterns for operation contruction
  • Split Graph struct into mutable GraphBuilder and immutable Sub-Graphs
    • Replace 'accidentally quadratic' graph algorithms
    • Replace up-front allocation with Sub-Graph optimised allocation/deallocation patterns based on liveness analysis of nodes
  • Overhaul data ingestion, particularly buffering input processing/reads.
  • Move to bluss' ndarray where possible (long overdue)
  • Improve naming inter/intra-library consistancy
  • Complete Documentation
  • Reduce ability to express illegal states in API
  • Move from panics to error-chain
  • Guard unsafe code rigourously
  • Comprehensive tests
  • Arrayfire as an option for sgemm on APUs


  • RNNs
  • Efficient probablistic structures (e.g. generative RNNs)
  • Graph optimisation passes and inplace operations
  • Support for both dynamic and static graphs



  • MNIST: Replace dots with hyphens in dataset filenames

    MNIST: Replace dots with hyphens in dataset filenames

    The original filenames as downloaded from http://yann.lecun.com/exdb/mnist/ don't contain dots in their filename but hyphens.

    Here's a screenshot of the website: mnist_filenames

    Even in 2004 was the hyphen used for the filenames rather than dots. I suspect an application did not liked handling files without extensions...

    This PR simply changes the expected file names to be loaded to contain hyphens.

    opened by nbigaouette 5
  • No license.

    No license.

    I would like to fork and modify to then pull, but hesitant to do so without any license, please add a license, I would recommend MIT, but it is your choice!

    opened by ergpopler 1
  • MNIST: Add download functionality

    MNIST: Add download functionality

    This PR adds a download_mnist cargo feature (enabled by default) that exposes a new function to download the MNIST dataset and extract it automatically.

    The function returns the directory where the files have been extracted so that it can be passed to functions loading the dataset.

    This prevents having hardcoded paths in the example code.

    opened by nbigaouette 1
  • Consider running a (recent) rustfmt on the source code

    Consider running a (recent) rustfmt on the source code

    Having a rust setup that autoformats the code using rustfmt makes it hard to hack on alumina since the two formats are quite different.

    Rustfmt 1.0 RC came out two weeks ago (https://www.ncameron.org/blog/rustfmt-1-rc/) and is easily installed using rustup: rustup component add rustfmt-preview.

    Then a simple cargo fmt will format the code base.

    This changes a lot of files so I think the maintainer should be in a better position to do this.

    Would you consider doing so?

    Thank you!

    opened by nbigaouette 0
  • How to use MSE loss?

    How to use MSE loss?

    thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Error(SubgraphInsufficientInputsForOutputs(["prediction_loss_gradient"]), State { next_error: None, backtrace: Some(stack backtrace:
       0:     0x55d11fc81d54 - backtrace::backtrace::libunwind::trace
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/backtrace-0.3.4/src/backtrace/libunwind.rs:53
                             - backtrace::backtrace::trace<closure>
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/backtrace-0.3.4/src/backtrace/mod.rs:42
       1:     0x55d11fc7bc2c - backtrace::capture::{{impl}}::new_unresolved
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/backtrace-0.3.4/src/capture.rs:88
       2:     0x55d11fc7bb7e - backtrace::capture::{{impl}}::new
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/backtrace-0.3.4/src/capture.rs:63
       3:     0x55d11fc79406 - error_chain::make_backtrace
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/error-chain-0.11.0/src/lib.rs:616
       4:     0x55d11fc7946f - error_chain::{{impl}}::default
                            at /home/g/.cargo/registry/src/github.com-1ecc6299db9ec823/error-chain-0.11.0/src/lib.rs:710
       5:     0x55d11fc2ab91 - alumina::graph::{{impl}}::from_kind
                            at /home/g/Desktop/learned-index/<impl_error_chain_processed macros>:53
       6:     0x55d11fc2ae7c - alumina::graph::{{impl}}::from
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:16
       7:     0x55d11fc131fc - core::convert::{{impl}}::into<alumina::graph::ErrorKind,alumina::graph::Error>
                            at /checkout/src/libcore/convert.rs:415
       8:     0x55d11fc233a7 - alumina::graph::find_pass_order
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:1059
       9:     0x55d11fc1faad - alumina::graph::{{impl}}::new
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:730
      10:     0x55d11fc1693b - alumina::graph::{{impl}}::subgraph
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:194
      11:     0x55d11fc16d73 - alumina::graph::{{impl}}::default_subgraph
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/graph.rs:210
      12:     0x55d11fb08dd2 - alumina::opt::adam::{{impl}}::new
                            at /home/g/.cargo/git/checkouts/alumina-6fbb489aa505e418/7c1b79e/src/opt/adam.rs:38
      13:     0x55d11fa307bc - learned_index::nn::learn_test
                            at src/lib/nn.rs:40
      14:     0x55d11fa1f4d6 - learned_btree::main
                            at src/bin/btree/main.rs:13
      15:     0x55d11fd349ce - panic_unwind::__rust_maybe_catch_panic
                            at /checkout/src/libpanic_unwind/lib.rs:101
      16:     0x55d11fd20a73 - std::panic::catch_unwind<closure,()>
                            at /checkout/src/libstd/panicking.rs:459
                             - std::rt::lang_start
                            at /checkout/src/libstd/rt.rs:58
      17:     0x55d11fa1f51d - main
      18:     0x7f8ee148382f - __libc_start_main
      19:     0x55d11fa1be68 - _start
      20:                0x0 - <unknown>) })', /checkout/src/libcore/result.rs:916:4
    stack backtrace:
       0: std::sys::unix::backtrace::tracing::imp::unwind_backtrace
                 at /checkout/src/libstd/sys/unix/backtrace/tracing/gcc_s.rs:49
       1: std::sys_common::backtrace::print
                 at /checkout/src/libstd/sys_common/backtrace.rs:68
                 at /checkout/src/libstd/sys_common/backtrace.rs:57
       2: std::panicking::default_hook::{{closure}}
                 at /checkout/src/libstd/panicking.rs:381
       3: std::panicking::default_hook
                 at /checkout/src/libstd/panicking.rs:397
       4: std::panicking::rust_panic_with_hook
                 at /checkout/src/libstd/panicking.rs:577
       5: std::panicking::begin_panic
                 at /checkout/src/libstd/panicking.rs:538
       6: std::panicking::begin_panic_fmt
                 at /checkout/src/libstd/panicking.rs:522
       7: rust_begin_unwind
                 at /checkout/src/libstd/panicking.rs:498
       8: core::panicking::panic_fmt
                 at /checkout/src/libcore/panicking.rs:71
       9: core::result::unwrap_failed
                 at /checkout/src/libcore/macros.rs:23
      10: <core::result::Result<T, E>>::unwrap
                 at /checkout/src/libcore/result.rs:782
      11: learned_btree::main
                 at src/bin/btree/main.rs:13
      12: __rust_maybe_catch_panic
                 at /checkout/src/libpanic_unwind/lib.rs:101
      13: std::rt::lang_start
                 at /checkout/src/libstd/panicking.rs:459
                 at /checkout/src/libstd/rt.rs:58
      14: main
      15: __libc_start_main
      16: _start

    I am testing out this framework. So currently I am modifying the MNIST example, changing the loss function to MSE and getting the error above.

    /// A common mnist network with two hidden layers of 800 units and tanh activation functions
    fn mnist_tanh_800(regularise: f32) -> Result<GraphDef> {
        let mut g = GraphDef::new();
        let input = g.new_node(shape![Unknown, 1], "input", tag![])?;
        let labels = g.new_node(shape![Unknown, 1], "labels", tag![])?;
        let layer1 = g.new_node(shape![Unknown, 10], "layer1", tag![])?;
        let layer1_activ = g.new_node(shape![Unknown, 10], "layer1_activ", tag![])?;
        let layer2 = g.new_node(shape![Unknown, 10], "layer2", tag![])?;
        let layer2_activ = g.new_node(shape![Unknown, 10], "layer2_activ", tag![])?;
        let prediction = g.new_node(shape![Unknown, 1], "prediction", tag![])?;
        let softmax = g.new_node(shape![Unknown, 1], "softmax", tag![])?;
        let prediction_loss = g.new_node(shape![Unknown], "prediction_loss", tag![])?;
        g.new_op(Linear::new(&input, &layer1).init(Linear::msra(1.0)), tag![])?;
        g.new_op(Bias::new(&layer1), tag![])?;
        g.new_op(Tanh::new(&layer1, &layer1_activ), tag![])?;
        g.new_op(Linear::new(&layer1_activ, &layer2).init(Linear::msra(1.0)), tag![])?;
        g.new_op(Bias::new(&layer2), tag![])?;
        g.new_op(Tanh::new(&layer2, &layer2_activ), tag![])?;
        g.new_op(Linear::new(&layer2_activ, &prediction).init(Linear::msra(1.0)), tag![])?;
        g.new_op(Softmax::new(&prediction, &softmax), tag![])?;
        g.new_op(CrossEntropy::new(&softmax, &labels), tag![])?;
        g.new_op(Mse::new(&prediction, &labels).output(&prediction_loss), tag![])?;
    opened by 0b01 2
