feat: rework memory allocation to throw when pointers are invalid, fix pointers yielded to java, add pattern for yielding stable pointers to data not owned by Runtime

This commit is contained in:
Orion Kindel 2023-04-09 20:44:39 -07:00
parent e94f94d1a2
commit 44ffe4f073
Signed by untrusted user who does not match committer: orion
GPG Key ID: 6D4165AE4C928719
13 changed files with 315 additions and 93 deletions

View File

@ -14,7 +14,13 @@ val path = settingKey[Map[String, String]]("paths")
fork := true
javaOptions += "--enable-preview"
javacOptions ++= Seq("--enable-preview", "--release", "20")
javacOptions ++= Seq(
"--enable-preview",
"--release",
"20",
"-Xlint:unchecked",
"-Xlint:deprecation"
)
lazy val root = project
.in(file("."))

View File

@ -1,37 +1,82 @@
use std::ptr::drop_in_place;
use std::sync::Once;
use std::time::Duration;
use jni::objects::GlobalRef;
use no_std_net::SocketAddr;
use toad::net::Addrd;
use toad::platform::Platform;
use toad_jni::java::{self, Object};
use toad_jni::java::lang::System;
use toad_jni::java::{self, Object, Signature};
use toad_msg::alloc::Message;
use toad_msg::{Code, Id, Token, Type};
use crate::message_ref::MessageRef;
use crate::message_type::MessageType;
use crate::runtime::Runtime;
use crate::runtime_config::RuntimeConfig;
pub fn runtime_init<'a>() -> (Runtime, java::Env<'a>) {
#[non_exhaustive]
struct State {
pub runtime: Runtime,
pub env: java::Env<'static>,
pub client: crate::Runtime,
pub srv_addr: SocketAddr,
}
fn init() -> State {
let mut _env = crate::test::init();
let env = &mut _env;
let cfg = RuntimeConfig::new(env);
let runtime = Runtime::get_or_init(env, cfg);
(runtime, _env)
let client = crate::Runtime::try_new("0.0.0.0:5684", Default::default()).unwrap();
State { runtime,
env: _env,
client,
srv_addr: "0.0.0.0:5683".parse().unwrap() }
}
fn runtime_poll_req(runtime: &Runtime, env: &mut java::Env) {
fn runtime_poll_req(State { runtime,
env,
client,
srv_addr,
.. }: &mut State) {
assert!(runtime.poll_req(env).is_none());
let client = crate::Runtime::try_new("0.0.0.0:5684", Default::default()).unwrap();
let request = Message::new(Type::Con, Code::GET, Id(0), Token(Default::default()));
client.send_msg(Addrd(request, "0.0.0.0:5683".parse().unwrap()))
.unwrap();
client.send_msg(Addrd(request, *srv_addr)).unwrap();
assert!(runtime.poll_req(env).is_some());
}
fn message_ref_should_throw_when_used_after_close(State {runtime, env, client, srv_addr, ..}: &mut State)
{
let request = Message::new(Type::Con, Code::GET, Id(0), Token(Default::default()));
client.send_msg(Addrd(request, *srv_addr)).unwrap();
let req = runtime.poll_req(env).unwrap();
assert_eq!(req.ty(env), Type::Con);
req.close(env);
let req_o = req.downcast(env);
env.call_method(req_o.as_local(),
"type",
Signature::of::<fn() -> MessageType>(),
&[])
.ok();
let err = env.exception_occurred().unwrap();
env.exception_clear().unwrap();
assert!(env.is_instance_of(err,
concat!(package!(dev.toad.RefHawk), "$IllegalStorageOfRefError"))
.unwrap());
}
#[test]
fn e2e_test_suite() {
let (runtime, mut env) = runtime_init();
runtime_poll_req(&runtime, &mut env);
let mut state = init();
runtime_poll_req(&mut state);
message_ref_should_throw_when_used_after_close(&mut state);
}

View File

@ -3,7 +3,7 @@
use std::ffi::c_void;
use jni::JavaVM;
use mem::RuntimeAllocator;
use mem::SharedMemoryRegion;
pub type Runtime =
toad::std::Platform<toad::std::dtls::N, toad::step::runtime::std::Runtime<toad::std::dtls::N>>;
@ -34,7 +34,7 @@ pub extern "system" fn JNI_OnLoad(jvm: JavaVM, _: *const c_void) -> i32 {
#[no_mangle]
pub extern "system" fn JNI_OnUnload(_: JavaVM, _: *const c_void) {
unsafe { mem::Runtime::dealloc() }
unsafe { mem::Shared::dealloc() }
}
#[cfg(all(test, feature = "e2e"))]

View File

@ -1,54 +1,88 @@
/// global [`RuntimeAllocator`] implementation
pub type Runtime = RuntimeGlobalStaticAllocator;
use std::sync::Mutex;
/// Trait managing the memory region(s) associated with the toad runtime
/// data structure.
///
/// Notably, any and all references produced by the runtime will be to data
/// within the Runtime's memory region, meaning that we can easily leverage
/// strict provenance to prevent addresses from leaking outside of that memory region.
pub trait RuntimeAllocator: core::default::Default + core::fmt::Debug + Copy {
use toad_msg::alloc::Message;
/// global [`RuntimeAllocator`] implementation
pub type Shared = GlobalStatic;
/// Trait managing the memory region(s) which java will store pointers to
pub trait SharedMemoryRegion: core::default::Default + core::fmt::Debug + Copy {
/// Allocate memory for the runtime and yield a stable pointer to it
unsafe fn alloc(r: impl FnOnce() -> crate::Runtime) -> *mut crate::Runtime;
///
/// This is idempotent and will only invoke the provided callback if the runtime
/// has not already been initialized.
unsafe fn init(r: impl FnOnce() -> crate::Runtime) -> *mut crate::Runtime;
/// Pass ownership of a [`Message`] to the shared memory region,
/// yielding a stable pointer to this message.
unsafe fn alloc_message(m: Message) -> *mut Message;
/// Delete a message from the shared memory region.
unsafe fn dealloc_message(m: *mut Message);
/// Teardown
unsafe fn dealloc() {}
unsafe fn dealloc();
/// Coerce a `long` rep of the stable pointer created by [`Self::alloc`] to
/// a pointer (preferably using strict_provenance)
unsafe fn deref(addr: i64) -> *mut crate::Runtime;
unsafe fn shared_region(addr: i64) -> *mut u8;
/// Coerce a `long` rep of a pointer to some data within the
/// Runtime data structure.
///
/// Requires the Runtime address in order for the new pointer
/// to inherit its provenance.
unsafe fn deref_inner<T>(runtime_addr: i64, addr: i64) -> *mut T {
Self::deref(runtime_addr).with_addr(addr as usize)
/// shared memory region.
unsafe fn deref<T>(shared_region_addr: i64, addr: i64) -> *mut T {
Self::shared_region(shared_region_addr).with_addr(addr as usize)
.cast::<T>()
}
}
static mut RUNTIME: *mut crate::Runtime = core::ptr::null_mut();
static mut MEM: *mut Mem = core::ptr::null_mut();
struct Mem {
runtime: crate::Runtime,
messages: Vec<Message>,
/// Lock used by `alloc_message` and `dealloc_message` to ensure
/// they are run serially.
///
/// This doesn't provide any guarantees that message pointers will
/// stay valid or always point to the correct location, but it does
/// ensure we don't accidentally yield the wrong pointer from `alloc_message`
/// or delete the wrong message in `dealloc_message`.
messages_lock: Mutex<()>,
}
#[derive(Default, Debug, Clone, Copy)]
pub struct RuntimeGlobalStaticAllocator;
impl RuntimeAllocator for RuntimeGlobalStaticAllocator {
/// Nops on already-init
unsafe fn alloc(r: impl FnOnce() -> crate::Runtime) -> *mut crate::Runtime {
if RUNTIME.is_null() {
RUNTIME = Box::into_raw(Box::new(r()));
RUNTIME
} else {
RUNTIME
}
}
pub struct GlobalStatic;
impl SharedMemoryRegion for GlobalStatic {
unsafe fn dealloc() {
drop(Box::from_raw(RUNTIME));
drop(Box::from_raw(MEM));
}
unsafe fn deref(_: i64) -> *mut crate::Runtime {
RUNTIME
unsafe fn init(r: impl FnOnce() -> crate::Runtime) -> *mut crate::Runtime {
if MEM.is_null() {
MEM = Box::into_raw(Box::new(Mem { runtime: r(),
messages: vec![],
messages_lock: Mutex::new(()) }));
}
&mut (*MEM).runtime as _
}
unsafe fn alloc_message(m: Message) -> *mut Message {
let _lock = (*MEM).messages_lock.lock();
(*MEM).messages.push(m);
&mut (*MEM).messages[(*MEM).messages.len() - 1] as _
}
unsafe fn dealloc_message(m: *mut Message) {
let _lock = (*MEM).messages_lock.lock();
let ix = m.offset_from((*MEM).messages.as_slice().as_ptr());
if ix.is_negative() {
panic!()
}
(*MEM).messages.remove(ix as usize);
}
unsafe fn shared_region(_: i64) -> *mut u8 {
MEM as _
}
}

View File

@ -4,7 +4,7 @@ use jni::JNIEnv;
use toad_jni::java::{self, Object};
use toad_msg::{OptNumber, OptValue};
use crate::mem::RuntimeAllocator;
use crate::mem::SharedMemoryRegion;
use crate::message_opt_value_ref::MessageOptValueRef;
pub struct MessageOptRef(pub java::lang::Object);
@ -27,7 +27,7 @@ impl MessageOptRef {
}
pub unsafe fn values_ptr<'a>(addr: i64) -> &'a mut Vec<OptValue<Vec<u8>>> {
crate::mem::Runtime::deref_inner::<Vec<OptValue<Vec<u8>>>>(/* TODO */ 0, addr).as_mut()
crate::mem::Shared::deref::<Vec<OptValue<Vec<u8>>>>(/* TODO */ 0, addr).as_mut()
.unwrap()
}
}
@ -53,5 +53,5 @@ pub extern "system" fn Java_dev_toad_msg_MessageOptionRef_values<'local>(mut e:
.map(|v| MessageOptValueRef::new(&mut e, (&v.0 as *const Vec<u8>).addr() as i64))
.collect::<Vec<_>>();
refs.downcast(&mut e).as_raw()
refs.yield_to_java(&mut e)
}

View File

@ -3,7 +3,7 @@ use jni::sys::jobject;
use toad_jni::java;
use toad_msg::OptValue;
use crate::mem::RuntimeAllocator;
use crate::mem::SharedMemoryRegion;
pub struct MessageOptValueRef(java::lang::Object);
@ -19,7 +19,7 @@ impl MessageOptValueRef {
}
pub unsafe fn ptr<'a>(addr: i64) -> &'a mut OptValue<Vec<u8>> {
crate::mem::Runtime::deref_inner::<OptValue<Vec<u8>>>(/* TODO */ 0, addr).as_mut()
crate::mem::Shared::deref::<OptValue<Vec<u8>>>(/* TODO */ 0, addr).as_mut()
.unwrap()
}
}

View File

@ -2,8 +2,9 @@ use jni::objects::JClass;
use jni::sys::jobject;
use toad_jni::java::{self, Object};
use toad_msg::alloc::Message;
use toad_msg::Type;
use crate::mem::RuntimeAllocator;
use crate::mem::{Shared, SharedMemoryRegion};
use crate::message_code::MessageCode;
use crate::message_opt_ref::MessageOptRef;
use crate::message_type::MessageType;
@ -16,13 +17,24 @@ impl java::Class for MessageRef {
}
impl MessageRef {
pub fn new(env: &mut java::Env, addr: *const Message) -> Self {
pub fn new(env: &mut java::Env, message: Message) -> Self {
let ptr = unsafe { Shared::alloc_message(message) };
static CTOR: java::Constructor<MessageRef, fn(i64)> = java::Constructor::new();
CTOR.invoke(env, addr.addr() as i64)
CTOR.invoke(env, ptr.addr() as i64)
}
pub fn close(&self, env: &mut java::Env) {
static CLOSE: java::Method<MessageRef, fn()> = java::Method::new("close");
CLOSE.invoke(env, self)
}
pub fn ty(&self, env: &mut java::Env) -> Type {
static TYPE: java::Method<MessageRef, fn() -> MessageType> = java::Method::new("type");
TYPE.invoke(env, self).to_toad(env)
}
pub unsafe fn ptr<'a>(addr: i64) -> &'a mut Message {
crate::mem::Runtime::deref_inner::<Message>(/* TODO */ 0, addr).as_mut()
crate::mem::Shared::deref::<Message>(/* TODO */ 0, addr).as_mut()
.unwrap()
}
}
@ -60,7 +72,7 @@ pub extern "system" fn Java_dev_toad_msg_MessageRef_type<'local>(mut e: java::En
addr: i64)
-> jobject {
let msg = unsafe { MessageRef::ptr(addr) };
MessageType::new(&mut e, msg.ty).downcast(&mut e).as_raw()
MessageType::new(&mut e, msg.ty).yield_to_java(&mut e)
}
#[no_mangle]
@ -69,7 +81,7 @@ pub extern "system" fn Java_dev_toad_msg_MessageRef_code<'local>(mut e: java::En
addr: i64)
-> jobject {
let msg = unsafe { MessageRef::ptr(addr) };
MessageCode::new(&mut e, msg.code).downcast(&mut e).as_raw()
MessageCode::new(&mut e, msg.code).yield_to_java(&mut e)
}
#[no_mangle]
@ -84,5 +96,5 @@ pub extern "system" fn Java_dev_toad_msg_MessageRef_opts<'local>(mut e: java::En
.map(|(n, v)| MessageOptRef::new(&mut e, v as *const _ as i64, n.0.into()))
.collect::<Vec<_>>();
refs.downcast(&mut e).as_raw()
refs.yield_to_java(&mut e)
}

View File

@ -22,4 +22,16 @@ impl MessageType {
FROM_STRING.invoke(env, str.to_string())
}
pub fn to_toad(&self, env: &mut java::Env) -> Type {
static TO_STRING: java::Method<MessageType, fn() -> String> = java::Method::new("toString");
match TO_STRING.invoke(env, self).trim().to_uppercase().as_str() {
| "CON" => Type::Con,
| "NON" => Type::Non,
| "ACK" => Type::Ack,
| "RESET" => Type::Reset,
| o => panic!("malformed message type {}", o),
}
}
}

View File

@ -3,7 +3,7 @@ use jni::sys::jobject;
use toad::platform::Platform;
use toad_jni::java::{self, Object};
use crate::mem::RuntimeAllocator;
use crate::mem::SharedMemoryRegion;
use crate::message_ref::MessageRef;
use crate::runtime_config::RuntimeConfig;
use crate::Runtime as ToadRuntime;
@ -29,19 +29,22 @@ impl Runtime {
}
pub fn ref_(&self, e: &mut java::Env) -> &'static ToadRuntime {
unsafe { crate::mem::Runtime::deref(self.addr(e)).as_ref().unwrap() }
unsafe {
crate::mem::Shared::deref::<ToadRuntime>(0, self.addr(e)).as_ref()
.unwrap()
}
}
fn init_impl(e: &mut java::Env, cfg: RuntimeConfig) -> i64 {
let r =
|| ToadRuntime::try_new(format!("0.0.0.0:{}", cfg.net(e).port(e)), cfg.to_toad(e)).unwrap();
unsafe { crate::mem::Runtime::alloc(r).addr() as i64 }
unsafe { crate::mem::Shared::init(r).addr() as i64 }
}
fn poll_req_impl(&self, e: &mut java::Env) -> java::util::Optional<MessageRef> {
match self.ref_(e).poll_req() {
| Ok(req) => {
let mr = MessageRef::new(e, req.data().msg());
let mr = MessageRef::new(e, req.unwrap().into());
java::util::Optional::<MessageRef>::of(e, mr)
},
| Err(nb::Error::WouldBlock) => java::util::Optional::<MessageRef>::empty(e),
@ -77,7 +80,5 @@ pub extern "system" fn Java_dev_toad_Runtime_pollReq<'local>(mut e: java::Env<'l
let e = &mut e;
java::lang::Object::from_local(e, runtime).upcast_to::<Runtime>(e)
.poll_req_impl(e)
.downcast(e)
.to_local(e)
.as_raw()
.yield_to_java(e)
}

View File

@ -0,0 +1,89 @@
package dev.toad;
import java.lang.ref.Cleaner;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
/**
* Static class used to track pointers issued by rust code
*
* When an object instance containing a pointer tracked by RefHawk
* is not automatically freed before `RefHawk.ensureReleased` invoked,
* an error is thrown indicating incorrect usage of an object containing
* a native pointer.
*/
public class RefHawk {
private static final HashSet<Long> addrs = new HashSet<>();
private RefHawk() {}
public static class IllegalStorageOfRefError extends Error {
private static final String fmt =
"Instance of %s may not be stored by user code.\n" +
"Object was registered by:\n" +
">>>>>>\n" +
"%s\n" +
"<<<<<<\n";
IllegalStorageOfRefError(Ptr ptr) {
super(String.format(fmt, ptr.clazz, ptr.trace));
}
}
public static class Ptr {
protected final long addr;
private final String clazz;
private final String trace;
Ptr(long addr, String clazz, String trace) {
this.clazz = clazz;
this.addr = addr;
this.trace = trace;
}
public long addr() {
RefHawk.ensureValid(this);
return this.addr;
}
}
/**
* Associate an object with a raw `long` pointer and a short text
* describing the scope in which the object is intended to be valid for
* (e.g. "lambda")
*/
public static Ptr register(Class c, long addr) {
var trace = Thread.currentThread().getStackTrace();
var traceStr = Arrays
.asList(trace)
.stream()
.skip(2)
.map(StackTraceElement::toString)
.reduce("", (s, tr) -> s == "" ? tr : s + "\n\t" + tr);
RefHawk.addrs.add(addr);
return new Ptr(addr, c.toString(), traceStr);
}
/**
* Invokes the cleaning action on the object associated with an address
*/
public static void release(Ptr ptr) {
RefHawk.addrs.remove(ptr.addr);
}
/**
* Throw `IllegalStorageOfRefError` if object has been leaked
* outside of its appropriate context.
*/
public static void ensureValid(Ptr ptr) {
if (!RefHawk.addrs.contains(ptr.addr)) {
throw new IllegalStorageOfRefError(ptr);
}
}
}

View File

@ -1,17 +1,19 @@
package dev.toad.msg;
import dev.toad.RefHawk;
import dev.toad.RefHawk.Ptr;
import java.util.Arrays;
import java.util.List;
public class MessageOptionRef implements MessageOption {
public class MessageOptionRef implements MessageOption, AutoCloseable {
private final long addr;
private Ptr ptr;
private final long number;
private native MessageOptionValueRef[] values(long addr);
private native MessageOptionValueRef[] values(long ptr);
public MessageOptionRef(long addr, long number) {
this.addr = addr;
this.ptr = RefHawk.register(this.getClass(), addr);
this.number = number;
}
@ -20,14 +22,19 @@ public class MessageOptionRef implements MessageOption {
}
public MessageOptionValueRef[] valueRefs() {
return this.values(this.addr);
return this.values(this.ptr.addr());
}
public List<MessageOptionValue> values() {
return Arrays.asList(this.values(this.addr));
return Arrays.asList(this.values(this.ptr.addr()));
}
public MessageOption clone() {
return new MessageOptionOwned(this);
}
@Override
public void close() {
RefHawk.release(this.ptr);
}
}

View File

@ -1,24 +1,33 @@
package dev.toad.msg;
public class MessageOptionValueRef implements MessageOptionValue {
import dev.toad.RefHawk;
import dev.toad.RefHawk.Ptr;
private final long addr;
public class MessageOptionValueRef
implements MessageOptionValue, AutoCloseable {
private final Ptr ptr;
private native byte[] bytes(long addr);
public MessageOptionValueRef(long addr) {
this.addr = addr;
this.ptr = RefHawk.register(this.getClass(), addr);
}
public byte[] asBytes() {
return this.bytes(this.addr);
return this.bytes(this.ptr.addr());
}
public String asString() {
return new String(this.bytes(this.addr));
return new String(this.bytes(this.ptr.addr()));
}
public MessageOptionValue clone() {
return this;
}
@Override
public void close() {
RefHawk.release(this.ptr);
}
}

View File

@ -1,5 +1,7 @@
package dev.toad.msg;
import dev.toad.RefHawk;
import dev.toad.RefHawk.Ptr;
import java.util.Arrays;
import java.util.List;
@ -10,9 +12,9 @@ import java.util.List;
* control is yielded back to the rust runtime, meaning instances of
* MessageRef should never be stored in state; invoke `.clone()` first.
*/
public class MessageRef implements Message {
public class MessageRef implements Message, AutoCloseable {
private final long addr;
private Ptr ptr;
private static native int id(long addr);
@ -27,7 +29,7 @@ public class MessageRef implements Message {
private static native MessageOptionRef[] opts(long addr);
public MessageRef(long addr) {
this.addr = addr;
this.ptr = RefHawk.register(this.getClass(), addr);
}
public Message clone() {
@ -35,34 +37,39 @@ public class MessageRef implements Message {
}
public int id() {
return this.id(this.addr);
return this.id(this.ptr.addr());
}
public byte[] token() {
return this.token(this.addr);
return this.token(this.ptr.addr());
}
public MessageCode code() {
return this.code(this.addr);
return this.code(this.ptr.addr());
}
public MessageType type() {
return this.type(this.addr);
return this.type(this.ptr.addr());
}
public MessageOptionRef[] optionRefs() {
return this.opts(this.addr);
return this.opts(this.ptr.addr());
}
public List<MessageOption> options() {
return Arrays.asList(this.opts(this.addr));
return Arrays.asList(this.opts(this.ptr.addr()));
}
public byte[] payloadBytes() {
return this.payload(this.addr);
return this.payload(this.ptr.addr());
}
public String payloadString() {
return new String(this.payload(this.addr));
return new String(this.payload(this.ptr.addr()));
}
@Override
public void close() {
RefHawk.release(this.ptr);
}
}