RUST笔记:candle使用基础

candle介绍

  • candle是huggingface开源的Rust的极简 ML 框架。

candle-矩阵乘法示例

cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
cargo build # 测试,或执行 cargo ckeck
  • main.rs
use candle_core::{Device, Tensor};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let device = Device::Cpu;

    let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
    let b = Tensor::randn(0f32, 1., (3, 4), &device)?;

    let c = a.matmul(&b)?;
    println!("{c}");
    Ok(())
}

  • 项目输出
~/myrust$ cargo new myapp
     Created binary (application) `myapp` package
~/myrust$ cd myapp
~/myrust/myapp$ cargo add --git https://github.com/huggingface/candle.git candle-core
    Updating git repository `https://github.com/huggingface/candle.git`
    Updating git submodule `https://github.com/NVIDIA/cutlass.git`
      Adding candle-core (git) to dependencies.
             Features:
             - accelerate
             - cuda
             - cudarc
             - cudnn
             - metal
             - mkl
    Updating git repository `https://github.com/huggingface/candle.git`
    Updating crates.io index
~/myrust/myapp$ cargo build
  Downloaded serde_derive v1.0.195
  Downloaded either v1.9.0
  Downloaded autocfg v1.1.0
  Downloaded zerofrom v0.1.3
  Downloaded zerofrom-derive v0.1.3
  Downloaded synstructure v0.13.0
  Downloaded crossbeam-deque v0.8.5
  Downloaded yoke-derive v0.7.3
  Downloaded half v2.3.1
  Downloaded bytemuck v1.14.1
  Downloaded rand_core v0.6.4
  Downloaded paste v1.0.14
  Downloaded proc-macro2 v1.0.78
  Downloaded itoa v1.0.10
  Downloaded memmap2 v0.9.4
  Downloaded syn v2.0.48
  Downloaded crossbeam-epoch v0.9.18
  Downloaded cfg-if v1.0.0
  Downloaded bitflags v1.3.2
  Downloaded num_cpus v1.16.0
  Downloaded gemm-f32 v0.17.0
  Downloaded reborrow v0.5.5
  Downloaded stable_deref_trait v1.2.0
  Downloaded rayon-core v1.12.1
  Downloaded seq-macro v0.3.5
  Downloaded thiserror-impl v1.0.56
  Downloaded dyn-stack v0.10.0
  Downloaded thiserror v1.0.56
  Downloaded unicode-xid v0.2.4
  Downloaded rand_chacha v0.3.1
  Downloaded ppv-lite86 v0.2.17
  Downloaded bytemuck_derive v1.5.0
  Downloaded getrandom v0.2.12
  Downloaded once_cell v1.19.0
  Downloaded unicode-ident v1.0.12
  Downloaded byteorder v1.5.0
  Downloaded crc32fast v1.3.2
  Downloaded num-complex v0.4.4
  Downloaded gemm-common v0.17.0
  Downloaded crossbeam-utils v0.8.19
  Downloaded quote v1.0.35
  Downloaded ryu v1.0.16
  Downloaded num-traits v0.2.17
  Downloaded zip v0.6.6
  Downloaded rand_distr v0.4.3
  Downloaded serde v1.0.195
  Downloaded rand v0.8.5
  Downloaded raw-cpuid v10.7.0
  Downloaded libm v0.2.8
  Downloaded serde_json v1.0.111
  Downloaded rayon v1.8.1
  Downloaded libc v0.2.152
  Downloaded gemm-c64 v0.17.0
  Downloaded gemm-c32 v0.17.0
  Downloaded safetensors v0.4.2
  Downloaded gemm-f64 v0.17.0
  Downloaded gemm v0.17.0
  Downloaded gemm-f16 v0.17.0
  Downloaded yoke v0.7.3
  Downloaded pulp v0.18.6
  Downloaded 60 crates (3.1 MB) in 14.91s
   Compiling proc-macro2 v1.0.78
   Compiling unicode-ident v1.0.12
   Compiling libc v0.2.152
   Compiling cfg-if v1.0.0
   Compiling libm v0.2.8
   Compiling autocfg v1.1.0
   Compiling crossbeam-utils v0.8.19
   Compiling ppv-lite86 v0.2.17
   Compiling rayon-core v1.12.1
   Compiling reborrow v0.5.5
   Compiling paste v1.0.14
   Compiling either v1.9.0
   Compiling bitflags v1.3.2
   Compiling seq-macro v0.3.5
   Compiling once_cell v1.19.0
   Compiling unicode-xid v0.2.4
   Compiling raw-cpuid v10.7.0
   Compiling serde v1.0.195
   Compiling crc32fast v1.3.2
   Compiling serde_json v1.0.111
   Compiling stable_deref_trait v1.2.0
   Compiling itoa v1.0.10
   Compiling ryu v1.0.16
   Compiling thiserror v1.0.56
   Compiling byteorder v1.5.0
   Compiling num-traits v0.2.17
   Compiling zip v0.6.6
   Compiling crossbeam-epoch v0.9.18
   Compiling quote v1.0.35
   Compiling syn v2.0.48
   Compiling crossbeam-deque v0.8.5
   Compiling getrandom v0.2.12
   Compiling memmap2 v0.9.4
   Compiling num_cpus v1.16.0
   Compiling rand_core v0.6.4
   Compiling rand_chacha v0.3.1
   Compiling rayon v1.8.1
   Compiling rand v0.8.5
   Compiling rand_distr v0.4.3
   Compiling synstructure v0.13.0
   Compiling bytemuck_derive v1.5.0
   Compiling serde_derive v1.0.195
   Compiling zerofrom-derive v0.1.3
   Compiling thiserror-impl v1.0.56
   Compiling yoke-derive v0.7.3
   Compiling bytemuck v1.14.1
   Compiling num-complex v0.4.4
   Compiling dyn-stack v0.10.0
   Compiling half v2.3.1
   Compiling zerofrom v0.1.3
   Compiling yoke v0.7.3
   Compiling pulp v0.18.6
   Compiling gemm-common v0.17.0
   Compiling gemm-f32 v0.17.0
   Compiling gemm-c64 v0.17.0
   Compiling gemm-f64 v0.17.0
   Compiling gemm-c32 v0.17.0
   Compiling gemm-f16 v0.17.0
   Compiling gemm v0.17.0
   Compiling safetensors v0.4.2
   Compiling candle-core v0.3.3 (https://github.com/huggingface/candle.git#fd7c8565)
   Compiling myapp v0.1.0 (/home/pdd/myrust/myapp)
    Finished dev [unoptimized + debuginfo] target(s) in 32.90s

candle_test的简单测试项目

Cargo.toml 文件

[package]
name = "candle_test"
version = "0.1.0"
edition = "2021" #  Rust 版本

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.2.1", features = ["cuda"] }
# `candle-core`:项目依赖的包的名称。`git` 字段指定了包的源代码仓库地址。`version` 字段指定了使用的包的版本。`features` 字段是一个数组,指定了启用的功能。在这里,启用了 "cuda" 功能。
# 可以通过以下命令添加,取消可注释掉"cuda",再cargo build
# cargo add --git https://github.com/huggingface/candle.git candle-core
# cargo add candle-core --features cuda

main.rs

use candle_core::{DType, Device, Result, Tensor};

// 定义一个模型结构体
struct Model {
    first: Tensor,
    second: Tensor,
}

impl Model {
    // 定义模型的前向传播方法
    fn forward(&self, image: &Tensor) -> Result<Tensor> {
        let x = image.matmul(&self.first)?; // 输入乘以第一层权重
        let x = x.relu()?; // 使用 ReLU 激活函数
        x.matmul(&self.second) // 结果乘以第二层权重
    }
}

fn main() -> Result<()> {
    // 初始化设备,如果 GPU 可用则使用 GPU,否则使用 CPU
    let device = match Device::new_cuda(0) {
        Ok(device) => device,
        Err(_) => Device::Cpu,
    };

    // 创建模型的第一层和第二层权重张量
    let first = Tensor::zeros((784, 100), DType::F32, &device)
        .unwrap()
        .contiguous()?;
    let second = Tensor::zeros((100, 10), DType::F32, &device)
        .unwrap()
        .contiguous()?;
    
    // 初始化模型
    let model = Model { first, second };

    // 创建一个用于测试的虚拟图像张量
    let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)
        .unwrap()
        .contiguous()?;

    // 调用模型的前向传播方法获取预测结果
    let digit = model.forward(&dummy_image)?;

    // 打印预测结果
    println!("Digit {digit:?} digit");

    Ok(())
}

知识点总结

candle_core:: Result

在这里插入图片描述

// Result定义在/home/pdd/.cargo/git/checkouts/candle-0c2b4fa9e5801351/e8e3375/candle-core/src/error.rs
pub type Result<T> = std::result::Result<T, Error>; // 定义了一个 `Result` 类型,这是一个 `Result<T, Error>` 类型的别名。其中 `T` 是成功时的返回类型,而 `Error` 是失败时的错误类型。
// Ok(()) 定义在 /home/pdd/.rustup/toolchains/stable-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/result.rs
// 这是 Rust 标准库中的 `Result` 公共的枚举类型,它有两个泛型参数 `T` 和 `E`。`T` 代表成功时返回的值的类型,`E` 代表错误时返回的错误类型。
// #[]是属性(attribute),提供额外信息
pub enum Result<T, E> {
    /// Contains the success value
    #[lang = "Ok"]
    #[stable(feature = "rust1", since = "1.0.0")]
    Ok(#[stable(feature = "rust1", since = "1.0.0")] T),// `Ok(T)`: 这是 `Result` 枚举的一个变体,用于表示成功的情况
                                                        // (): 是 Rust 中的单元类型(unit type),类似于其他语言中的 void。

    /// Contains the error value
    #[lang = "Err"]
    #[stable(feature = "rust1", since = "1.0.0")]
    Err(#[stable(feature = "rust1", since = "1.0.0")] E),// `Err(E)`: 这是 `Result` 枚举的另一个变体,用于表示错误的情况。
}

?符号

  • 在 Rust 中,? 符号用于处理 ResultOption 类型的返回值。这个符号的作用是将可能的错误或 None 值快速传播到调用链的最上层,使得代码更加简洁和易读。
fn forward(&self, image: &Tensor) -> Result<Tensor> {
    let x = image.matmul(&self.first)?; // 如果matmul返回Err,则整个forward函数返回Err
    let x = x.relu()?; // 如果relu返回Err,则整个forward函数返回Err
    x.matmul(&self.second) // 如果matmul返回Err,则整个forward函数返回Err;否则返回Ok(Tensor)
}

语句和表达式:语句以分号结尾,而表达式通常不需要分号。

  • 函数体:函数体是一个块表达式,其值是最后一个表达式的值。

    fn add(x: i32, y: i32) -> i32 {
        x + y // 表达式
    }
    

CG