diff --git a/build.sbt b/build.sbt index ea73cb7..73efd43 100644 --- a/build.sbt +++ b/build.sbt @@ -14,7 +14,13 @@ val path = settingKey[Map[String, String]]("paths") fork := true javaOptions += "--enable-preview" -javacOptions ++= Seq("--enable-preview", "--release", "20") +javacOptions ++= Seq( + "--enable-preview", + "--release", + "20", + "-Xlint:unchecked", + "-Xlint:deprecation" +) lazy val root = project .in(file(".")) diff --git a/glue/src/e2e.rs b/glue/src/e2e.rs index 3e05a0c..0cc7e8f 100644 --- a/glue/src/e2e.rs +++ b/glue/src/e2e.rs @@ -1,37 +1,82 @@ +use std::ptr::drop_in_place; use std::sync::Once; +use std::time::Duration; +use jni::objects::GlobalRef; use no_std_net::SocketAddr; use toad::net::Addrd; use toad::platform::Platform; -use toad_jni::java::{self, Object}; +use toad_jni::java::lang::System; +use toad_jni::java::{self, Object, Signature}; use toad_msg::alloc::Message; use toad_msg::{Code, Id, Token, Type}; +use crate::message_ref::MessageRef; +use crate::message_type::MessageType; use crate::runtime::Runtime; use crate::runtime_config::RuntimeConfig; -pub fn runtime_init<'a>() -> (Runtime, java::Env<'a>) { +#[non_exhaustive] +struct State { + pub runtime: Runtime, + pub env: java::Env<'static>, + pub client: crate::Runtime, + pub srv_addr: SocketAddr, +} + +fn init() -> State { let mut _env = crate::test::init(); let env = &mut _env; let cfg = RuntimeConfig::new(env); let runtime = Runtime::get_or_init(env, cfg); - (runtime, _env) + let client = crate::Runtime::try_new("0.0.0.0:5684", Default::default()).unwrap(); + + State { runtime, + env: _env, + client, + srv_addr: "0.0.0.0:5683".parse().unwrap() } } -fn runtime_poll_req(runtime: &Runtime, env: &mut java::Env) { +fn runtime_poll_req(State { runtime, + env, + client, + srv_addr, + .. }: &mut State) { assert!(runtime.poll_req(env).is_none()); - let client = crate::Runtime::try_new("0.0.0.0:5684", Default::default()).unwrap(); let request = Message::new(Type::Con, Code::GET, Id(0), Token(Default::default())); - client.send_msg(Addrd(request, "0.0.0.0:5683".parse().unwrap())) - .unwrap(); + client.send_msg(Addrd(request, *srv_addr)).unwrap(); assert!(runtime.poll_req(env).is_some()); } +fn message_ref_should_throw_when_used_after_close(State {runtime, env, client, srv_addr, ..}: &mut State) +{ + let request = Message::new(Type::Con, Code::GET, Id(0), Token(Default::default())); + client.send_msg(Addrd(request, *srv_addr)).unwrap(); + let req = runtime.poll_req(env).unwrap(); + + assert_eq!(req.ty(env), Type::Con); + req.close(env); + + let req_o = req.downcast(env); + env.call_method(req_o.as_local(), + "type", + Signature::of:: MessageType>(), + &[]) + .ok(); + + let err = env.exception_occurred().unwrap(); + env.exception_clear().unwrap(); + assert!(env.is_instance_of(err, + concat!(package!(dev.toad.RefHawk), "$IllegalStorageOfRefError")) + .unwrap()); +} + #[test] fn e2e_test_suite() { - let (runtime, mut env) = runtime_init(); - runtime_poll_req(&runtime, &mut env); + let mut state = init(); + runtime_poll_req(&mut state); + message_ref_should_throw_when_used_after_close(&mut state); } diff --git a/glue/src/lib.rs b/glue/src/lib.rs index 379511b..b487670 100644 --- a/glue/src/lib.rs +++ b/glue/src/lib.rs @@ -3,7 +3,7 @@ use std::ffi::c_void; use jni::JavaVM; -use mem::RuntimeAllocator; +use mem::SharedMemoryRegion; pub type Runtime = toad::std::Platform>; @@ -34,7 +34,7 @@ pub extern "system" fn JNI_OnLoad(jvm: JavaVM, _: *const c_void) -> i32 { #[no_mangle] pub extern "system" fn JNI_OnUnload(_: JavaVM, _: *const c_void) { - unsafe { mem::Runtime::dealloc() } + unsafe { mem::Shared::dealloc() } } #[cfg(all(test, feature = "e2e"))] diff --git a/glue/src/mem.rs b/glue/src/mem.rs index 39b3266..7b74916 100644 --- a/glue/src/mem.rs +++ b/glue/src/mem.rs @@ -1,54 +1,88 @@ -/// global [`RuntimeAllocator`] implementation -pub type Runtime = RuntimeGlobalStaticAllocator; +use std::sync::Mutex; -/// Trait managing the memory region(s) associated with the toad runtime -/// data structure. -/// -/// Notably, any and all references produced by the runtime will be to data -/// within the Runtime's memory region, meaning that we can easily leverage -/// strict provenance to prevent addresses from leaking outside of that memory region. -pub trait RuntimeAllocator: core::default::Default + core::fmt::Debug + Copy { +use toad_msg::alloc::Message; + +/// global [`RuntimeAllocator`] implementation +pub type Shared = GlobalStatic; + +/// Trait managing the memory region(s) which java will store pointers to +pub trait SharedMemoryRegion: core::default::Default + core::fmt::Debug + Copy { /// Allocate memory for the runtime and yield a stable pointer to it - unsafe fn alloc(r: impl FnOnce() -> crate::Runtime) -> *mut crate::Runtime; + /// + /// This is idempotent and will only invoke the provided callback if the runtime + /// has not already been initialized. + unsafe fn init(r: impl FnOnce() -> crate::Runtime) -> *mut crate::Runtime; + + /// Pass ownership of a [`Message`] to the shared memory region, + /// yielding a stable pointer to this message. + unsafe fn alloc_message(m: Message) -> *mut Message; + + /// Delete a message from the shared memory region. + unsafe fn dealloc_message(m: *mut Message); /// Teardown - unsafe fn dealloc() {} + unsafe fn dealloc(); - /// Coerce a `long` rep of the stable pointer created by [`Self::alloc`] to - /// a pointer (preferably using strict_provenance) - unsafe fn deref(addr: i64) -> *mut crate::Runtime; + unsafe fn shared_region(addr: i64) -> *mut u8; /// Coerce a `long` rep of a pointer to some data within the - /// Runtime data structure. - /// - /// Requires the Runtime address in order for the new pointer - /// to inherit its provenance. - unsafe fn deref_inner(runtime_addr: i64, addr: i64) -> *mut T { - Self::deref(runtime_addr).with_addr(addr as usize) - .cast::() + /// shared memory region. + unsafe fn deref(shared_region_addr: i64, addr: i64) -> *mut T { + Self::shared_region(shared_region_addr).with_addr(addr as usize) + .cast::() } } -static mut RUNTIME: *mut crate::Runtime = core::ptr::null_mut(); +static mut MEM: *mut Mem = core::ptr::null_mut(); + +struct Mem { + runtime: crate::Runtime, + messages: Vec, + + /// Lock used by `alloc_message` and `dealloc_message` to ensure + /// they are run serially. + /// + /// This doesn't provide any guarantees that message pointers will + /// stay valid or always point to the correct location, but it does + /// ensure we don't accidentally yield the wrong pointer from `alloc_message` + /// or delete the wrong message in `dealloc_message`. + messages_lock: Mutex<()>, +} #[derive(Default, Debug, Clone, Copy)] -pub struct RuntimeGlobalStaticAllocator; -impl RuntimeAllocator for RuntimeGlobalStaticAllocator { - /// Nops on already-init - unsafe fn alloc(r: impl FnOnce() -> crate::Runtime) -> *mut crate::Runtime { - if RUNTIME.is_null() { - RUNTIME = Box::into_raw(Box::new(r())); - RUNTIME - } else { - RUNTIME - } - } - +pub struct GlobalStatic; +impl SharedMemoryRegion for GlobalStatic { unsafe fn dealloc() { - drop(Box::from_raw(RUNTIME)); + drop(Box::from_raw(MEM)); } - unsafe fn deref(_: i64) -> *mut crate::Runtime { - RUNTIME + unsafe fn init(r: impl FnOnce() -> crate::Runtime) -> *mut crate::Runtime { + if MEM.is_null() { + MEM = Box::into_raw(Box::new(Mem { runtime: r(), + messages: vec![], + messages_lock: Mutex::new(()) })); + } + + &mut (*MEM).runtime as _ + } + + unsafe fn alloc_message(m: Message) -> *mut Message { + let _lock = (*MEM).messages_lock.lock(); + (*MEM).messages.push(m); + &mut (*MEM).messages[(*MEM).messages.len() - 1] as _ + } + + unsafe fn dealloc_message(m: *mut Message) { + let _lock = (*MEM).messages_lock.lock(); + let ix = m.offset_from((*MEM).messages.as_slice().as_ptr()); + if ix.is_negative() { + panic!() + } + + (*MEM).messages.remove(ix as usize); + } + + unsafe fn shared_region(_: i64) -> *mut u8 { + MEM as _ } } diff --git a/glue/src/message_opt_ref.rs b/glue/src/message_opt_ref.rs index de4cc21..21a5ee2 100644 --- a/glue/src/message_opt_ref.rs +++ b/glue/src/message_opt_ref.rs @@ -4,7 +4,7 @@ use jni::JNIEnv; use toad_jni::java::{self, Object}; use toad_msg::{OptNumber, OptValue}; -use crate::mem::RuntimeAllocator; +use crate::mem::SharedMemoryRegion; use crate::message_opt_value_ref::MessageOptValueRef; pub struct MessageOptRef(pub java::lang::Object); @@ -27,8 +27,8 @@ impl MessageOptRef { } pub unsafe fn values_ptr<'a>(addr: i64) -> &'a mut Vec>> { - crate::mem::Runtime::deref_inner::>>>(/* TODO */ 0, addr).as_mut() - .unwrap() + crate::mem::Shared::deref::>>>(/* TODO */ 0, addr).as_mut() + .unwrap() } } @@ -53,5 +53,5 @@ pub extern "system" fn Java_dev_toad_msg_MessageOptionRef_values<'local>(mut e: .map(|v| MessageOptValueRef::new(&mut e, (&v.0 as *const Vec).addr() as i64)) .collect::>(); - refs.downcast(&mut e).as_raw() + refs.yield_to_java(&mut e) } diff --git a/glue/src/message_opt_value_ref.rs b/glue/src/message_opt_value_ref.rs index f2c8d85..5636f03 100644 --- a/glue/src/message_opt_value_ref.rs +++ b/glue/src/message_opt_value_ref.rs @@ -3,7 +3,7 @@ use jni::sys::jobject; use toad_jni::java; use toad_msg::OptValue; -use crate::mem::RuntimeAllocator; +use crate::mem::SharedMemoryRegion; pub struct MessageOptValueRef(java::lang::Object); @@ -19,8 +19,8 @@ impl MessageOptValueRef { } pub unsafe fn ptr<'a>(addr: i64) -> &'a mut OptValue> { - crate::mem::Runtime::deref_inner::>>(/* TODO */ 0, addr).as_mut() - .unwrap() + crate::mem::Shared::deref::>>(/* TODO */ 0, addr).as_mut() + .unwrap() } } diff --git a/glue/src/message_ref.rs b/glue/src/message_ref.rs index f1f3965..c846c5b 100644 --- a/glue/src/message_ref.rs +++ b/glue/src/message_ref.rs @@ -2,8 +2,9 @@ use jni::objects::JClass; use jni::sys::jobject; use toad_jni::java::{self, Object}; use toad_msg::alloc::Message; +use toad_msg::Type; -use crate::mem::RuntimeAllocator; +use crate::mem::{Shared, SharedMemoryRegion}; use crate::message_code::MessageCode; use crate::message_opt_ref::MessageOptRef; use crate::message_type::MessageType; @@ -16,14 +17,25 @@ impl java::Class for MessageRef { } impl MessageRef { - pub fn new(env: &mut java::Env, addr: *const Message) -> Self { + pub fn new(env: &mut java::Env, message: Message) -> Self { + let ptr = unsafe { Shared::alloc_message(message) }; static CTOR: java::Constructor = java::Constructor::new(); - CTOR.invoke(env, addr.addr() as i64) + CTOR.invoke(env, ptr.addr() as i64) + } + + pub fn close(&self, env: &mut java::Env) { + static CLOSE: java::Method = java::Method::new("close"); + CLOSE.invoke(env, self) + } + + pub fn ty(&self, env: &mut java::Env) -> Type { + static TYPE: java::Method MessageType> = java::Method::new("type"); + TYPE.invoke(env, self).to_toad(env) } pub unsafe fn ptr<'a>(addr: i64) -> &'a mut Message { - crate::mem::Runtime::deref_inner::(/* TODO */ 0, addr).as_mut() - .unwrap() + crate::mem::Shared::deref::(/* TODO */ 0, addr).as_mut() + .unwrap() } } @@ -60,7 +72,7 @@ pub extern "system" fn Java_dev_toad_msg_MessageRef_type<'local>(mut e: java::En addr: i64) -> jobject { let msg = unsafe { MessageRef::ptr(addr) }; - MessageType::new(&mut e, msg.ty).downcast(&mut e).as_raw() + MessageType::new(&mut e, msg.ty).yield_to_java(&mut e) } #[no_mangle] @@ -69,7 +81,7 @@ pub extern "system" fn Java_dev_toad_msg_MessageRef_code<'local>(mut e: java::En addr: i64) -> jobject { let msg = unsafe { MessageRef::ptr(addr) }; - MessageCode::new(&mut e, msg.code).downcast(&mut e).as_raw() + MessageCode::new(&mut e, msg.code).yield_to_java(&mut e) } #[no_mangle] @@ -84,5 +96,5 @@ pub extern "system" fn Java_dev_toad_msg_MessageRef_opts<'local>(mut e: java::En .map(|(n, v)| MessageOptRef::new(&mut e, v as *const _ as i64, n.0.into())) .collect::>(); - refs.downcast(&mut e).as_raw() + refs.yield_to_java(&mut e) } diff --git a/glue/src/message_type.rs b/glue/src/message_type.rs index a27beaa..8708abb 100644 --- a/glue/src/message_type.rs +++ b/glue/src/message_type.rs @@ -22,4 +22,16 @@ impl MessageType { FROM_STRING.invoke(env, str.to_string()) } + + pub fn to_toad(&self, env: &mut java::Env) -> Type { + static TO_STRING: java::Method String> = java::Method::new("toString"); + + match TO_STRING.invoke(env, self).trim().to_uppercase().as_str() { + | "CON" => Type::Con, + | "NON" => Type::Non, + | "ACK" => Type::Ack, + | "RESET" => Type::Reset, + | o => panic!("malformed message type {}", o), + } + } } diff --git a/glue/src/runtime.rs b/glue/src/runtime.rs index 1437b48..132516f 100644 --- a/glue/src/runtime.rs +++ b/glue/src/runtime.rs @@ -3,7 +3,7 @@ use jni::sys::jobject; use toad::platform::Platform; use toad_jni::java::{self, Object}; -use crate::mem::RuntimeAllocator; +use crate::mem::SharedMemoryRegion; use crate::message_ref::MessageRef; use crate::runtime_config::RuntimeConfig; use crate::Runtime as ToadRuntime; @@ -29,19 +29,22 @@ impl Runtime { } pub fn ref_(&self, e: &mut java::Env) -> &'static ToadRuntime { - unsafe { crate::mem::Runtime::deref(self.addr(e)).as_ref().unwrap() } + unsafe { + crate::mem::Shared::deref::(0, self.addr(e)).as_ref() + .unwrap() + } } fn init_impl(e: &mut java::Env, cfg: RuntimeConfig) -> i64 { let r = || ToadRuntime::try_new(format!("0.0.0.0:{}", cfg.net(e).port(e)), cfg.to_toad(e)).unwrap(); - unsafe { crate::mem::Runtime::alloc(r).addr() as i64 } + unsafe { crate::mem::Shared::init(r).addr() as i64 } } fn poll_req_impl(&self, e: &mut java::Env) -> java::util::Optional { match self.ref_(e).poll_req() { | Ok(req) => { - let mr = MessageRef::new(e, req.data().msg()); + let mr = MessageRef::new(e, req.unwrap().into()); java::util::Optional::::of(e, mr) }, | Err(nb::Error::WouldBlock) => java::util::Optional::::empty(e), @@ -77,7 +80,5 @@ pub extern "system" fn Java_dev_toad_Runtime_pollReq<'local>(mut e: java::Env<'l let e = &mut e; java::lang::Object::from_local(e, runtime).upcast_to::(e) .poll_req_impl(e) - .downcast(e) - .to_local(e) - .as_raw() + .yield_to_java(e) } diff --git a/src/main/java/dev.toad/RefHawk.java b/src/main/java/dev.toad/RefHawk.java new file mode 100644 index 0000000..fea98e5 --- /dev/null +++ b/src/main/java/dev.toad/RefHawk.java @@ -0,0 +1,89 @@ +package dev.toad; + +import java.lang.ref.Cleaner; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; + +/** + * Static class used to track pointers issued by rust code + * + * When an object instance containing a pointer tracked by RefHawk + * is not automatically freed before `RefHawk.ensureReleased` invoked, + * an error is thrown indicating incorrect usage of an object containing + * a native pointer. + */ +public class RefHawk { + + private static final HashSet addrs = new HashSet<>(); + + private RefHawk() {} + + public static class IllegalStorageOfRefError extends Error { + + private static final String fmt = + "Instance of %s may not be stored by user code.\n" + + "Object was registered by:\n" + + ">>>>>>\n" + + "%s\n" + + "<<<<<<\n"; + + IllegalStorageOfRefError(Ptr ptr) { + super(String.format(fmt, ptr.clazz, ptr.trace)); + } + } + + public static class Ptr { + + protected final long addr; + private final String clazz; + private final String trace; + + Ptr(long addr, String clazz, String trace) { + this.clazz = clazz; + this.addr = addr; + this.trace = trace; + } + + public long addr() { + RefHawk.ensureValid(this); + return this.addr; + } + } + + /** + * Associate an object with a raw `long` pointer and a short text + * describing the scope in which the object is intended to be valid for + * (e.g. "lambda") + */ + public static Ptr register(Class c, long addr) { + var trace = Thread.currentThread().getStackTrace(); + var traceStr = Arrays + .asList(trace) + .stream() + .skip(2) + .map(StackTraceElement::toString) + .reduce("", (s, tr) -> s == "" ? tr : s + "\n\t" + tr); + + RefHawk.addrs.add(addr); + return new Ptr(addr, c.toString(), traceStr); + } + + /** + * Invokes the cleaning action on the object associated with an address + */ + public static void release(Ptr ptr) { + RefHawk.addrs.remove(ptr.addr); + } + + /** + * Throw `IllegalStorageOfRefError` if object has been leaked + * outside of its appropriate context. + */ + public static void ensureValid(Ptr ptr) { + if (!RefHawk.addrs.contains(ptr.addr)) { + throw new IllegalStorageOfRefError(ptr); + } + } +} diff --git a/src/main/java/dev.toad/msg/MessageOptionRef.java b/src/main/java/dev.toad/msg/MessageOptionRef.java index 2fb38a6..80e7c16 100644 --- a/src/main/java/dev.toad/msg/MessageOptionRef.java +++ b/src/main/java/dev.toad/msg/MessageOptionRef.java @@ -1,17 +1,19 @@ package dev.toad.msg; +import dev.toad.RefHawk; +import dev.toad.RefHawk.Ptr; import java.util.Arrays; import java.util.List; -public class MessageOptionRef implements MessageOption { +public class MessageOptionRef implements MessageOption, AutoCloseable { - private final long addr; + private Ptr ptr; private final long number; - private native MessageOptionValueRef[] values(long addr); + private native MessageOptionValueRef[] values(long ptr); public MessageOptionRef(long addr, long number) { - this.addr = addr; + this.ptr = RefHawk.register(this.getClass(), addr); this.number = number; } @@ -20,14 +22,19 @@ public class MessageOptionRef implements MessageOption { } public MessageOptionValueRef[] valueRefs() { - return this.values(this.addr); + return this.values(this.ptr.addr()); } public List values() { - return Arrays.asList(this.values(this.addr)); + return Arrays.asList(this.values(this.ptr.addr())); } public MessageOption clone() { return new MessageOptionOwned(this); } + + @Override + public void close() { + RefHawk.release(this.ptr); + } } diff --git a/src/main/java/dev.toad/msg/MessageOptionValueRef.java b/src/main/java/dev.toad/msg/MessageOptionValueRef.java index 04ae6ab..f34533e 100644 --- a/src/main/java/dev.toad/msg/MessageOptionValueRef.java +++ b/src/main/java/dev.toad/msg/MessageOptionValueRef.java @@ -1,24 +1,33 @@ package dev.toad.msg; -public class MessageOptionValueRef implements MessageOptionValue { +import dev.toad.RefHawk; +import dev.toad.RefHawk.Ptr; - private final long addr; +public class MessageOptionValueRef + implements MessageOptionValue, AutoCloseable { + + private final Ptr ptr; private native byte[] bytes(long addr); public MessageOptionValueRef(long addr) { - this.addr = addr; + this.ptr = RefHawk.register(this.getClass(), addr); } public byte[] asBytes() { - return this.bytes(this.addr); + return this.bytes(this.ptr.addr()); } public String asString() { - return new String(this.bytes(this.addr)); + return new String(this.bytes(this.ptr.addr())); } public MessageOptionValue clone() { return this; } + + @Override + public void close() { + RefHawk.release(this.ptr); + } } diff --git a/src/main/java/dev.toad/msg/MessageRef.java b/src/main/java/dev.toad/msg/MessageRef.java index 6509e4e..6740c74 100644 --- a/src/main/java/dev.toad/msg/MessageRef.java +++ b/src/main/java/dev.toad/msg/MessageRef.java @@ -1,5 +1,7 @@ package dev.toad.msg; +import dev.toad.RefHawk; +import dev.toad.RefHawk.Ptr; import java.util.Arrays; import java.util.List; @@ -10,9 +12,9 @@ import java.util.List; * control is yielded back to the rust runtime, meaning instances of * MessageRef should never be stored in state; invoke `.clone()` first. */ -public class MessageRef implements Message { +public class MessageRef implements Message, AutoCloseable { - private final long addr; + private Ptr ptr; private static native int id(long addr); @@ -27,7 +29,7 @@ public class MessageRef implements Message { private static native MessageOptionRef[] opts(long addr); public MessageRef(long addr) { - this.addr = addr; + this.ptr = RefHawk.register(this.getClass(), addr); } public Message clone() { @@ -35,34 +37,39 @@ public class MessageRef implements Message { } public int id() { - return this.id(this.addr); + return this.id(this.ptr.addr()); } public byte[] token() { - return this.token(this.addr); + return this.token(this.ptr.addr()); } public MessageCode code() { - return this.code(this.addr); + return this.code(this.ptr.addr()); } public MessageType type() { - return this.type(this.addr); + return this.type(this.ptr.addr()); } public MessageOptionRef[] optionRefs() { - return this.opts(this.addr); + return this.opts(this.ptr.addr()); } public List options() { - return Arrays.asList(this.opts(this.addr)); + return Arrays.asList(this.opts(this.ptr.addr())); } public byte[] payloadBytes() { - return this.payload(this.addr); + return this.payload(this.ptr.addr()); } public String payloadString() { - return new String(this.payload(this.addr)); + return new String(this.payload(this.ptr.addr())); + } + + @Override + public void close() { + RefHawk.release(this.ptr); } }