rust 多平台demo

rust 多平台demo

rust 安装

官网:https://rustup.rs/

curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

提高init速度

export RUSTUP_DIST_SERVER="https://mirrors.ustc.edu.cn/rust-static"
export RUSTUP_UPDATE_ROOT="https://mirrors.ustc.edu.cn/rust-static/rustup"
换源

编辑$HOME/.cargo/config.toml

[source.crates-io]
replace-with = "ustc"

[source.ustc]
registry = "sparse+https://mirrors.ustc.edu.cn/crates.io-index/"
#更多---------------------------
[source.crates-io]
registry = "https://github.com/rust-lang/crates.io-index"
# 指定镜像
replace-with = 'ustc'

# 清华大学
[source.tuna]
registry = "https://mirrors.tuna.tsinghua.edu.cn/git/crates.io-index.git"

# 中国科学技术大学
[source.ustc]
registry = "https://mirrors.ustc.edu.cn/crates.io-index"

# 上海交通大学
[source.sjtu]
registry = "https://mirrors.sjtug.sjtu.edu.cn/git/crates.io-index"

# rustcc社区
[source.rustcc]
registry = "https://code.aliyun.com/rustcc/crates.io-index.git"

vscode配置

check commond : check 改为 clippy ,以防止rust_analyzer卡

版本切换

在使用SIMD进行加速时,需要使用nightly版本

rustup update nightly
rustup override set nightly

查看是否配置成功

rustc --version

输出:rustc 1.84.0-nightly (798fb83f7 2024-10-16)

C++

第一步:创建Rust库
  1. 新建Rust库项目

    cargo new rust_tokenizer --lib
    cd rust_tokenizer
  2. 编辑Cargo.toml: 添加如下内容,以确保生成动态链接库:

    [lib]
    name = "rust_tokenizer"
    crate-type = ["cdylib"]
  3. 编写Rust代码: 在src/lib.rs中添加你的tokenizer逻辑,同时提供一个C兼容的接口。下面是一个简单的例子,使用Rust实现一个基础的tokenizer函数。

    use std::ffi::{CString, CStr};
    use std::os::raw::c_char;

    #[no_mangle]
    pub extern "C" fn tokenize(text: *const c_char) -> *mut c_char {
    let c_str = unsafe {
    assert!(!text.is_null());
    CStr::from_ptr(text)
    };

    let r_str = c_str.to_str().unwrap();
    let tokenized = r_str.replace(" ", "|"); // 简单的tokenization逻辑

    CString::new(tokenized).unwrap().into_raw()
    }

    #[no_mangle]
    pub extern "C" fn free_string(s: *mut c_char) {
    unsafe {
    if s.is_null() { return }
    CString::from_raw(s)
    };
    }

    这里我们定义了两个函数:tokenize用于处理字符串,free_string用于释放Rust分配的内存。

  4. 编译Rust代码

    cargo build --release

    编译后的.so文件位于target/release目录下。

第二步:在C++中调用Rust函数
  1. 编写C++代码: 在C++项目中,使用Rust编译出的.so库文件。

    #include <iostream>
    extern "C" {
    char* tokenize(const char* text);
    void free_string(char* s);
    }

    int main() {
    const char* text = "Hello, world!";
    char* result = tokenize(text);
    std::cout << "Tokenized: " << result << std::endl;
    free_string(result);
    return 0;
    }
  2. 编译C++代码: 使用g++编译,并链接Rust生成的.so文件。


    g++ -o tokenizer_test main.cpp -L/path/to/target/release -lrust_tokenizer -ldl
  3. 运行C++程序: 确保Rust生成的.so文件路径在LD_LIBRARY_PATH环境变量中,或者复制到合适的位置。

    export LD_LIBRARY_PATH=/path/to/target/release:$LD_LIBRARY_PATH
    ./tokenizer_test

C++ 交叉编译

1.rust配置

添加目标架构 aarch64-unknown-linux-gnu(适用于 ARM64 的 Linux):

rustup target add aarch64-unknown-linux-gnu
2. 安装交叉编译工具链

需要一个交叉编译工具链来生成适用于 ARM64 架构的二进制文件。在 Debian/Ubuntu 上,可以安装以下工具:

sudo apt-get update
sudo apt-get install gcc-aarch64-linux-gnu
3. 配置和编译

在项目的根目录下创建一个文件 .cargo/config.toml,以配置交叉编译工具链:

[target.aarch64-unknown-linux-gnu]
ar = "aarch64-linux-gnu-ar"
linker = "aarch64-linux-gnu-gcc"

接下来,运行以下命令进行编译:

cargo build --release --target=aarch64-unknown-linux-gnu

编译完成后, .so 文件将在 target/aarch64-unknown-linux-gnu/release/ 目录下。

4.使用自定义工具链

在项目的根目录下创建 .cargo/config.toml 文件:

[target.aarch64-unknown-linux-gnu]
ar = "/path/to/custom/toolchain/bin/aarch64-linux-gnu-ar"
linker = "/path/to/custom/toolchain/bin/aarch64-linux-gnu-gcc"

设置环境变量以指定工具链的 sysroot 和相关库路径:

export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=/path/to/custom/toolchain/bin/aarch64-linux-gnu-gcc
export CC_aarch64_unknown_linux_gnu=/path/to/custom/toolchain/bin/aarch64-linux-gnu-gcc
export AR_aarch64_unknown_linux_gnu=/path/to/custom/toolchain/bin/aarch64-linux-gnu-ar
5.编译x64通用
rustup target add x86_64-unknown-linux-gnu

在 Rust 项目中,创建一个适用于更广泛兼容性的 .cargo/config.toml 文件,确保使用的是一个较低版本的 glibc,并且静态链接标准库和其他依赖项:

[target.x86_64-unknown-linux-gnu]
rustflags = [
"-C", "target-feature=+crt-static",
"-C", "link-args=-Wl,--no-as-needed -Wl,--as-needed"
]

确保在编译时使用目标为 x86_64-unknown-linux-gnu

cargo build --release --target=x86_64-unknown-linux-gnu

编译ubuntu16.04兼容的python库

docker 安装

https://docs.docker.com/desktop/wsl/#download

拉取docker
docker pull ubuntu:16.04
docker run -it \
--name ubuntu16-rust3 \
--net=host \
-v /home/jw/.cargo:/root/.cargo \
-v /home/jw:/home/jw \
-v /mnt/c/share:/mnt/c/share \
ubuntu:16.04
apt-get update
apt-get install -y curl build-essential
apt-get install libhdf5-dev

配置rust

curl https://sh.rustup.rs -sSf | sh
source $HOME/.cargo/env #root

配置Python

cd /home/jw
bin/micromamba shell init --shell bash --root-prefix=./micromamba
source ~/.bashrc
micromamba activate [env_name]

编译

cd code/
maturin build --release --out dist

无法安装问题解决:将xxx-0.1.0-cp310-cp310-manylinux_2_24_x86_64.whl更名为xxx-0.1.0-cp310-cp310-manylinux1_x86_64.whl

重启后重新进入

进入环境

docker ps 
docker exec -it ubuntu16-rust3 /bin/bash

激活python环境

micromamba activate [env_name]

编译

cd code/
maturin build --release --out dist
编译manylinux2014

拉取

docker pull quay.io/pypa/manylinux2014_x86_64

创建

docker images

sudo docker run -it \
--name manylinux2014 \
--net=host \
-v /home/jw:/home/jw \
-v code:/code \
quay.io/pypa/manylinux2014_x86_64

配置环境

yum install -y curl && \
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \
source $HOME/.cargo/env
yum groupinstall -y "Development Tools"
yum install -y hdf5 hdf5-devel

配置Python

cd /home/jw
bin/micromamba shell init --shell bash --root-prefix=./micromamba
source ~/.bashrc
micromamba activate [env_name]

编译

cd code/rust
maturin build --release --manylinux=2014 --out dist

c++和python接口应用实例

1. 配置 Cargo.toml

首先,在Rust 项目的 Cargo.toml 文件中添加 pyo3maturin 依赖:

[package]
name = "tokenizers_wrapper"
version = "0.1.0"
edition = "2018"

[dependencies]
pyo3 = { version = "0.15", features = ["extension-module"] }
serde_json = "1.0"
tokenizers = "0.13"

[lib]
crate-type = ["cdylib"]

[package.metadata.maturin]
name = "tokenizers_wrapper"

pyo3和maturin用于python接口,cdylib用于c++接口

2. 编写 Rust 代码

src/lib.rs 中编写 Rust 代码,将现有功能同时暴露给 Python 和 C++:

use pyo3::prelude::*;
use serde_json::Value;
use std::{collections::HashMap, str::FromStr};
use tokenizers::models::bpe::BPE;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::tokenizer::Tokenizer;

#[pyclass]
pub struct TokenizerWrapper {
tokenizer: Tokenizer,
decode_str: String,
id_to_token_result: String,
}

#[pymethods]
impl TokenizerWrapper {
#[new]
pub fn from_str(json: &str) -> Self {
TokenizerWrapper {
tokenizer: Tokenizer::from_str(json).unwrap(),
decode_str: String::new(),
id_to_token_result: String::new(),
}
}

#[staticmethod]
pub fn byte_level_bpe_from_str(vocab: &str, merges: &str, added_tokens: &str) -> Self {
let vocab_json: Value = serde_json::from_str(vocab).unwrap();
let added_tokens_json: Value = serde_json::from_str(added_tokens).unwrap();
let mut vocab = HashMap::new();

match vocab_json {
Value::Object(m) => {
for (token, id) in m {
if let Value::Number(id) = id {
let id = id.as_u64().unwrap() as u32;
vocab.insert(token, id);
}
}
}
_ => panic!("Invalid vocab.json file."),
};

match added_tokens_json {
Value::Object(m) => {
for (token, id) in m {
if let Value::Number(id) = id {
let id = id.as_u64().unwrap() as u32;
vocab.insert(token, id);
}
}
}
_ => panic!("Invalid added_tokens.json file."),
}

let merges = merges
.lines()
.filter(|line| !line.starts_with("#version"))
.map(|line| {
let parts = line.split(' ').collect::<Vec<_>>();
if parts.len() != 2 {
panic!("Invalid merges.txt file.")
}
(parts[0].to_string(), parts[1].to_string())
})
.collect::<Vec<(String, String)>>();

let byte_level = ByteLevel::new(
/*add_prefix_space=*/ false, /*trim_offsets=*/ false,
/*use_regex=*/ false,
);

let mut tokenizer = Tokenizer::new(BPE::new(vocab, merges));
tokenizer.with_pre_tokenizer(byte_level).with_decoder(byte_level);

TokenizerWrapper {
tokenizer,
decode_str: String::new(),
id_to_token_result: String::new(),
}
}

pub fn encode(&mut self, text: &str, add_special_tokens: bool) -> Vec<u32> {
let encoded = self.tokenizer.encode(text, add_special_tokens).unwrap();
encoded.get_ids().to_vec()
}

pub fn encode_batch(&mut self, texts: Vec<&str>, add_special_tokens: bool) -> Vec<Vec<u32>> {
self.tokenizer
.encode_batch(texts, add_special_tokens)
.unwrap()
.into_iter()
.map(|encoded| encoded.get_ids().to_vec())
.collect()
}

pub fn decode(&mut self, ids: Vec<u32>, skip_special_tokens: bool) -> String {
self.tokenizer.decode(ids.as_slice(), skip_special_tokens).unwrap()
}

pub fn get_vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(true)
}

pub fn id_to_token(&mut self, id: u32) -> String {
match self.tokenizer.id_to_token(id) {
Some(token) => token,
None => String::new(),
}
}

pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.tokenizer.token_to_id(token)
}
}

#[pymodule]
fn tokenizers_wrapper(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<TokenizerWrapper>()?;
Ok(())
}

// C++ FFI functions
#[no_mangle]
pub extern "C" fn tokenizers_new_from_str(input_cstr: *const u8, len: usize) -> *mut TokenizerWrapper {
let json = unsafe { std::str::from_utf8(std::slice::from_raw_parts(input_cstr, len)).unwrap() };
Box::into_raw(Box::new(TokenizerWrapper::from_str(json)))
}

#[no_mangle]
pub extern "C" fn byte_level_bpe_tokenizers_new_from_str(
input_vocab_str: *const u8,
len_vocab: usize,
input_merges_str: *const u8,
len_merges: usize,
input_added_tokens_str: *const u8,
len_added_tokens: usize,
) -> *mut TokenizerWrapper {
let vocab = unsafe { std::str::from_utf8(std::slice::from_raw_parts(input_vocab_str, len_vocab)).unwrap() };
let merges = unsafe { std::str::from_utf8(std::slice::from_raw_parts(input_merges_str, len_merges)).unwrap() };
let added_tokens = unsafe { std::str::from_utf8(std::slice::from_raw_parts(input_added_tokens_str, len_added_tokens)).unwrap() };
Box::into_raw(Box::new(TokenizerWrapper::byte_level_bpe_from_str(vocab, merges, added_tokens)))
}

#[no_mangle]
pub extern "C" fn tokenizers_encode(
handle: *mut TokenizerWrapper,
input_cstr: *const u8,
len: usize,
add_special_tokens: i32,
out_result: *mut TokenizerEncodeResult,
) {
let input_data = unsafe { std::str::from_utf8(std::slice::from_raw_parts(input_cstr, len)).unwrap() };
let encoded = unsafe { &mut *handle }.encode(input_data, add_special_tokens != 0);
let len = encoded.len();
unsafe {
*out_result = TokenizerEncodeResult {
token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32,
len,
};
}
}

#[no_mangle]
pub extern "C" fn tokenizers_encode_batch(
handle: *mut TokenizerWrapper,
input_cstr: *const *const u8,
input_len: *const usize,
num_seqs: usize,
add_special_tokens: i32,
out_result: *mut TokenizerEncodeResult,
) {
let input_data = (0..num_seqs)
.map(|i| {
unsafe {
std::str::from_utf8(std::slice::from_raw_parts(*input_cstr.offset(i as isize), *input_len.offset(i as isize))).unwrap()
}
})
.collect::<Vec<&str>>();
let encoded_batch = unsafe { &mut *handle }.encode_batch(input_data, add_special_tokens != 0);
for (i, encoded) in encoded_batch.into_iter().enumerate() {
let len = encoded.len();
unsafe {
*out_result.offset(i as isize) = TokenizerEncodeResult {
token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32,
len,
};
}
}
}

#[no_mangle]
pub extern "C" fn tokenizers_free_encode_results(results: *mut TokenizerEncodeResult, num_seqs: usize) {
let slice = unsafe { std::slice::from_raw_parts_mut(results, num_seqs) };
for result in slice {
unsafe {
drop(Box::from_raw(std::slice::from_raw_parts_mut(result.token_ids, result.len)));
}
}
}

#[no_mangle]
pub extern "C" fn tokenizers_decode(
handle: *mut TokenizerWrapper,
input_ids: *const u32,
len: usize,
skip_special_tokens: i32,
) {
let input_data = unsafe { std::slice::from_raw_parts(input_ids, len) };
unsafe { &mut *handle }.decode(input_data.to_vec(), skip_special_tokens != 0);
}

#[no_mangle]
pub extern "C" fn tokenizers_get_decode_str(
handle: *mut TokenizerWrapper,
out_cstr: *mut *mut u8,
out_len: *mut usize,
) {
let decode_str = unsafe { &mut *handle }.decode_str.as_bytes();
unsafe {
*out_cstr = decode_str.as_ptr() as *mut u8;
*out_len = decode_str.len();
}
}

#[no_mangle]
pub extern "C" fn tokenizers_free(wrapper: *mut TokenizerWrapper) {
unsafe {
drop(Box::from_raw(wrapper));
}
}

#[no_mangle]
pub extern "C" fn tokenizers_get_vocab_size(handle: *mut TokenizerWrapper, size: *mut usize) {
unsafe {
*size = (*handle).get_vocab_size();
}
}

#[no_mangle]
pub extern "C" fn tokenizers_id_to_token(
handle: *mut TokenizerWrapper,
id: u32,
out_cstr: *mut *mut u8,
out_len: *mut usize,
) {
let token = unsafe { &mut *handle }.id_to_token(id).into_bytes();
unsafe {
*out_cstr = token.as_ptr() as *mut u8;
*out_len = token.len();
}
}

#[no_mangle]
pub extern "C" fn tokenizers_token_to_id(
handle: *mut TokenizerWrapper,
token: *const u8,
len: usize,
out_id: *mut i32,
) {
let token_str = unsafe { std::str::from_utf8(std::slice::from_raw_parts(token, len)).unwrap() };
let id = unsafe { &mut *handle }.token_to_id(token_str);
unsafe {
*out_id = match id {
Some(id) => id as i32,
None => -1,
};
}
}

#[repr(C)]
pub struct TokenizerEncodeResult {
token_ids: *mut u32,
len: usize,
}
3. 构建和发布

使用 maturin 来构建和发布 Python 包。确保安装了 maturin

pip install maturin

在项目根目录下运行以下命令来构建 Python 包:

maturin develop --release

这将构建并安装 Rust 扩展模块,使其可以在 Python 中导入和使用。

错误解决:pkg_config 错误

sudo apt install pkg-config

错误解决:ssl 错误

sudo apt install libssl-dev
4. 编写 Python 接口

现在,可以在 Python 中导入并使用这个包装器模块。以下是一个简单的 Python 接口示例:

import tokenizers_wrapper

# 从字符串创建 TokenizerWrapper 实例
json_str = '...' # 你的 tokenizer JSON 字符串
tokenizer = tokenizers_wrapper.TokenizerWrapper(json_str)

# 使用 byte-level BPE 创建 TokenizerWrapper 实例
vocab = '...' # 你的 vocab.json 字符串
merges = '...' # 你的 merges.txt 字符串
added_tokens = '...' # 你的 added_tokens.json 字符串
tokenizer = tokenizers_wrapper.TokenizerWrapper.byte_level_bpe_from_str(vocab, merges, added_tokens)

# 编码单个文本
text = "Hello, world!"
token_ids = tokenizer.encode(text, add_special_tokens=True)
print(token_ids)

# 批量编码文本
texts = ["Hello, world!", "Goodbye, world!"]
batch_token_ids = tokenizer.encode_batch(texts, add_special_tokens=True)
print(batch_token_ids)

# 解码 token IDs
decoded_text = tokenizer.decode(token_ids, skip_special_tokens=True)
print(decoded_text)

# 获取词汇表大小
vocab_size = tokenizer.get_vocab_size()
print(vocab_size)

# 从 token ID 获取 token
token = tokenizer.id_to_token(0)
print(token)

# 从 token 获取 token ID
token_id = tokenizer.token_to_id("Hello")
print(token_id)
5. 编写 C++ 接口

创建一个简单的 C++ 项目来使用这个包装器模块:

#include <iostream>
#include <vector>
#include "tokenizers_wrapper.h"

int main() {
const char* json_str = R"({"type": "bpe", "unk_token": "[UNK]"})";
TokenizerWrapper* tokenizer = tokenizers_new_from_str((const uint8_t*)json_str, strlen(json_str));

const char* text = "Hello, world!";
TokenizerEncodeResult result;
tokenizers_encode(tokenizer, (const uint8_t*)text, strlen(text), 1, &result);

std::vector<uint32_t> token_ids(result.token_ids, result.token_ids + result.len);
for (uint32_t id : token_ids) {
std::cout << id << " ";
}
std::cout << std::endl;

tokenizers_free_encode_results(&result, 1);
tokenizers_free(tokenizer);
return 0;
}
6. 构建和运行 C++ 项目

创建一个 CMakeLists.txt 文件来构建 C++ 项目:

cmake_minimum_required(VERSION 3.18)
project(tokenizers_cpp CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

add_executable(tokenizers_demo main.cpp)

# 设置 Tokenizers Wrapper 库的路径
set(TOKENIZERS_WRAPPER_LIB_PATH "/path/to/your/rust/library")
include_directories(${TOKENIZERS_WRAPPER_LIB_PATH}/include)
link_directories(${TOKENIZERS_WRAPPER_LIB_PATH}/lib)

# 链接 Tokenizers Wrapper 库
target_link_libraries(tokenizers_demo tokenizers_wrapper)

然后,使用 CMake 构建和运行项目:

mkdir build
cd build
cmake ..
make
./tokenizers_demo

rust调用c++

编写mnn_adapter.cpp并编译成so文件

extern "C" { 
int add(int a, int b) { return a + b; }
}

编译

gcc -fPIC -shared mnn_adapter.cpp -o libmnn_adapter.so

编写build.rs

extern crate dunce;
use std::{env, path::PathBuf};

fn main() {

let root_v = PathBuf::from(env::var_os("CARGO_MANIFEST_DIR").unwrap());
let library_dir_v = dunce::canonicalize(root_v.join("variable")).unwrap();

//库文件路径,与静态链接的区别是没有加库名称
println!("cargo:rustc-link-search=native={}", env::join_paths(&[library_dir_v]).unwrap().to_str().unwrap());
println!("cargo:rustc-link-lib=dylib=mnn_adapter");
}

编写main.rs


#![allow(non_snake_case)]
#[link(name="mnn_adapter")]
extern "C" { fn add(x: i32, y: i32) -> i32; }


fn main() {
// println!("Hello, world!");
let x = unsafe { add(62, 30)};
println!("{}", x); // 92
}

编写cargo.yml

[package]
name = "rust_demo"
version = "0.1.0"
edition = "2018"

[dependencies]
libc = "0.2"

[build_dependencies]
dunce = "1.0.0"
export LD_LIBRARY_PATH=./variable
cargo run

MNN示例

build.rs

extern crate dunce;
extern crate bindgen;
use std::{env, path::PathBuf};

fn main() {

let root_v = PathBuf::from(env::var_os("CARGO_MANIFEST_DIR").unwrap());
let library_dir_v = dunce::canonicalize(root_v.join("variable")).unwrap();

// 编译 C++ 文件并生成静态库
cc::Build::new()
.cpp(true)
.include("/home/jw/code/rust_condition_se/condition_se/third_party/mnn/mnn-linux-x64-2.9.0/include") // 指定 MNN 的头文件路径
.flag_if_supported("-std=c++14") // 设置 C++ 标准
.file("variable/mnn_adapter.cpp")
.compile("mnn_adapter"); // 生成 libmnn_adapter.a 静态库


//库文件路径,与静态链接的区别是没有加库名称
println!("cargo:rustc-link-search=native={}", env::join_paths(&[library_dir_v]).unwrap().to_str().unwrap());
println!("cargo:rustc-link-lib=dylib=mnn_adapter");
// 指定 MNN 的库文件路径
println!("cargo:rustc-link-search=native=/home/jw/code/rust_condition_se/condition_se/third_party/mnn/mnn-linux-x64-2.9.0/lib");
println!("cargo:rustc-link-lib=dylib=MNN");

// 链接 C++ 标准库
println!("cargo:rustc-link-lib=stdc++");
}

wapper.h

#include "MNN/Interpreter.hpp"

extern "C" {

using namespace MNN;

struct MNN_Interpreter {
std::shared_ptr<Interpreter> interpreter;
};

struct MNN_Session {
Session* session;
};

struct MNN_Tensor {
Tensor* tensor;
};

MNN_Interpreter* MNN_Interpreter_create(const char* model_path);
MNN_Session* MNN_Interpreter_createSession(MNN_Interpreter* interpreter);
MNN_Tensor* MNN_Interpreter_getSessionInput(MNN_Interpreter* interpreter, MNN_Session* session);
void MNN_Tensor_setData(MNN_Tensor* tensor, float* data, int size);
void MNN_Interpreter_runSession(MNN_Interpreter* interpreter, MNN_Session* session);
MNN_Tensor* MNN_Interpreter_getSessionOutput(MNN_Interpreter* interpreter, MNN_Session* session);
void MNN_Tensor_getData(MNN_Tensor* tensor, float* data, int size);
void MNN_Interpreter_destroy(MNN_Interpreter* interpreter);
void MNN_Session_destroy(MNN_Session* session);
void MNN_Tensor_destroy(MNN_Tensor* tensor);
}

mnn_adapter.cpp

#include "wrapper.h"
#include <cstring>

extern "C" {

MNN_Interpreter* MNN_Interpreter_create(const char* model_path) {
MNN_Interpreter* mnn_interpreter = new MNN_Interpreter;
mnn_interpreter->interpreter = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(model_path));
return mnn_interpreter;
}

// 其他函数定义保持不变
MNN_Session* MNN_Interpreter_createSession(MNN_Interpreter* interpreter) {
MNN_Session* mnn_session = new MNN_Session;
MNN::ScheduleConfig config;
mnn_session->session = interpreter->interpreter->createSession(config);
return mnn_session;
}

MNN_Tensor* MNN_Interpreter_getSessionInput(MNN_Interpreter* interpreter, MNN_Session* session) {
MNN_Tensor* mnn_tensor = new MNN_Tensor;
mnn_tensor->tensor = interpreter->interpreter->getSessionInput(session->session, nullptr);
return mnn_tensor;
}

void MNN_Tensor_setData(MNN_Tensor* tensor, float* data, int size) {
std::memcpy(tensor->tensor->host<float>(), data, size * sizeof(float));
}

void MNN_Interpreter_runSession(MNN_Interpreter* interpreter, MNN_Session* session) {
interpreter->interpreter->runSession(session->session);
}

MNN_Tensor* MNN_Interpreter_getSessionOutput(MNN_Interpreter* interpreter, MNN_Session* session) {
MNN_Tensor* mnn_tensor = new MNN_Tensor;
mnn_tensor->tensor = interpreter->interpreter->getSessionOutput(session->session, nullptr);
return mnn_tensor;
}

void MNN_Tensor_getData(MNN_Tensor* tensor, float* data, int size) {
std::memcpy(data, tensor->tensor->host<float>(), size * sizeof(float));
}

void MNN_Interpreter_destroy(MNN_Interpreter* interpreter) {
delete interpreter;
}

void MNN_Session_destroy(MNN_Session* session) {
delete session;
}

void MNN_Tensor_destroy(MNN_Tensor* tensor) {
delete tensor;
}

}

lib.rs

#![allow(non_snake_case)]

extern crate libc;
use libc::c_char;
use std::ffi::CString;
use std::ptr;

#[repr(C)]
pub struct MNN_Interpreter {
_private: [u8; 0],
}

#[repr(C)]
pub struct MNN_Session {
_private: [u8; 0],
}

#[repr(C)]
pub struct MNN_Tensor {
_private: [u8; 0],
}

#[link(name = "mnn_adapter")]
#[link(name = "MNN")]
#[link(name = "stdc++")]
extern "C" {
fn MNN_Interpreter_create(model_path: *const c_char) -> *mut MNN_Interpreter;
fn MNN_Interpreter_createSession(interpreter: *mut MNN_Interpreter) -> *mut MNN_Session;
fn MNN_Interpreter_getSessionInput(interpreter: *mut MNN_Interpreter, session: *mut MNN_Session) -> *mut MNN_Tensor;
fn MNN_Tensor_setData(tensor: *mut MNN_Tensor, data: *const f32, size: i32);
fn MNN_Interpreter_runSession(interpreter: *mut MNN_Interpreter, session: *mut MNN_Session);
fn MNN_Interpreter_getSessionOutput(interpreter: *mut MNN_Interpreter, session: *mut MNN_Session) -> *mut MNN_Tensor;
fn MNN_Tensor_getData(tensor: *mut MNN_Tensor, data: *mut f32, size: i32);
fn MNN_Interpreter_destroy(interpreter: *mut MNN_Interpreter);
fn MNN_Session_destroy(session: *mut MNN_Session);
fn MNN_Tensor_destroy(tensor: *mut MNN_Tensor);
}

pub struct Interpreter {
interpreter: *mut MNN_Interpreter,
}

impl Interpreter {
pub fn create(model_path: &str) -> Result<Self, &'static str> {
let c_model_path = CString::new(model_path).map_err(|_| "Failed to create CString")?;
let interpreter = unsafe { MNN_Interpreter_create(c_model_path.as_ptr() as *const c_char) };
if interpreter.is_null() {
Err("Failed to create MNN Interpreter")
} else {
Ok(Self { interpreter })
}
}

pub fn create_session(&self) -> Result<Session, &'static str> {
let session = unsafe { MNN_Interpreter_createSession(self.interpreter) };
if session.is_null() {
Err("Failed to create session")
} else {
Ok(Session { session })
}
}

pub fn run_session(&self, session: &Session) {
unsafe {
MNN_Interpreter_runSession(self.interpreter, session.session);
}
}
}

impl Drop for Interpreter {
fn drop(&mut self) {
if !self.interpreter.is_null() {
unsafe {
MNN_Interpreter_destroy(self.interpreter);
}
}
}
}

pub struct Session {
session: *mut MNN_Session,
}

impl Session {
pub fn get_input(&self, interpreter: &Interpreter) -> Result<Tensor, &'static str> {
let tensor = unsafe { MNN_Interpreter_getSessionInput(interpreter.interpreter, self.session) };
if tensor.is_null() {
Err("Failed to get input tensor")
} else {
Ok(Tensor { tensor })
}
}

pub fn get_output(&self, interpreter: &Interpreter) -> Result<Tensor, &'static str> {
let tensor = unsafe { MNN_Interpreter_getSessionOutput(interpreter.interpreter, self.session) };
if tensor.is_null() {
Err("Failed to get output tensor")
} else {
Ok(Tensor { tensor })
}
}
}

impl Drop for Session {
fn drop(&mut self) {
if !self.session.is_null() {
unsafe {
MNN_Session_destroy(self.session);
}
}
}
}

pub struct Tensor {
tensor: *mut MNN_Tensor,
}

impl Tensor {
pub fn set_data(&self, data: &[f32]) {
unsafe {
MNN_Tensor_setData(self.tensor, data.as_ptr(), data.len() as i32);
}
}

pub fn get_data(&self, data: &mut [f32]) {
unsafe {
MNN_Tensor_getData(self.tensor, data.as_mut_ptr(), data.len() as i32);
}
}
}

impl Drop for Tensor {
fn drop(&mut self) {
if !self.tensor.is_null() {
unsafe {
MNN_Tensor_destroy(self.tensor);
}
}
}
}

main.rs

extern crate rust_demo;

use rust_demo::{Interpreter, Tensor};

fn main() {
let model_path = "/home/jw/code/rust_condition_se/condition_se/libri_aishell2_fsd50k_NAMFmid1D_160_lsnr1_l1100_wsdr_mspec_nokl_lag7_l1noise_1111_flim_mf_emb512_1dec/model_se.mnn";
match Interpreter::create(model_path) {
Ok(interpreter) => {
match interpreter.create_session() {
Ok(session) => {
match session.get_input(&interpreter) {
Ok(input_tensor) => {
let input_data = vec![0.0f32; 1 * 3 * 224 * 224]; // 假设输入是 1x3x224x224 的图像
input_tensor.set_data(&input_data);

interpreter.run_session(&session);

match session.get_output(&interpreter) {
Ok(output_tensor) => {
let mut output_data = vec![0.0f32; 1000]; // 假设输出大小为 1000
output_tensor.get_data(&mut output_data);

println!("Output data: {:?}", &output_data[..10]);
}
Err(err) => eprintln!("Error getting output tensor: {}", err),
}
}
Err(err) => eprintln!("Error getting input tensor: {}", err),
}
}
Err(err) => eprintln!("Error creating session: {}", err),
}
}
Err(err) => eprintln!("Error: {}", err),
}
}

c++调用rust

cbindgen自动生成header文件

安装
cargo install cbindgen
编写cbindgen.toml
header = "// SPDX-License-Identifier: MIT OR Apache-2.0"
sys_includes = ["stddef.h", "stdint.h", "stdlib.h"]
no_includes = true
include_version = true
include_guard = "CONDITION_SE_H"
tab_width = 4
style = "Type"
language = "C"
cpp_compat = true

在cargo.toml中添加(后面测试好像不用加)

[package.metadata.capi.header]
name = "condition_se"
subdirectory = "condition_se"
[package.metadata.capi.pkg_config]
name = "libconditionse"
filename = "conditionse"
[package.metadata.capi.library]
name = "conditionse"
运行
cbindgen --config cbindgen.toml --crate condition_se --output ../src/condition_se_cc.h

出现问题跨项目类型依赖时,无法导出NoiseType

在cbindgen.toml中添加

[parse]
# Enable parsing of dependencies
parse_deps = true

# List items to include in the bindings
include = ["NoiseType"]

[parse.expand]
# Enable all features if necessary
features = ["all"]

[export]
# Ensure NoiseType and TestEnum are exported correctly
include = ["NoiseType"]

rust 运行bin

cargo run --bin aec --release

使用feature

cargo run --bin aec --release --features="aec bin"

rust的jupyter

为了便于进行算法模块测试,使用jupyter更方便

安装jupyter lab
conda install -c conda-forge jupyterlab
#plotly 用于渲染可视化图表
jupyter labextension install jupyterlab-plotly
#更换主题
jupyter labextension install @shahinrostami/theme-purple-please
安装 evcxr
cargo install evcxr_jupyter
evcxr_jupyter --install
基础语法
#导入
:timing
:dep { rand = "0.7.3" }
:dep { log = "0.4.11" }
:dep my_crate = { path = "." }
let values = (1..13).map(fib).collect::<Vec<i32>>();
values
显示所有变量
:vars

自定义显示

use std::fmt::Debug;
pub struct Matrix<T> {pub values: Vec<T>, pub row_size: usize}
impl<T: Debug> Matrix<T> {
pub fn evcxr_display(&self) {
let mut html = String::new();
html.push_str("<table>");
for r in 0..(self.values.len() / self.row_size) {
html.push_str("<tr>");
for c in 0..self.row_size {
html.push_str("<td>");
html.push_str(&format!("{:?}", self.values[r * self.row_size + c]));
html.push_str("</td>");
}
html.push_str("</tr>");
}
html.push_str("</table>");
println!("EVCXR_BEGIN_CONTENT text/html\n{}\nEVCXR_END_CONTENT", html);
}
}
let m = Matrix {values: vec![1,2,3,4,5,6,7,8,9], row_size: 3};
m

显示图片

:dep image = "0.23"
:dep evcxr_image = "1.1"
use evcxr_image::ImageDisplay;

image::ImageBuffer::from_fn(256, 256, |x, y| {
if (x as i32 - y as i32).abs() < 3 {
image::Rgb([0, 0, 255])
} else {
image::Rgb([0, 0, 0])
}
})
配置jupyter lab代码补全(vscode更好用)

# 安装jupyterlab-lsp
pip install jupyter-lsp

# 安装python-lsp-server
pip install python-lsp-server

设置

依次点击Settings—>Settings Editor—>Code Completion(第二个)—>打开Enable autocompletion

无法编译,无法启动

通常是因为环境有冲突,与其找原因不如直接删除缓存,重新编译

sudo rm -r ~/.cargo/registry/
hdf5编译出错
apt install libhdf5-dev
ort 的rocm模式

重装rocm

# Uninstall single-version ROCm packages
sudo apt autoremove rocm-core

# Uninstall Kernel-mode Driver
sudo apt autoremove amdgpu-dkms

# remove apt source
sudo rm /etc/apt/sources.list.d/<rocm_repository-name>.list
sudo rm /etc/apt/sources.list.d/<amdgpu_repository-name>.list
sudo rm /etc/apt/sources.list.d/rocm.list
sudo rm /etc/apt/sources.list.d/amdgpu.list

sudo rm -rf /var/cache/apt/*
sudo apt-get clean all

sudo reboot

安装参考rocm环境配置,可以安装6.0.2

编译onnxruntime

ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ONNXRUNTIME_BRANCH=main
git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
cd onnxruntime &&\
/bin/sh ./build.sh --config Release --build_shared_lib --parallel --cmake_extra_defines\
ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm
------ 本文结束 🎉🎉 谢谢观看 ------
0%