rust-postgres/tokio-postgres/tests/test/main.rs
Steven Fackler 1f6d9ddc06 Overhaul query cancellation
Multi-host support means we can't simply take the old approach - we need
to know which of the hosts we actually connected to. It's also nice to
move this from the connection to the client since that's what you'd
normally have access to.
2019-01-06 18:03:51 -08:00

780 lines
23 KiB
Rust

#![warn(rust_2018_idioms)]
use futures::sync::mpsc;
use futures::{future, stream, try_ready};
use log::debug;
use std::error::Error;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::prelude::*;
use tokio::runtime::current_thread::Runtime;
use tokio::timer::Delay;
use tokio_postgres::error::SqlState;
use tokio_postgres::types::{Kind, Type};
use tokio_postgres::{AsyncMessage, Client, Connection, NoTls};
mod parse;
#[cfg(feature = "runtime")]
mod runtime;
mod types;
fn connect(
s: &str,
) -> impl Future<Item = (Client, Connection<TcpStream>), Error = tokio_postgres::Error> {
let builder = s.parse::<tokio_postgres::Config>().unwrap();
TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
.map_err(|e| panic!("{}", e))
.and_then(move |s| builder.handshake(s, NoTls))
}
fn smoke_test(s: &str) {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = connect(s);
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let prepare = client.prepare("SELECT 1::INT4");
let statement = runtime.block_on(prepare).unwrap();
let select = client.query(&statement, &[]).collect().map(|rows| {
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);
});
runtime.block_on(select).unwrap();
drop(statement);
drop(client);
runtime.run().unwrap();
}
#[test]
fn plain_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = connect("user=pass_user dbname=postgres");
runtime.block_on(handshake).err().unwrap();
}
#[test]
fn plain_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = connect("user=pass_user password=foo dbname=postgres");
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn plain_password_ok() {
smoke_test("user=pass_user password=password dbname=postgres");
}
#[test]
fn md5_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = connect("user=md5_user dbname=postgres");
runtime.block_on(handshake).err().unwrap();
}
#[test]
fn md5_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = connect("user=md5_user password=foo dbname=postgres");
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn md5_password_ok() {
smoke_test("user=md5_user password=password dbname=postgres");
}
#[test]
fn scram_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = connect("user=scram_user dbname=postgres");
runtime.block_on(handshake).err().unwrap();
}
#[test]
fn scram_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = connect("user=scram_user password=foo dbname=postgres");
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn scram_password_ok() {
smoke_test("user=scram_user password=password dbname=postgres");
}
#[test]
fn pipelined_prepare() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let prepare1 = client.prepare("SELECT $1::HSTORE[]");
let prepare2 = client.prepare("SELECT $1::HSTORE[]");
let prepare = prepare1.join(prepare2);
runtime.block_on(prepare).unwrap();
drop(client);
runtime.run().unwrap();
}
#[test]
fn insert_select() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)"))
.unwrap();
let insert = client.prepare("INSERT INTO foo (name) VALUES ($1), ($2)");
let select = client.prepare("SELECT id, name FROM foo ORDER BY id");
let prepare = insert.join(select);
let (insert, select) = runtime.block_on(prepare).unwrap();
let insert = client
.execute(&insert, &[&"alice", &"bob"])
.map(|n| assert_eq!(n, 2));
let select = client.query(&select, &[]).collect().map(|rows| {
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[0].get::<_, &str>(1), "alice");
assert_eq!(rows[1].get::<_, i32>(0), 2);
assert_eq!(rows[1].get::<_, &str>(1), "bob");
});
let tests = insert.join(select);
runtime.block_on(tests).unwrap();
}
#[test]
fn query_portal() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT);
INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('charlie');
BEGIN;",
))
.unwrap();
let statement = runtime
.block_on(client.prepare("SELECT id, name FROM foo ORDER BY id"))
.unwrap();
let portal = runtime.block_on(client.bind(&statement, &[])).unwrap();
let f1 = client.query_portal(&portal, 2).collect();
let f2 = client.query_portal(&portal, 2).collect();
let f3 = client.query_portal(&portal, 2).collect();
let (r1, r2, r3) = runtime.block_on(f1.join3(f2, f3)).unwrap();
assert_eq!(r1.len(), 2);
assert_eq!(r1[0].get::<_, i32>(0), 1);
assert_eq!(r1[0].get::<_, &str>(1), "alice");
assert_eq!(r1[1].get::<_, i32>(0), 2);
assert_eq!(r1[1].get::<_, &str>(1), "bob");
assert_eq!(r2.len(), 1);
assert_eq!(r2[0].get::<_, i32>(0), 3);
assert_eq!(r2[0].get::<_, &str>(1), "charlie");
assert_eq!(r3.len(), 0);
}
#[test]
fn cancel_query_raw() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let sleep = client
.batch_execute("SELECT pg_sleep(100)")
.then(|r| match r {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()),
Err(e) => panic!("unexpected error {}", e),
});
let cancel = Delay::new(Instant::now() + Duration::from_millis(100))
.then(|r| {
r.unwrap();
TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
})
.then(|r| {
let s = r.unwrap();
client.cancel_query_raw(s, NoTls)
})
.then(|r| {
r.unwrap();
Ok::<(), ()>(())
});
let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap();
}
#[test]
fn custom_enum() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TYPE pg_temp.mood AS ENUM (
'sad',
'ok',
'happy'
)",
))
.unwrap();
let select = client.prepare("SELECT $1::mood");
let select = runtime.block_on(select).unwrap();
let ty = &select.params()[0];
assert_eq!("mood", ty.name());
assert_eq!(
&Kind::Enum(vec![
"sad".to_string(),
"ok".to_string(),
"happy".to_string(),
]),
ty.kind()
);
}
#[test]
fn custom_domain() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)",
))
.unwrap();
let select = client.prepare("SELECT $1::session_id");
let select = runtime.block_on(select).unwrap();
let ty = &select.params()[0];
assert_eq!("session_id", ty.name());
assert_eq!(&Kind::Domain(Type::BYTEA), ty.kind());
}
#[test]
fn custom_array() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let select = client.prepare("SELECT $1::HSTORE[]");
let select = runtime.block_on(select).unwrap();
let ty = &select.params()[0];
assert_eq!("_hstore", ty.name());
match *ty.kind() {
Kind::Array(ref ty) => {
assert_eq!("hstore", ty.name());
assert_eq!(&Kind::Simple, ty.kind());
}
_ => panic!("unexpected kind"),
}
}
#[test]
fn custom_composite() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier INTEGER,
price NUMERIC
)",
))
.unwrap();
let select = client.prepare("SELECT $1::inventory_item");
let select = runtime.block_on(select).unwrap();
let ty = &select.params()[0];
assert_eq!(ty.name(), "inventory_item");
match *ty.kind() {
Kind::Composite(ref fields) => {
assert_eq!(fields[0].name(), "name");
assert_eq!(fields[0].type_(), &Type::TEXT);
assert_eq!(fields[1].name(), "supplier");
assert_eq!(fields[1].type_(), &Type::INT4);
assert_eq!(fields[2].name(), "price");
assert_eq!(fields[2].type_(), &Type::NUMERIC);
}
ref t => panic!("bad type {:?}", t),
}
}
#[test]
fn custom_range() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TYPE pg_temp.floatrange AS RANGE (
subtype = float8,
subtype_diff = float8mi
)",
))
.unwrap();
let select = client.prepare("SELECT $1::floatrange");
let select = runtime.block_on(select).unwrap();
let ty = &select.params()[0];
assert_eq!("floatrange", ty.name());
assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind());
}
#[test]
fn custom_simple() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let select = client.prepare("SELECT $1::HSTORE");
let select = runtime.block_on(select).unwrap();
let ty = &select.params()[0];
assert_eq!("hstore", ty.name());
assert_eq!(&Kind::Simple, ty.kind());
}
#[test]
fn notifications() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, mut connection) = runtime.block_on(connect("user=postgres")).unwrap();
let (tx, rx) = mpsc::unbounded();
let connection = future::poll_fn(move || {
while let Some(message) = try_ready!(connection.poll_message().map_err(|e| panic!("{}", e)))
{
if let AsyncMessage::Notification(notification) = message {
debug!("received {}", notification.payload);
tx.unbounded_send(notification).unwrap();
}
}
Ok(Async::Ready(()))
});
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute("LISTEN test_notifications"))
.unwrap();
runtime
.block_on(client.batch_execute("NOTIFY test_notifications, 'hello'"))
.unwrap();
runtime
.block_on(client.batch_execute("NOTIFY test_notifications, 'world'"))
.unwrap();
drop(client);
runtime.run().unwrap();
let notifications = rx.collect().wait().unwrap();
assert_eq!(notifications.len(), 2);
assert_eq!(notifications[0].channel, "test_notifications");
assert_eq!(notifications[0].payload, "hello");
assert_eq!(notifications[1].channel, "test_notifications");
assert_eq!(notifications[1].payload, "world");
}
#[test]
fn transaction_commit() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
)",
))
.unwrap();
let f = client.batch_execute("INSERT INTO foo (name) VALUES ('steven')");
runtime.block_on(client.transaction().build(f)).unwrap();
let rows = runtime
.block_on(
client
.prepare("SELECT name FROM foo")
.and_then(|s| client.query(&s, &[]).collect()),
)
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, &str>(0), "steven");
}
#[test]
fn transaction_abort() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
)",
))
.unwrap();
let f = client
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.map_err(|e| Box::new(e) as Box<dyn Error>)
.and_then(|_| Err::<(), _>(Box::<dyn Error>::from("")));
runtime.block_on(client.transaction().build(f)).unwrap_err();
let rows = runtime
.block_on(
client
.prepare("SELECT name FROM foo")
.and_then(|s| client.query(&s, &[]).collect()),
)
.unwrap();
assert_eq!(rows.len(), 0);
}
#[test]
fn copy_in() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
))
.unwrap();
let stream = stream::iter_ok::<_, String>(vec![b"1\tjim\n".to_vec(), b"2\tjoe\n".to_vec()]);
let rows = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
)
.unwrap();
assert_eq!(rows, 2);
let rows = runtime
.block_on(
client
.prepare("SELECT id, name FROM foo ORDER BY id")
.and_then(|s| client.query(&s, &[]).collect()),
)
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[0].get::<_, &str>(1), "jim");
assert_eq!(rows[1].get::<_, i32>(0), 2);
assert_eq!(rows[1].get::<_, &str>(1), "joe");
}
#[test]
fn copy_in_error() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
))
.unwrap();
let stream = stream::iter_result(vec![Ok(b"1\tjim\n".to_vec()), Err("asdf")]);
let error = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
)
.unwrap_err();
assert!(error.to_string().contains("asdf"));
let rows = runtime
.block_on(
client
.prepare("SELECT id, name FROM foo ORDER BY id")
.and_then(|s| client.query(&s, &[]).collect()),
)
.unwrap();
assert_eq!(rows.len(), 0);
}
#[test]
fn copy_out() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
);
INSERT INTO foo (name) VALUES ('jim'), ('joe');",
))
.unwrap();
let data = runtime
.block_on(
client
.prepare("COPY foo TO STDOUT")
.and_then(|s| client.copy_out(&s, &[]).concat2()),
)
.unwrap();
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
}
#[test]
fn transaction_builder_around_moved_client() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let transaction_builder = client.transaction();
let work = future::lazy(move || {
let execute = client.batch_execute(
"CREATE TEMPORARY TABLE transaction_foo (
id SERIAL,
name TEXT
)",
);
execute.and_then(move |_| {
client
.prepare("INSERT INTO transaction_foo (name) VALUES ($1), ($2)")
.map(|statement| (client, statement))
})
})
.and_then(|(mut client, statement)| {
client
.query(&statement, &[&"jim", &"joe"])
.collect()
.map(|_res| client)
});
let transaction = transaction_builder.build(work);
let mut client = runtime.block_on(transaction).unwrap();
let data = runtime
.block_on(
client
.prepare("COPY transaction_foo TO STDOUT")
.and_then(|s| client.copy_out(&s, &[]).concat2()),
)
.unwrap();
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
drop(client);
runtime.run().unwrap();
}
#[test]
fn poll_idle_running() {
struct DelayStream(Delay);
impl Stream for DelayStream {
type Item = Vec<u8>;
type Error = tokio_postgres::Error;
fn poll(&mut self) -> Poll<Option<Vec<u8>>, tokio_postgres::Error> {
try_ready!(self.0.poll().map_err(|e| panic!("{}", e)));
QUERY_DONE.store(true, Ordering::SeqCst);
Ok(Async::Ready(None))
}
}
struct IdleFuture(tokio_postgres::Client);
impl Future for IdleFuture {
type Item = ();
type Error = tokio_postgres::Error;
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
try_ready!(self.0.poll_idle());
assert!(QUERY_DONE.load(Ordering::SeqCst));
Ok(Async::Ready(()))
}
}
static QUERY_DONE: AtomicBool = AtomicBool::new(false);
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let execute = client.batch_execute("CREATE TEMPORARY TABLE foo (id INT)");
runtime.block_on(execute).unwrap();
let prepare = client.prepare("COPY foo FROM STDIN");
let stmt = runtime.block_on(prepare).unwrap();
let copy_in = client.copy_in(
&stmt,
&[],
DelayStream(Delay::new(Instant::now() + Duration::from_millis(10))),
);
let copy_in = copy_in.map(|_| ()).map_err(|e| panic!("{}", e));
runtime.spawn(copy_in);
let future = IdleFuture(client);
runtime.block_on(future).unwrap();
}
#[test]
fn poll_idle_new() {
struct IdleFuture {
client: tokio_postgres::Client,
prepare: Option<tokio_postgres::Prepare>,
}
impl Future for IdleFuture {
type Item = ();
type Error = tokio_postgres::Error;
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
match self.prepare.take() {
Some(_future) => {
assert!(!self.client.poll_idle().unwrap().is_ready());
Ok(Async::NotReady)
}
None => {
assert!(self.client.poll_idle().unwrap().is_ready());
Ok(Async::Ready(()))
}
}
}
}
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let prepare = client.prepare("");
let future = IdleFuture {
client,
prepare: Some(prepare),
};
runtime.block_on(future).unwrap();
}