A Rust🦀 implementation of CRAFTML, an Efficient Clustering-based Random Forest for Extreme Multi-label Learning

Overview

craftml-rs

A Rust implementation of CRAFTML, an Efficient Clustering-based Random Forest for Extreme Multi-label Learning (Siblini et al., 2018).

Performance

This implementation has been tested on datasets from the Extreme Classification Repository. Each data set comes either with a single data file and separate files for train / test splits, or with two separate train / test data files.

A data file starts with a header line with three space-separated integers: total number of examples, number of features, and number of labels. Following the header line, there is one line per each example, starting with comma-separated labels, followed by space-separated feature:value pairs:

label1,label2,...labelk ft1:ft1_val ft2:ft2_val ft3:ft3_val .. ftd:ftd_val

A split file is a integer matrix, with one line per row, and columns separated by spaces. The integers are example indices (1-indexed) in the corresponding data file, and each column corresponds to a separate split.

Precisions at 1, 3, and 5 are calculated for models trained with default hyper-parameters, e.g.

  • craftml train Mediamill/data.txt --cv_splits_path Mediamill/train_split.txt for Mediamill, which has a single data file and separate train / test split files;
  • craftml train EURLex-4K/train.txt --test_data EURLex-4K/test.txt for EURLex-4K, which has separate train / test data files.
Dataset P@1 P@3 P@5
Mediamill 85.51 69.94 56.39
Bibtex 61.47 37.20 27.32
Delicious 67.78 62.15 57.63
EURLex-4K 79.52 66.42 55.25
Wiki10-31K 83.57 72.69 63.65
WikiLSHTC-325K 51.79 32.41 23.43
Delicious-200K 47.34 40.85 37.67
Amazon-670K 38.40 34.21 31.41
AmazonCat-13K 92.88 77.48 61.32

These numbers are generally consistent with those reported in the original paper.

Note that if there isn't enough memory to train on a large data set, the --test_trees_singly flag can be set to only train & test one tree at a time, and discard each tree when it's been tested. This allows one to obtain test results without being able to fit the entire model in memory. One can also tune the --centroid_preserve_ratio option to trade off between model size and accuracy.

Build

The project can be easily built with Cargo:

$ cargo build --release

The compiled binary file will be available at target/release/craftml.

Usage

$ craftml train --help

craftml-train
Train a new CRAFTML model

USAGE:
    craftml train [FLAGS] [OPTIONS] <training_data>

FLAGS:
    -h, --help                 Prints help information
        --test_trees_singly    Test forest tree by tree, freeing each before training the next to reduce memory usage.
                               Model cannot be saved.
    -V, --version              Prints version information

OPTIONS:
        --centroid_min_n_preserve <centroid_min_n_preserve>
            The minimum number of entries to preserve from puning, regardless preserve ratio setting. [default: 10]

        --centroid_preserve_ratio <centroid_preserve_ratio>
            A real number between 0 and 1, which is the ratio of entries with largest absoulte values to preserve. The
            rest of the entries are pruned. [default: 0.1]
        --cluster_sample_size <cluster_sample_size>
            Number of examples drawn for clustering on a branching node [default: 20000]

        --cv_splits_path <PATH>
            Path to the k-fold cross validation splits file, with k space-separated columns of indices (starting from 1)
            for training splits.
        --k_clusters <k_clusters>                              Number of clusters on a branching node [default: 10]
        --leaf_max_size <leaf_max_size>
            Maximum number of distinct examples on a leaf node [default: 10]

        --model_path <PATH>                                    Path to which the trained model will be saved if provided
        --n_cluster_iters <n_cluster_iters>
            Number of clustering iterations to run on each branching node [default: 2]

        --n_feature_buckets <n_feature_buckets>
            Number of buckets into which features are hashed [default: 10000]

        --n_label_buckets <n_label_buckets>
            Number of buckets into which labels are hashed [default: 10000]

        --n_threads <n_threads>
            Number of worker threads. If 0, the number is selected automatically. [default: 0]

        --n_trees <n_trees>                                    Number of trees in the random forest [default: 50]
        --out_path <PATH>
            Path to the which predictions will be written, if provided

        --test_data <PATH>
            Path to test dataset file used to calculate metrics if provided (in the format of the Extreme Classification
            Repository)

ARGS:
    <training_data>    Path to training dataset file (in the format of the Extreme Classification Repository)
$ craftml test --help

craftml-test
Test an existing CRAFTML model

USAGE:
    craftml test [OPTIONS] <model_path> <test_data>

FLAGS:
    -h, --help       Prints help information
    -V, --version    Prints version information

OPTIONS:
        --k_top <k_top>            Number of top predictions to write out for each test example [default: 5]
        --n_threads <n_threads>    Number of worker threads. If 0, the number is selected automatically. [default: 0]
        --out_path <PATH>          Path to the which predictions will be written, if provided

ARGS:
    <model_path>    Path to the trained model
    <test_data>     Path to test dataset file (in the format of the Extreme Classification Repository)

References

  • Siblini, W., Kuntz, P., & Meyer, F. (2018). CRAFTML, an Efficient Clustering-based Random Forest for Extreme Multi-label Learning. In Proceedings of the 35th International Conference on Machine Learning (Vol. 80, pp. 4664–4673). Stockholmsmässan, Stockholm Sweden: PMLR. http://proceedings.mlr.press/v80/siblini18a.html
You might also like...
Flexible, reusable reinforcement learning (Q learning) implementation in Rust

Rurel Rurel is a flexible, reusable reinforcement learning (Q learning) implementation in Rust. Release documentation In Cargo.toml: rurel = "0.2.0"

Cleora AI is a general-purpose model for efficient, scalable learning of stable and inductive entity embeddings for heterogeneous relational data.
Cleora AI is a general-purpose model for efficient, scalable learning of stable and inductive entity embeddings for heterogeneous relational data.

Cleora Cleora is a genus of moths in the family Geometridae. Their scientific name derives from the Ancient Greek geo γῆ or γαῖα "the earth", and metr

A neural network model that can approximate any non-linear function by using the random search algorithm for the optimization of the loss function.

random_search A neural network model that can approximate any non-linear function by using the random search algorithm for the optimization of the los

Narwhal and Tusk A DAG-based Mempool and Efficient BFT Consensus.

This repo contains a prototype of Narwhal and Tusk. It supplements the paper Narwhal and Tusk: A DAG-based Mempool and Efficient BFT Consensus.

Rust implementation of multi-index hashing for neighbor searches on binary codes in the Hamming space

mih-rs Rust implementation of multi-index hashing (MIH) for neighbor searches on binary codes in the Hamming space, described in the paper Norouzi, Pu

Rust based Cross-GPU Machine Learning

HAL : Hyper Adaptive Learning Rust based Cross-GPU Machine Learning. Why Rust? This project is for those that miss strongly typed compiled languages.

High performance distributed framework for training deep learning recommendation models based on PyTorch.
High performance distributed framework for training deep learning recommendation models based on PyTorch.

PERSIA (Parallel rEcommendation tRaining System with hybrId Acceleration) is developed by AI platform@Kuaishou Technology, collaborating with ETH. It

🚀  efficient approximate nearest neighbor search algorithm collections library written in Rust 🦀 .
🚀 efficient approximate nearest neighbor search algorithm collections library written in Rust 🦀 .

🚀 efficient approximate nearest neighbor search algorithm collections library written in Rust 🦀 .

HNSW ANN from the paper "Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs"

hnsw Hierarchical Navigable Small World Graph for fast ANN search Enable the serde feature to serialize and deserialize HNSW. Tips A good default for

Comments
  • Model not getting saved

    Model not getting saved

    thread 'main' panicked at 'Failed to create model file: Os { code: 21, kind: Other, message: "Is a directory" }', src/libcore/result.rs:999:5 (error while trying to save the model in path) while training

    eg. craftml train --model_path model eurlex_train.txt

    opened by saquib17 0
  • Fix Spherical k-Means++

    Fix Spherical k-Means++

    Current we just use (1 - cosine similarity) as distance for k-means++:

    https://github.com/tomtung/craftml-rs/blob/838da05eee0351f54515d3780777edb3008fe596/src/model/skmeans.rs#L34

    As Endo Y. et al. (2015) pointed out, however, this violates triangle inequality, which can be fix by changing 1 to 3/2. Should try this out. (Note that we can't simply change the number, since the current implementation assumes one's distance to itself is zero.)

    opened by tomtung 0
Owner
Tom Dong
Tom Dong
Label Propagation Algorithm by Rust. Label propagation (LP) is graph-based semi-supervised learning (SSL). LGC and CAMLP have been implemented.

label-propagation-rs Label Propagation Algorithm by Rust. Label propagation (LP) is graph-based semi-supervised learning (SSL). A simple LGC and a mor

vaaaaanquish 4 Sep 15, 2021
A random forest implementation in Rust

randomforest A random forest implementation in Rust. Examples use randomforest::criterion::Mse; use randomforest::RandomForestRegressorOptions; use ra

Takeru Ohta 3 Nov 19, 2022
Random Cut Forest anomaly detection for C/C++

Random Cut Forest C/C++ Random Cut Forest (RCF) anomaly detection for C/C++ ?? Also available for Ruby and PHP, and as a CLI Installation Download the

Andrew Kane 4 Nov 8, 2022
A naive density-based clustering algorithm written in Rust

Density-based clustering This a pure Rust implementation of a naive density-based clustering algorithm similar to DBSCAN. Here, 50 points are located

chris m 0 Mar 19, 2020
Rust implementation for DBSCANSD, a trajectory clustering algorithm.

DBSCANSD Rust implementation for DBSCANSD, a trajectory clustering algorithm. Brief Introduction DBSCANSD (Density-Based Spatial Clustering of Applica

Nick Gu 2 Mar 14, 2021
Fast hierarchical agglomerative clustering in Rust.

kodama This crate provides a fast implementation of agglomerative hierarchical clustering. This library is released under the MIT license. The ideas a

Diffeo 61 Oct 7, 2022
A rust library inspired by kDDBSCAN clustering algorithm

kddbscan-rs Rust implementation of the kddbscan clustering algorithm. From the authors of kDDBSCAN algorithm. Due to the adoption of global parameters

WhizSid 2 Apr 28, 2021
k-Medoids clustering in Rust with the FasterPAM algorithm

k-Medoids Clustering in Rust with FasterPAM This Rust crate implements k-medoids clustering with PAM. It can be used with arbitrary dissimilarites, as

Erich Schubert 11 Oct 16, 2022
DBSCAN and OPTICS clustering algorithms.

petal-clustering A collection of clustering algorithms. Currently this crate provides DBSCAN and OPTICS. Examples The following example shows how to c

Petabi 15 Dec 15, 2022
Rust port of the extended isolation forest algorithm for anomaly detection

Extended Isolation Forest This is a rust port of the anomaly detection algorithm described in Extended Isolation Forest and implemented in https://git

Nico Mandery 6 Oct 21, 2022