BURN: Burn Unstoppable Rusty Neurons

Overview

BURN

BURN: Burn Unstoppable Rusty Neurons

This library aims to be a complete deep learning framework with extreme flexibility written in Rust. The goal would be to satisfy researchers as well as practitioners making it easier to experiment, train and deploy your solution.

Why Rust?

A big benefit of using Rust instead of Python is to allow performant multi-threaded deep learning networks which might open new doors for more efficient models. Scale seems to be very important, but the only tool we currently have to achieve it is big matrix multiplication on GPUs. This often implies big batch sizes, which is impossible for online learning. Also, asynchronous sparsely activated networks without copying weights is kind of impossible to achieve with Python (or really hard without proper threading).

Burn-Tensor

BURN has its own tensor library supporting multiple backends, it can also be used for other scientific computing applications. Click here for more details.

Module Definition

Currently working on it ... 💻

Comments
  • Issues running and changing the backend of the mnist example

    Issues running and changing the backend of the mnist example

    Ubuntu 20.04 LTS, NVIDIA 3070 GPU (Driver 510.85.02, CUDA Version 11.6)

    I am able to run the example as is and it trains successfully but it is very slow and appears to not be fully utilizing all the cores on my cpu. However at what appears to be the end of Epoch 2 (Last progress printout reports Iteration 80 Epoch 2/6, with 2 full bars) it crashes with this message:

    thread 'main' panicked at 'called Result::unwrap() on an Err value: SendError { .. }', burn/burn/src/train/checkpoint/async_checkpoint.rs:68:40

    I changed the example to use the Tch backend by changing main to this:

    fn main() {
        use burn::tensor::backend::TchADBackend;
    
        let device = TchDevice::Cpu;
        training::run::<TchADBackend<f32>>(device);
        println!("Done.");
    }
    

    Which appeares to train using my full Cpu at a great speeds but then crashed both tries in 2 different ways. The first is the same message as above and upon using the vscode debugger it crashed in a different way:

    thread '' panicked at 'attempt to subtract with overflow', burn/burn/src/train/checkpoint/file.rs:41:60

    In that case epoch was 1 and self.num_keep was 2

    I changed the example main as follows to try to use my GPU:

    fn main() {
        use burn::tensor::backend::TchADBackend;
    
        let device = TchDevice::Cuda(0);
        training::run::<TchADBackend<f32>>(device);
        println!("Done.");
    }
    

    My first question is what does the magic number in TchDevice::Cuda(XXX) represent?

    Then even with various numbers for that value (0, 1, 1024) the application crashes on the line model.to_device(device); I always get this error message which I have been unable to solve:

    thread 'main' panicked at 'called Result::unwrap() on an Err value: Torch("Could not run 'aten::empty_strided' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::empty_strided' is only available for these backends: [Dense, Conjugate, Negative, UNKNOWN_TENSOR_TYPE_ID, QuantizedXPU, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseCPU, SparseCUDA, SparseHIP, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseXPU, UNKNOWN_TENSOR_TYPE_ID, SparseVE, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, NestedTensorCUDA, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID].\n\nCPU: registered at aten/src/ATen/RegisterCPU.cpp:37386 [kernel]\nMeta: registered at aten/src/ATen/RegisterMeta.cpp:31637 [kernel]\nQuantizedCPU: registered at aten/src/ATen/RegisterQuantizedCPU.cpp:1294 [kernel]\nBackendSelect: registered at aten/src/ATen/RegisterBackendSelect.cpp:726 [kernel]\nPython: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:133 [backend fallback]\nNamed: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]\nConjugate: fallthrough registered at ../aten/src/ATen/ConjugateFallback.cpp:22 [kernel]\nNegative: fallthrough registered at ../aten/src/ATen/native/NegateFallback.cpp:22 [kernel]\nZeroTensor: fallthrough registered at ../aten/src/ATen/ZeroTensorFallback.cpp:90 [kernel]\nADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]\nAutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nUNKNOWN_TENSOR_TYPE_ID: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nUNKNOWN_TENSOR_TYPE_ID: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nAutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_2.cpp:14210 [autograd kernel]\nTracer: registered at ../torch/csrc/autograd/generated/TraceType_2.cpp:14069 [kernel]\nAutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:481 [backend fallback]\nAutocast: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:324 [backend fallback]\nBatched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]\nVmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]\nFunctionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:89 [backend fallback]\nPythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:137 [backend fallback]\n\nException raised from reportError at ../aten/src/ATen/core/dispatch/OperatorEntry.cpp:447 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0x6b (0x7f95aa2a79cb in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libc10.so)\nframe #1: c10::impl::OperatorEntry::reportError(c10::DispatchKey) const + 0x36b (0x7f95ab5e252b in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #2: + 0x1b4df9b (0x7f95abe40f9b in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #3: at::_ops::empty_strided::redispatch(c10::DispatchKeySet, c10::ArrayRef, c10::ArrayRef, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional) + 0xac (0x7f95ac011e6c in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #4: + 0x1fac735 (0x7f95ac29f735 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #5: at::_ops::empty_strided::call(c10::ArrayRef, c10::ArrayRef, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional) + 0x174 (0x7f95ac054114 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #6: at::empty_strided(c10::ArrayRef, c10::ArrayRef, c10::TensorOptions) + 0xd8 (0x55f15452c2a8 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #7: at::native::_to_copy(at::Tensor const&, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional, bool, c10::optionalc10::MemoryFormat) + 0x1447 (0x7f95aba2cf97 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #8: + 0x21479e3 (0x7f95ac43a9e3 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #9: at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional, bool, c10::optionalc10::MemoryFormat) + 0x10d (0x7f95abd9d78d in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #10: + 0x1faef51 (0x7f95ac2a1f51 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #11: at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional, bool, c10::optionalc10::MemoryFormat) + 0x10d (0x7f95abd9d78d in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #12: + 0x2fd82be (0x7f95ad2cb2be in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #13: + 0x2fd883b (0x7f95ad2cb83b in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #14: at::_ops::_to_copy::call(at::Tensor const&, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional, bool, c10::optionalc10::MemoryFormat) + 0x202 (0x7f95abe1a1e2 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #15: at::native::to(at::Tensor const&, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional, bool, bool, c10::optionalc10::MemoryFormat) + 0x13e (0x7f95aba22dde in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #16: + 0x2251799 (0x7f95ac544799 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #17: at::_ops::to_dtype_layout::call(at::Tensor const&, c10::optionalc10::ScalarType, c10::optionalc10::Layout, c10::optionalc10::Device, c10::optional, bool, bool, c10::optionalc10::MemoryFormat) + 0x216 (0x7f95abf47b26 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/build/torch-sys-d8a9710e31a4996b/out/libtorch/libtorch/lib/libtorch_cpu.so)\nframe #18: at::Tensor::to(c10::TensorOptions, bool, bool, c10::optionalc10::MemoryFormat) const + 0xf0 (0x55f1545286e4 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #19: + 0x247491 (0x55f154531491 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #20: + 0x225035 (0x55f15450f035 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #21: + 0x226137 (0x55f154510137 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #22: + 0xdaf55 (0x55f1543c4f55 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #23: + 0xaa848 (0x55f154394848 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #24: + 0x9ec37 (0x55f154388c37 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #25: + 0x15bf7e (0x55f154445f7e in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #26: + 0x114627 (0x55f1543fe627 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #27: + 0x15b097 (0x55f154445097 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #28: + 0x1304b7 (0x55f15441a4b7 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #29: + 0x15b81b (0x55f15444581b in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #30: + 0x12f1b7 (0x55f1544191b7 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #31: + 0x11c5d6 (0x55f1544065d6 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #32: + 0xa0170 (0x55f15438a170 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #33: + 0xb54cb (0x55f15439f4cb in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #34: + 0x130afe (0x55f15441aafe in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #35: + 0x133c81 (0x55f15441dc81 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #36: + 0x34b21f (0x55f15463521f in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #37: + 0x133c5a (0x55f15441dc5a in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #38: + 0xa01d1 (0x55f15438a1d1 in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\nframe #39: __libc_start_main + 0xf3 (0x7f95a9eb9083 in /lib/x86_64-linux-gnu/libc.so.6)\nframe #40: + 0x7962e (0x55f15436362e in /home/matthew/Projects/Rust/machine_learning/burn/target/debug/mnist)\n")', /home/matthew/.cargo/registry/src/github.com-1ecc6299db9ec823/tch-0.8.0/src/wrappers/tensor_generated.rs:12977:27

    opened by mwbryant 6
  • Make it easier to create tensors manually

    Make it easier to create tensors manually

    The easiest way I've found to create tensors so far is this:

    let tensor: Tensor::<B, 2> = Tensor::from_data(Data::<f32 , 2>::from([[1.0f32, 2.0], [3.0, 4.0]]).convert());
    

    This took quite a bit of digging to find out, and I'd expect this to be more straightforward.

    Suggestion 1 - Add a method directly to Tensor

    // Note: I think it should be possible to infer the type on the LHS here (assuming the backend can be inferred from context)
    let tensor: Tensor::<B, 2> = Tensor::from_array([[1.0f32, 2.0], [3.0, 4.0]]);
    

    Not sure if from_array is the best name here. We could also just use the From trait;

    Implementation Details

    This will require adding multiple impls like the following to Tensor:

    impl<B: Backend>  Tensor<B, 1> {
        pub fn from_array<const D1: usize>(data: [B::Elem; D1]) -> Self {
            todo!()
        }
    }
    
    impl<B: Backend, >  Tensor<B, 2> {
        pub fn from_array<const D1: usize, const D2: usize>(data: [[B::Elem; D1]; D2]) -> Self {
            todo!()
        }
    }
    
    

    Suggestion 2 - Add a macro

    // Note: Again, the LHS type should be inferrable
    let tensor: Tensor::<B, 2> = tensor![[1.0f32, 2.0], [3.0, 4.0]];
    
    opened by vultix 3
  • Fix `cargo run --example mnist`

    Fix `cargo run --example mnist`

    • Please check if the PR fulfills these requirements
    • [ ] The commit message follows our guidelines
    • [x] Docs have been added / updated (for bug fixes / features)
    • What kind of change does this PR introduce? (Bug fix, feature, docs update, ...)

    Fixes https://github.com/burn-rs/burn/issues/89

    • Does this PR introduce a breaking change? (What changes might users need to make in their application due to this PR?)

    Only changes the example

    opened by n8henrie 1
  • Weird compile error on MNIST example only when copying out of repo?

    Weird compile error on MNIST example only when copying out of repo?

    Describe the bug

    I copy pasted everything out of the MNIST src into my own project because I wanted to try making a test NN.

    It does not like one of the types in data.rs yet it compiles fine in the project itself?

      --> TicTacToeMaster\src\data.rs:29:42
       |
    29 |             .map(|tensor| tensor.reshape([1,784]))
       |                                  ------- ^^^^^^^ expected struct `Shape`, found array `[{integer}; 2]`
       |                                  |
       |                                  arguments to this function are incorrect
       |
       = note: expected struct `Shape<_>`
                   found array `[{integer}; 2]`
    note: associated function defined here
      --> C:\Users\itscrabs\.cargo\registry\src\github.com-1ecc6299db9ec823\burn-tensor-0.2.3\src\tensor\base.rs:31:12
       |
    31 |     pub fn reshape<const D2: usize>(&self, shape: Shape<D2>) -> Tensor<B, D2> {
       |            ^^^^^^^
    help: try wrapping the expression in `burn::tensor::Shape`
       |
    29 |             .map(|tensor| tensor.reshape(burn::tensor::Shape { dims: [1,784] }))
       |                                          +++++++++++++++++++++++++++         +
    
    For more information about this error, try `rustc --explain E0308`.
    error: could not compile `TicTacToeMaster` due to previous error
    

    To Reproduce

    Steps to reproduce the behavior:

    1. Create new cargo project
    2. Copy everything in examples/mnist to the project
    3. add serde and burn to Cargo,toml
    4. cargo run
    5. See error

    Expected behavior Example to run

    Additional context I'm not so great with GitHub but just forcing it into the correct Burn shape using new() seems to fix it: at examples/mnist/data.rs: line 29 should become: .map(|tensor| tensor.reshape(burn::tensor::Shape::new([1,784])))

    opened by CrabBucket 1
  • Have a landing page to promote Burn

    Have a landing page to promote Burn

    Feature description

    Put in place a landing page for the framework and attach to it thorough doc/usage.

    Feature motivation

    It could help in growing our community.

    opened by olgam4 1
  • Pretty print of Tensor

    Pretty print of Tensor

    Feature description

    Implement std::fmt::Display for Tensor and BoolTensor the same way for each Backend with basic metadata. This could look something like the following, but it's just a draft.

    Tensor {
      data: [[0., 0., 0.,  ..., 0., 0., 0.],
             [0., 0., 0.,  ..., 0., 0., 0.],
             [0., 0., 0.,  ..., 0., 0., 0.],
             ...,
             [0., 0., 0.,  ..., 0., 0., 0.],
             [0., 0., 0.,  ..., 0., 0., 0.],
             [0., 0., 0.,  ..., 0., 0., 0.]],
      shape:   [30, 30],
      device:  Cpu,
      backend: ndarray,
      dtype:   f32,
    }
    
    enhancement good first issue 
    opened by nathanielsimard 0
  • Support computing of partial derivatives

    Support computing of partial derivatives

    Feature description

    Add an (idiomatic) way of computing (mixed) partial derivatives of models.

    Feature motivation

    Most of the automatic differentiation libraries provide a first-class way of computing gradient of models of form f(X) -> Y. With the introduction of Physics-Informed Neural Networks there is a need for efficient computation of (mixed) partial derivatives (eg. df/dxdy) of models of form f(x, y, z) -> u.

    In currently most popular ML libraries, computation of partial derivatives comes with a significant performance overhead and sometimes requires some "hackery".

    As far as I see burn does not provide a method for computing partial derivatives of models. I believe that implementing this feature in Rust can benefit from its high performance. And if implemented as a first-class feature it can provide burn a significant advantage over other ML libraries.

    (Optional) Suggest a Solution

    In terms of API, pytorch and tensorflow provide a method for computing general gradients of form grad(model_output, one_of_model_inputs) -> partial_derivative_of_model_output_wrt_input.

    This API is convenient and easy to understand but requires multiple function calls for second and higher derivatives, which can introduce performance overhead - however in Rust this might not be an issue.

    opened by Quba1 1
  • Functional style for neural network modules (JAX)

    Functional style for neural network modules (JAX)

    Hello!

    I just found your project, this is very exciting work! In my spare time, I played around with AD in Rust and came to similar design decisions as you did. I noticed that you are using a a PyTorch style approach for the modules. One thing I wanted to make you aware of is the functional style used by JAX. In this design, modules are not stateful, instead they only have pure functions, and the user maintains all state. Many ML researcher found this style to be superior. It also leads to a more modular design of the library ecosystem.

    Here is an example of how this could look like (code is copied from my snippets so does not match Burn exactly):

    pub trait Module {
        type Params<B: Backend>;
        type Input<B: Backend>;
        type Output<B: Backend>;
    
        fn params<B: Backend>(&self) -> Self::Params<B>;
    
        fn forward<B: Backend>(&self, params: &Self::Params<B>, input: &Self::Input<B>) -> Self::Output<B>;
    }
    
    pub struct LinearParams<B: Backend> {
        w: Tensor2<B>,
        b: Tensor1<B>,
    }
    
    pub struct Linear {
        inputs: usize,
        hiddens: usize,
    }
    
    impl Linear {
        pub fn new(inputs: usize, hiddens: usize) -> Self {
            Self { inputs, hiddens }
        }
    }
    
    impl Module for Linear {
        type Params<B: Backend> = LinearParams<B>;
        type Input<B: Backend> = Tensor2<B>;
    
        fn params<B: Backend>(&self) -> Self::Params<B> {
            LinearParams {
                w: Tensor::ones((self.inputs, self.hiddens)),
                b: Tensor::ones(self.hiddens),
            }
        }
    
        fn forward<B: Backend>(&self, params: &Self::Params<B>, input: &Self::Output<B>) -> Self::Output<B> {
            params.w.dot(&input).add(&params.b)
        }
    }
    

    And then user code to run this looks as follows:

    fn main() {
        let x = Tensor2::<NdArrayBackend>::ones((8, 10));
        let h = Linear::new(10, 30);
        let p = h.params();
        let y = h.forward(&p, &x);
        println!("{:?}", y);
    }
    
    opened by makroiss 4
Releases(v0.4.0)
Owner
Nathaniel Simard
Nathaniel Simard
Stable Diffusion v1.4 ported to Rust's burn framework

Stable-Diffusion-Burn Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn.

null 156 Aug 8, 2023
Stable Diffusion XL ported to Rust's burn framework

Stable-Diffusion-XL-Burn Stable-Diffusion-XL-Burn is a Rust-based project which ports stable diffusion xl into the Rust deep learning framework burn.

null 194 Sep 4, 2023
A Rusty CUDA wrapper

cuda-oxide cuda-oxide is a safe wrapper for CUDA. With cuda-oxide you can execute and coordinate CUDA kernels. Safety Philosophy cuda-oxide does not o

Max Bruce 30 Dec 7, 2022
A game made for the Rusty Jam https://itch.io/jam/rusty-jam

Murder-User Dungeon Introduction Tony is a young man. Finally having its own apartment is a good thing! He will learn how to live by himself and how t

null 62 Dec 6, 2022
An extended CW721 (v0.9.2) with update, burn, freeze, set_minter functionalities.

Extended CW721 Extended CW721 NFT with update, burn, freeze, set_minter functionalities

null 12 Oct 1, 2022
`dfx new --type=rust` + burn-rs MNIST web inference example

ic-mnist The frontend provides a canvas where users can draw a digit. The drawn digit is then sent to the backend canister running burn-rs for inferen

Marcin Nowak-Liebiediew 4 Jun 25, 2023
A Rust implementation of OpenAI's Whisper model using the burn framework

Whisper Burn: Rust Implementation of OpenAI's Whisper Transcription Model Whisper Burn is a Rust implementation of OpenAI's Whisper transcription mode

null 19 Jul 24, 2023
Stable Diffusion v1.4 ported to Rust's burn framework

Stable-Diffusion-Burn Stable-Diffusion-Burn is a Rust-based project which ports the V1 stable diffusion model into the deep learning framework, Burn.

null 156 Aug 8, 2023
Stable Diffusion XL ported to Rust's burn framework

Stable-Diffusion-XL-Burn Stable-Diffusion-XL-Burn is a Rust-based project which ports stable diffusion xl into the Rust deep learning framework burn.

null 194 Sep 4, 2023
Implementation of sentence embeddings with BERT in Rust, using the Burn library.

Sentence Transformers in Burn This library provides an implementation of the Sentence Transformers framework for computing text representations as vec

Tyler Vergho 4 Sep 4, 2023
A diffusers API in Burn (Rust)

diffusers-burn: A diffusers API in Rust/Burn ⚠️ This is still in development - contributors welcome! The diffusers-burn crate is a conversion of diffu

OxideAI 6 Nov 29, 2023
Rusty Object Notation

Rusty Object Notation RON is a simple readable data serialization format that looks similar to Rust syntax. It's designed to support all of Serde's da

ron-rs 2.3k Jan 1, 2023
A rusty dynamically typed scripting language

dyon A rusty dynamically typed scripting language Tutorial Dyon-Interactive Dyon Snippets /r/dyon Dyon script files end with .dyon. To run Dyon script

PistonDevelopers 1.5k Dec 27, 2022
RDFM - The Rusty DotFiles Manager

d8888b. d88888b. 8888888b 8888b d8888 88 `8D 88 `8D 88' 88'YbdP`88 88oobY' 88 88

Wafelack 40 Aug 14, 2022
rusty-donut - ASCII raymarching inside a terminal

ASCII raymarching inside a terminal

drip 14 Feb 9, 2022
A rusty, dual-wielding Quake and Half-Life texture WAD parser.

Ogre   A rusty, dual-wielding Quake and Half-Life texture WAD parser ogre is a rust representation and nom parser for Quake and Half-Life WAD files. I

Josh Palmer 16 Dec 5, 2022
A Rusty CUDA wrapper

cuda-oxide cuda-oxide is a safe wrapper for CUDA. With cuda-oxide you can execute and coordinate CUDA kernels. Safety Philosophy cuda-oxide does not o

Max Bruce 30 Dec 7, 2022
Rusty NuGet client

ruget It's a NuGet client built in Rust. It's not really meant to replace existing nuget clients. It's more of a playground for experimenting with rel

Kat Marchán 18 Feb 2, 2022
Rusty NuGet client

turron It's a NuGet client built in Rust. It's not really meant to replace existing nuget clients. It's more of a playground for experimenting with re

Kat Marchán 18 Feb 2, 2022
Rusty Armor Builds - Monster Hunter Rise Armor Set Creation Tool

RAB Rusty Armor Builds - Monster Hunter Rise Armor Set Creation Tool Armor files used by RAB

null 28 Oct 3, 2022