A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.

Overview

candle-simplified-example

A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.

How its works

This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.

Basic moments:

  1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
  2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
  3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
  4. For training, samples with real data on the results of the first and second stages of different elections are used.
  5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
  6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
  7. After training, the model is tested on a deferred sample to evaluate the accuracy.
  8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.

Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.

What does the code look like

const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 20;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;

#[derive(Clone)]
pub struct Dataset {
    pub train_votes: Tensor,
    pub train_results: Tensor,
    pub test_votes: Tensor,
    pub test_results: Tensor,
}

struct MultiLevelPerceptron {
    ln1: Linear,
    ln2: Linear,
    ln3: Linear,
}

impl MultiLevelPerceptron {
    fn new(vs: VarBuilder) -> Result<Self> {
        let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
        let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
        let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
        Ok(Self { ln1, ln2, ln3 })
    }

    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let xs = self.ln1.forward(xs)?;
        let xs = xs.relu()?;
        let xs = self.ln2.forward(&xs)?;
        let xs = xs.relu()?;
        self.ln3.forward(&xs)
    }
}


pub fn main() -> anyhow::Result<()> {
    let dev = Device::cuda_if_available(0)?;

    let train_votes_vec: Vec<u32> = vec![
        15, 10,
        10, 15,
        5, 12,
        30, 20,
        16, 12,
        13, 25,
        6, 14,
        31, 21,
    ];
    let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;

    let train_results_vec: Vec<u32> = vec![
        1,
        0,
        0,
        1,
        1,
        0,
        0,
        1,
    ];
    let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;

    let test_votes_vec: Vec<u32> = vec![
        13, 9,
        8, 14,
        3, 10,
    ];
    let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;

    let test_results_vec: Vec<u32> = vec![
        1,
        0,
        0,
    ];
    let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;

    let m = Dataset {
        train_votes: train_votes_tensor,
        train_results: train_results_tensor,
        test_votes: test_votes_tensor,
        test_results: test_results_tensor,
    };

    let trained_model: MultiLevelPerceptron;
    loop {
        println!("Trying to train neural network.");
        match train(m.clone(), &dev) {
            Ok(model) => {
                trained_model = model;
                break;
            },
            Err(e) => {
                println!("Error: {:?}", e);
                continue;
            }
        }

    }

    let real_world_votes: Vec<u32> = vec![
        13, 22,
    ];

    let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;

    let final_result = trained_model.forward(&tensor_test_votes)?;

    let result = final_result
        .argmax(D::Minus1)?
        .to_dtype(DType::F32)?
        .get(0).map(|x| x.to_scalar::<f32>())??;
    println!("real_life_votes: {:?}", real_world_votes);
    println!("neural_network_prediction_result: {:?}", result);

    Ok(())
}

How to run

cargo run

Example output


Trying to train neural network.
Epoch:   1 Train loss:  4.42555 Test accuracy:  0.00%
Epoch:   2 Train loss:  0.84677 Test accuracy: 33.33%
Epoch:   3 Train loss:  2.54335 Test accuracy: 33.33%
Epoch:   4 Train loss:  0.37806 Test accuracy: 33.33%
Epoch:   5 Train loss:  0.36647 Test accuracy: 100.00%
real_life_votes: [13, 22]
neural_network_prediction_result: 0.0

You might also like...
A toy example showing how to run Rust code in Python for speed and progress.

PoC: Integrating Rust in Python A toy example showing how to run Rust code in Python for speed and progress. Requirements Python 3.6+ Rust 1.44+ Cargo

Little example projects for learning Rust and building awesome cli tools! ⭐️

rust-cli-examples Examples of clean and well-tested command line utilities, written in Rust. 🦀 👍 What is this repo? This repo is a collection of man

TinyTodo is a Cedar Agent example, with a server in Rust and client in python
TinyTodo is a Cedar Agent example, with a server in Rust and client in python

TinyTodo - OPAL and Cedar Agent Demo TinyTodo is a simple application for managing task lists. It uses OPAL and Cedar Agent to control who has access

A working, tested example for how to use Rust with warp and JWT

rust-jwt-example Example of JWT authentication and authorization in Rust using Warp Login curl http://localhost:8000/login -d '{"email": "user@userlan

Black-box integration tests for your REST API using the Rust and its test framework

restest Black-box integration test for REST APIs in Rust. This crate provides the [assert_api] macro that allows to declaratively test, given a certai

Captures packets and streams them to other devices. Built for home network analysis and A&D CTFs.

🍩 shiny-donut shiny-donut is a packet capture app that supports streaming packets from a remote system to another device. The main use for this is to

Super-lightweight Immediate-mode Embedded GUI framework, based on the awesome embedded-graphics library. Written in Rust.

Kolibri - A GUI framework made to be as lightweight as its namesake What is Kolibri? Kolibri is an embedded Immediate Mode GUI mini-framework very str

Example to run Rust code on the MCH2022 badge.
Example to run Rust code on the MCH2022 badge.

Rust on the MCH2022 badge This repo contains instructions and code to run Rust on the MCH2022 badge. There are two approaches regarding environment: I

This is an example Nostr rust project to enable '402 Payment Required' responses for requests to paid content.

Nostr Paywall Example This is an example Nostr rust project to enable 402 Payment Required responses for requests to paid content. To prove payment, a

Owner
Evgeny Igumnov
Rust Developer at Jetico company
Evgeny Igumnov
rust+slint+candle+openchat3.5 demo

Slint Chatbot Demo This is a demo of Rust + Slint + Candle + OpenChat LLM, it looks like this: Do it by yourself Make sure you have downloaded opencha

null 12 Dec 13, 2023
Think tmux, then aim... lower

shpool shpool is a service that enables session persistence by allowing the creation of named shell sessions owned by shpool so that the session is no

null 14 Mar 4, 2024
Over-simplified, featherweight, open-source and easy-to-use authentication and authorization server.

concess ⚠️ Early Development: This is not production ready, yet. Do not use it for anything important. Introduction concess is a over-simplified, feat

Dustin Frisch 3 Nov 25, 2022
🐹 LGrow, Fast and simplified ⚡

?? LGrow, the most simplified high-performance language! Before, we needed to make this presentation, since this character is VERY important to our st

LGrow lang 4 Oct 15, 2023
A simplified recreation of the command-line utility grep written in Rust.

smolgrep A simplified recreation of the command-line utility grep written in Rust. Download and run Download Rust On Mac/Linux Open a terminal and ent

Thi Dinh 0 Dec 27, 2021
A simplified general-purpose queueing system for Rust apps.

A simplified general-purpose queueing system for Rust apps. Example // Create a new Redeez object, and define your queues let mut queues = Redeez::new

Miguel Piedrafita 11 Jan 16, 2023
An over-simplified version control system written in Rust, similar to Git, for local files (Incomplete)

Vault Vault will be a command line tool (if successful) similar to git which would have multiple features like brances etc etc. __ __ _ _

Shubham 3 Nov 21, 2023
Example of an dark-mode toggle button based on progressive enhancement

Leptos Starter Template This is a template for use with the Leptos web framework and the cargo-leptos tool. Creating your template repo If you don't h

Leptos 5 Jan 12, 2023
A complete imgui-rs example using dependencies only from crates.io.

Dear imgui-rs, hello. This is a fairly basic, but complete and standalone example application for the Rust version of dear imgui (https://github.com/o

null 0 Nov 30, 2022