Skip to content

Commit

Permalink
Initial Rust bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Baker committed Nov 11, 2024
1 parent 413719b commit 5b67cc0
Show file tree
Hide file tree
Showing 11 changed files with 925 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,7 @@ go.work
go.work.sum
bindings/go/pkg/
bindings/go/go.sum
*~
\#*
bindings/rust/target
bindings/rust/Cargo.lock
12 changes: 12 additions & 0 deletions bindings/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "ucx-sys"
version = "0.1.0"
edition = "2021"
description = "Low level Rust bindings"
license = "BSD"

[build-dependencies]
bindgen = "0.70.1"

[dependencies.bitflags]
version = "2.6.0"
45 changes: 45 additions & 0 deletions bindings/rust/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use std::env;
use std::path::PathBuf;

fn main() {
// Tell cargo to look for shared libraries in the specified directory
println!("cargo:rustc-link-search=../../src/ucp/.libs/");

// Tell cargo to tell rustc to link the system bzip2
// shared library.
println!("cargo:rustc-link-lib=ucp");

// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// Some of the UCX detailed examples in comments can confuse the
// bindgen parser and it will make bad code instead of comments
.generate_comments(false)
// ucs_status_t is defined as a packed enum and that will lead to
// badness without the flag which tells bindgen to repeat that
// trick with the rust enums
.rustified_enum(".*")
.clang_arg("-I../../src/ucp/api/")
.clang_arg("-I../../")
.clang_arg("-I../../src/")
// Annotate ucs_status_t and ucs_status_ptr_t as #[must_use]
.must_use_type("ucs_status_t")
.must_use_type("ucs_status_ptr_t")
// The input header we would like to generate
// bindings for.
.header("wrapper.h")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");

// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}
2 changes: 2 additions & 0 deletions bindings/rust/rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
hard_tabs = false
tab_spaces = 4
120 changes: 120 additions & 0 deletions bindings/rust/src/am.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use crate::ep::Ep;
use crate::ffi::*;
use crate::status_ptr_to_result;
use crate::status_to_result;
use crate::worker::Worker;
use crate::Request;
use crate::RequestParam;
use bitflags::bitflags;

type AmRecvCb = unsafe extern "C" fn(
arg: *mut ::std::os::raw::c_void,
header: *const ::std::os::raw::c_void,
header_length: usize,
data: *mut ::std::os::raw::c_void,
length: usize,
param: *const ucp_am_recv_param_t,
) -> ucs_status_t;

impl Worker<'_> {
#[inline]
pub fn am_register(&self, am_param: &HandlerParams) -> Result<(), ucs_status_t> {
status_to_result(unsafe { ucp_worker_set_am_recv_handler(self.handle, &am_param.handle) })
}
}

impl Ep<'_> {
#[inline]
pub fn am_send(
&self,
id: u32,
header: &[u8],
data: &[u8],
params: &RequestParam,
) -> Result<Option<Request>, ucs_status_t> {
status_ptr_to_result(unsafe {
ucp_am_send_nbx(
self.handle,
id,
header.as_ptr() as _,
header.len(),
data.as_ptr() as _,
data.len(),
&params.handle,
)
})
}
}

bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CbFlags: u32 {
const WholeMsg = ucp_am_cb_flags::UCP_AM_FLAG_WHOLE_MSG as u32;
const PersistentData = ucp_am_cb_flags::UCP_AM_FLAG_PERSISTENT_DATA as u32;
}
}

#[derive(Debug, Clone)]
pub struct HandlerParamsBuilder {
uninit_handle: std::mem::MaybeUninit<ucp_am_handler_param_t>,
flags: u64,
}

impl HandlerParamsBuilder {
#[inline]
pub fn new() -> HandlerParamsBuilder {
let uninit_params = std::mem::MaybeUninit::<ucp_am_handler_param_t>::uninit();
HandlerParamsBuilder {
uninit_handle: uninit_params,
flags: 0,
}
}

#[inline]
pub fn id(&mut self, id: u32) -> &mut HandlerParamsBuilder {
self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_ID as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.id = id;
self
}

#[inline]
pub fn flags(&mut self, flags: CbFlags) -> &mut HandlerParamsBuilder {
self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_FLAGS as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.flags = flags.bits();
self
}

#[inline]
pub fn cb(&mut self, cb: AmRecvCb) -> &mut HandlerParamsBuilder {
self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_CB as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.cb = Some(cb);
self
}

#[inline]
pub fn arg(&mut self, arg: *mut std::os::raw::c_void) -> &mut HandlerParamsBuilder {
self.flags |= ucp_am_handler_param_field::UCP_AM_HANDLER_PARAM_FIELD_ARG as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.arg = arg;
self
}

#[inline]
pub fn build(&mut self) -> HandlerParams {
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.field_mask = self.flags;

let handler_param = HandlerParams {
handle: unsafe { self.uninit_handle.assume_init() },
};

handler_param
}
}

pub struct HandlerParams {
pub(crate) handle: ucp_am_handler_param_t,
}
201 changes: 201 additions & 0 deletions bindings/rust/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use crate::ffi::*;
use crate::status_to_result;
use crate::worker;
use crate::worker::Worker;
use bitflags::bitflags;
use std::ffi::CString;

type RequestInitCb = unsafe extern "C" fn(request: *mut ::std::os::raw::c_void);
type RequestCleanUpCb = unsafe extern "C" fn(request: *mut ::std::os::raw::c_void);

bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Flags: u64 {
const Tag = ucp_feature::UCP_FEATURE_TAG as u64;
const Rma = ucp_feature::UCP_FEATURE_RMA as u64;
const Amo32 = ucp_feature::UCP_FEATURE_AMO32 as u64;
const Amo64 = ucp_feature::UCP_FEATURE_AMO64 as u64;
const Wakeup = ucp_feature::UCP_FEATURE_WAKEUP as u64;
const Stream = ucp_feature::UCP_FEATURE_STREAM as u64;
const Am = ucp_feature::UCP_FEATURE_AM as u64;
const ExportedMemH = ucp_feature::UCP_FEATURE_EXPORTED_MEMH as u64;
}
}

pub struct Config {
handle: *mut ucp_config_t,
}

impl Config {
pub fn read(name: &str, file: &str) -> Result<*mut ucp_config_t, ucs_status_t> {
let mut config: *mut ucp_config_t = std::ptr::null_mut();
let c_name = CString::new(name).unwrap();
let c_file = CString::new(file).unwrap();
status_to_result(unsafe { ucp_config_read(c_name.as_ptr(), c_file.as_ptr(), &mut config) })
.unwrap();
return Ok(config);
}
}

impl Default for Config {
fn default() -> Self {
let config = Config::read("", "").unwrap();
Config { handle: config }
}
}

impl Drop for Config {
fn drop(&mut self) {
unsafe { ucp_config_release(self.handle) };
}
}

#[derive(Debug, Clone)]
pub struct ParamsBuilder {
uninit_handle: std::mem::MaybeUninit<ucp_params_t>,
field_mask: u64,
name: Option<CString>,
}

#[derive(Debug, Clone)]
pub struct Params {
handle: ucp_params_t,
name: Option<CString>,
}

// This builder wraps up the unsafe parts of building the ucp_param_t struct. On construction
// it makes a zero filled ucp_params_t which Rust considers unitialized. Each call on the builder
// will fill in the fields of the struct and add the mask for that field. On the final build()
// it will fill in the final value of the features field_mask and proclame the rest of the struct
// as initalized. This is Rust safe because all of the other fields are guarenteed to not be used
// by the library since the proper feature flag is not set.

impl ParamsBuilder {
pub fn new() -> ParamsBuilder {
let uninit_params = std::mem::MaybeUninit::<ucp_params_t>::uninit();
ParamsBuilder {
uninit_handle: uninit_params,
field_mask: 0,
name: None,
}
}

pub fn features(&mut self, features: Flags) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_FEATURES as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.features = features.bits();
self
}

pub fn request_size(&mut self, size: usize) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_SIZE as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.request_size = size;
self
}

pub fn request_init(&mut self, cb: RequestInitCb) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_INIT as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };

params.request_init = Some(cb);
self
}

pub fn request_cleanup(&mut self, cb: RequestCleanUpCb) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_REQUEST_CLEANUP as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.request_cleanup = Some(cb);
self
}

pub fn tag_sender_mask(&mut self, mask: u64) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_TAG_SENDER_MASK as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.tag_sender_mask = mask;
self
}

pub fn mt_workers_shared(&mut self, shared: i32) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_MT_WORKERS_SHARED as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.mt_workers_shared = shared;
self
}

pub fn estimated_num_eps(&mut self, num_eps: usize) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_ESTIMATED_NUM_EPS as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.estimated_num_eps = num_eps;
self
}

pub fn estimated_num_ppn(&mut self, num_ppn: usize) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_ESTIMATED_NUM_PPN as u64;
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.estimated_num_ppn = num_ppn;
self
}

pub fn name(&mut self, name: &str) -> &mut ParamsBuilder {
self.field_mask |= ucp_params_field::UCP_PARAM_FIELD_NAME as u64;
let name_cs = CString::new(name).unwrap();
self.name = Some(name_cs);
self
}

pub fn build(&mut self) -> Params {
let params = unsafe { &mut *self.uninit_handle.as_mut_ptr() };
params.field_mask = self.field_mask;

let mut ucp_param = Params {
name: None,
handle: unsafe { self.uninit_handle.assume_init() },
};

if self.name.is_some() {
let new_name = self.name.clone().unwrap();
ucp_param.handle.name = new_name.as_ptr();
ucp_param.name = Some(new_name);
}

ucp_param
}
}

impl Context {
pub fn new(config: &Config, params: &Params) -> Result<Context, ucs_status_t> {
let mut context: ucp_context_h = std::ptr::null_mut();

let result = status_to_result(unsafe {
ucp_init_version(
UCP_API_MAJOR,
UCP_API_MINOR,
&params.handle,
config.handle,
&mut context,
)
});
match result {
Ok(()) => Ok(Context { handle: context }),
Err(ucs_status_t) => Err(ucs_status_t),
}
}

pub fn worker_create<'a>(
&'a self,
params: &'a worker::Params,
) -> Result<Worker<'a>, ucs_status_t> {
Worker::new(self, params)
}
}

pub struct Context {
pub(crate) handle: ucp_context_h,
}

impl Drop for Context {
fn drop(&mut self) {
unsafe { ucp_cleanup(self.handle) };
}
}
Loading

0 comments on commit 5b67cc0

Please sign in to comment.