Flexible, reusable reinforcement learning (Q learning) implementation in Rust

Overview

Rurel

Build Status crates.io

Rurel is a flexible, reusable reinforcement learning (Q learning) implementation in Rust.

In Cargo.toml:

rurel = "0.2.0"

An example is included. This teaches an agent on a 21x21 grid how to arrive at 10,10, using actions (go left, go up, go right, go down):

cargo run --example eucdist

Getting started

There are two main traits you need to implement: rurel::mdp::State and rurel::mdp::Agent.

A State is something which defines a Vec of actions that can be taken from this state, and has a certain reward. A State needs to define the corresponding action type A.

An Agent is something which has a current state, and given an action, can take the action and evaluate the next state.

Example

Let's implement the example in cargo run --example eucdist. We want to make an agent which is taught how to arrive at 10,10 on a 21x21 grid.

First, let's define a State, which should represent a position on a 21x21, and the correspoding Action, which is either up, down, left or right.

use rurel::mdp::State;

#[derive(PartialEq, Eq, Hash, Clone)]
struct MyState { x: i32, y: i32 }
#[derive(PartialEq, Eq, Hash, Clone)]
struct MyAction { dx: i32, dy: i32 }

impl State for MyState {
	type A = MyAction;
	fn reward(&self) -> f64 {
		// Negative Euclidean distance
		-((((10 - self.x).pow(2) + (10 - self.y).pow(2)) as f64).sqrt())
	}
	fn actions(&self) -> Vec<MyAction> {
		vec![MyAction { dx: 0, dy: -1 },	// up
			 MyAction { dx: 0, dy: 1 },	// down
			 MyAction { dx: -1, dy: 0 },	// left
			 MyAction { dx: 1, dy: 0 },	// right
		]
	}
}

Then define the agent:

use rurel::mdp::Agent;

struct MyAgent { state: MyState }
impl Agent<MyState> for MyAgent {
	fn current_state(&self) -> &MyState {
		&self.state
	}
	fn take_action(&mut self, action: &MyAction) -> () {
		match action {
			&MyAction { dx, dy } => {
				self.state = MyState {
					x: (((self.state.x + dx) % 21) + 21) % 21, // (x+dx) mod 21
					y: (((self.state.y + dy) % 21) + 21) % 21, // (y+dy) mod 21
				}
			}
		}
	}
}

That's all. Now make a trainer and train the agent with Q learning, with learning rate 0.2, discount factor 0.01 and an initial value of Q of 2.0. We let the trainer run for 100000 iterations, randomly exploring new states.

use rurel::AgentTrainer;
use rurel::strategy::learn::QLearning;
use rurel::strategy::explore::RandomExploration;
use rurel::strategy::terminate::FixedIterations;

let mut trainer = AgentTrainer::new();
let mut agent = MyAgent { state: MyState { x: 0, y: 0 }};
trainer.train(&mut agent,
              &QLearning::new(0.2, 0.01, 2.),
              &mut FixedIterations::new(100000),
              &RandomExploration::new());

After this, you can query the learned value (Q) for a certain action in a certain state by:

trainer.expected_value(&state, &action) // : Option<f64>
Comments
  • Empty list of actions panics

    Empty list of actions panics

    It seems that if the State::actions() returns an empty vector, the whole system crashes. I am new to reinforcement learning, so I might be using it incorrectly. My setup:

    • the state is a NxN empty board
    • for each board state, actions() generates a list of actions available for the current board state
    • in many cases the board state cannot be improved any further, and it should back track to try a different action
    opened by nyurik 3
  • Eucdist example doesn't seem to be correct

    Eucdist example doesn't seem to be correct

    I've modified the eucdist example to add Display for MyAction which prints an arrow based on the action. And added a function entry_to_action which gets the most likely action from a given state (if I'm not wrong):

    fn entry_to_action(entry: &HashMap<MyAction, f64>) -> Option<&MyAction> {
        entry
            .iter()
            .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap_or(Ordering::Equal))
            .map(|(a, _)| a)
    }
    

    And after running the example, it prints this:

    →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  ↑  
    ↓  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  ↑  ↑  
    ↓  ↓  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  →  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  →  →  →  →  →  →  →  →  →  →  →  →  →  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  →  →  →  →  →  →  →  →  →  →  →  →  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  →  →  →  →  →  →  →  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ↓  →  →  →  →  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ↓  →  →  →  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  →  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  →  →  →  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  ↑  ↑  
    ↓  ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  
    ↓  ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  ↑  
    ↓  ↓  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ↑  ↑  
    ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←  ←
    

    I expected that the arrows would all point toward the center right?

    Heres the code for printing the arrows:

    impl fmt::Display for MyAction {
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
            match *self {
                MyAction::Move { dx, dy } => {
                    match (dx, dy) {
                        (-1, 0) => write!(f, "{LEFT_ARROW}"),
                        (1, 0) => write!(f, "{RIGHT_ARROW}"),
                        (0, -1) => write!(f, "{UP_ARROW}"),
                        (0, 1) => write!(f, "{DOWN_ARROW}"),
                        _ => unreachable!()
                    }
                }
            }
        }
    }
    
    
    opened by tqwewe 2
  • Compute rewards from the agent due to irrelevant data

    Compute rewards from the agent due to irrelevant data

    The State during training must have relevant data so as not to distort the results. So you can add other irrelevant data on the Agent to calculate the rewards while keeping the right State.

    opened by tetuaoro 2
  • Example doesn't

    Example doesn't "learn" anything

    Running the example and adding some debugging code, I'm finding that the neural network is not learning anything at all.

        let mut trainer = AgentTrainer::new();
        let mut agent = MyAgent {
            state: MyState { x: 0, y: 0 },
        };
        trainer.train(
            &mut agent,
            &QLearning::new(0.2, 0.01, 2.),
            &mut FixedIterations::new(10000000),
            &RandomExploration::new(),
        );
        let state1 = MyState { x: 1, y: 0 };
        let state2 = MyState { x: 0, y: 1 };
        let actions = vec![MyAction { dx: 0, dy: -1 }, MyAction { dx: -1, dy: 0 }];
        for action in actions {
            println!(
                "1: {:?} {:?} {:?}",
                state1,
                action,
                trainer.expected_value(&state1, &action),
            );
            println!(
                "2: {:?} {:?} {:?}",
                state2,
                action,
                trainer.expected_value(&state2, &action),
            );
            println!();
        }
    
    1: MyState { x: 1, y: 0 } MyAction { dx: 0, dy: -1 } Some(-13.582118848154376)
    2: MyState { x: 0, y: 1 } MyAction { dx: 0, dy: -1 } Some(-14.27795681221249)
    
    1: MyState { x: 1, y: 0 } MyAction { dx: -1, dy: 0 } Some(-14.27795681221249)
    2: MyState { x: 0, y: 1 } MyAction { dx: -1, dy: 0 } Some(-13.582118848154376)
    

    It seems that it hasn't learned that even with x:1 and y:0, dx:-1 and dy:0 is the best move. Am I misunderstanding the example or anything here?

    opened by sigaloid 2
  • Provide a reference to learned_values

    Provide a reference to learned_values

    It can be useful to have access to the learned values without cloning them. For example, for serializating where only a reference is needed, saving the time and memory overhead of a clone.

    opened by paholg 1
  • Implementing serde for AgentTrainer and methods for JSON save/load.

    Implementing serde for AgentTrainer and methods for JSON save/load.

    I've made some additions to rurel that will make it easy to save and load data to JSON. In the pull request I've included examples on how to do this, and a snippet in the readme. The only catch is that serde's Serialize and Deserialize has to be derived for States and Actions in order for rurel to work.

    This does also mean we have to include serde in the toml, but I made serde pub extern so it can be accessed directly through rurel. End users will not have to put serde in their toml unless they want all its features.

    I'm a CS student, I'm not a professional. Just seemed like a good way to contribute to an awesome crate.

    opened by RylanYancey 1
  • 2021 edition, formatting, clippy

    2021 edition, formatting, clippy

    I did a rough pass over the code while attempting to understanding it. Let me know if these changes are ok.

    • use nightly cargo +nightly fmt to optimize and sort use statements
    • used clippy to fix a few minor things
    • removed externs - not needed with 2021 edition
    • added readme to the doctest (they all fail, but at least they are now part of the test code and can be improved in a another PR)
    opened by nyurik 1
  • Adding import/export and fixing some stuff

    Adding import/export and fixing some stuff

    Related to #3, adding import and export features for the HashMap and applying some formatting and cargo check feedback. Not fully tested yet, should be fine though.

    opened by mtib 1
  • Usage for future values?

    Usage for future values?

    Question... can this library be used for something like Pong - i.e. where the reward isn't known right away, but rather becomes eventually known and somehow backpropogated?

    I guess the reward could sortof be immediately known, by checking y-distance from the ball.. but that's not a huge win over just manually making it chase the ball. It'd be nicer to have the AI figure out other things - like trying to hit the ball such that it makes the opponent chase it a longer distance.

    Sorry for the naive question - I haven't jumped into ML yet and I'm just tinkering around with a webassembly pong thing and thought this library might be a nice way to drive the AI :)

    opened by dakom 1
  • Dumping results to file

    Dumping results to file

    I think the library should offer a way to dump the learned state to file (basically just the HashMap of the AgentTrainer) to save the learned state, checkpoint or continue learning an earlier state. This could be hacked from the outside with some unsafe code, though I think offering access to q within AgentTrainer would be useful.

    opened by mtib 0
  • Train by individual steps

    Train by individual steps

    Is it possible to train a single step at a time, so I can have it run within something like a bevy system?

    It seems like the current train method is usually done with a certain amount of iterations, but perhaps a function step or train_step would be convenient?

    opened by tqwewe 1
Owner
Milan Boers
Milan Boers
Reinforcement learning library written in Rust

REnforce Reinforcement library written in Rust This library is still in early stages, and the API has not yet been finalized. The documentation can be

Niven Achenjang 20 Jun 14, 2022
Reinforcement learning with Rust

ReLearn: A Reinforcement Learning Library A reinforcement learning library and experiment runner. Uses pytorch as the neural network backend via the t

Eric Langlois 10 Jun 14, 2022
Border is a reinforcement learning library in Rust

Border Border is a reinforcement learning library in Rust. For reusability of both RL environments and agents, this library provides a reference imple

Taku Yoshioka 1 Dec 15, 2022
NEATeRS is a library for training a genetic neural net through reinforcement learning.

NEATeRS NEATeRS is a library for training a genetic neural net through reinforcement learning. It uses the NEAT algorithm developed by Ken Stanley whi

TecTrixer 3 Nov 28, 2022
A Rust🦀 implementation of CRAFTML, an Efficient Clustering-based Random Forest for Extreme Multi-label Learning

craftml-rs A Rust implementation of CRAFTML, an Efficient Clustering-based Random Forest for Extreme Multi-label Learning (Siblini et al., 2018). Perf

Tom Dong 15 Nov 6, 2022
Practice repo for learning Rust. Currently going through "Rust for JavaScript Developers" course.

rust-practice ?? Practice repo for learning Rust. Directories /rust-for-js-dev Files directed towards "Rust for JavaScript Developers" course. Thank y

Sammy Samkough 0 Dec 25, 2021
A Rust library with homemade machine learning models to classify the MNIST dataset. Built in an attempt to get familiar with advanced Rust concepts.

mnist-classifier Ideas UPDATED: Finish CLI Flags Parallelize conputationally intensive functions Class-based naive bayes README Image parsing Confusio

Neil Kaushikkar 0 Sep 2, 2021
A Rust machine learning framework.

Linfa linfa (Italian) / sap (English): The vital circulating fluid of a plant. linfa aims to provide a comprehensive toolkit to build Machine Learning

Rust-ML 2.2k Jan 2, 2023
Machine Learning library for Rust

rusty-machine This library is no longer actively maintained. The crate is currently on version 0.5.4. Read the API Documentation to learn more. And he

James Lucas 1.2k Dec 31, 2022
Machine learning crate for Rust

rustlearn A machine learning package for Rust. For full usage details, see the API documentation. Introduction This crate contains reasonably effectiv

Maciej Kula 547 Dec 28, 2022
A deep learning library for rust

Alumina An experimental deep learning library written in pure rust. Breakage expected on each release in the short term. See mnist.rs in examples or R

zza 95 Nov 30, 2022
Machine learning in Rust.

Rustml Rustml is a library for doing machine learning in Rust. The documentation of the project with a descprition of the modules can be found here. F

null 60 Dec 15, 2022
Rust based Cross-GPU Machine Learning

HAL : Hyper Adaptive Learning Rust based Cross-GPU Machine Learning. Why Rust? This project is for those that miss strongly typed compiled languages.

Jason Ramapuram 83 Dec 20, 2022
Machine Learning Library for Rust

autograph Machine Learning Library for Rust undergoing maintenance Features Portable accelerated compute Run SPIR-V shaders on GPU's that support Vulk

null 223 Jan 1, 2023
Fwumious Wabbit, fast on-line machine learning toolkit written in Rust

Fwumious Wabbit is a very fast machine learning tool built with Rust inspired by and partially compatible with Vowpal Wabbit (much love! read more abo

Outbrain 115 Dec 9, 2022
🦀 Example of serving deep learning models in Rust with batched prediction

rust-dl-webserver This project provides an example of serving a deep learning model with batched prediction using Rust. In particular it runs a GPT2 m

Evan Pete Walsh 28 Dec 15, 2022
A Machine Learning Framework for High Performance written in Rust

polarlight polarlight is a machine learning framework for high performance written in Rust. Key Features TBA Quick Start TBA How To Contribute Contrib

Chris Ohk 25 Aug 23, 2022
Example of Rust API for Machine Learning

rust-machine-learning-api-example Example of Rust API for Machine Learning API example that uses resnet224 to infer images received in base64 and retu

vaaaaanquish 16 Oct 3, 2022
Label Propagation Algorithm by Rust. Label propagation (LP) is graph-based semi-supervised learning (SSL). LGC and CAMLP have been implemented.

label-propagation-rs Label Propagation Algorithm by Rust. Label propagation (LP) is graph-based semi-supervised learning (SSL). A simple LGC and a mor

vaaaaanquish 4 Sep 15, 2021