Change batch_execute into simple_query

Closes #413
This commit is contained in:
Steven Fackler 2019-01-31 20:34:49 -08:00
parent 80ebcd18e8
commit 32e09dbb91
21 changed files with 499 additions and 245 deletions

View File

@ -1,3 +1,4 @@
use fallible_iterator::FallibleIterator;
use futures::{Async, Future, Poll, Stream};
use std::io::{self, Read};
use tokio_postgres::types::{ToSql, Type};
@ -7,7 +8,7 @@ use tokio_postgres::{MakeTlsConnect, Socket, TlsConnect};
#[cfg(feature = "runtime")]
use crate::Config;
use crate::{CopyOutReader, Query, Statement, ToStatement, Transaction};
use crate::{CopyOutReader, Query, SimpleQuery, Statement, ToStatement, Transaction};
pub struct Client(tokio_postgres::Client);
@ -81,12 +82,12 @@ impl Client {
CopyOutReader::new(stream)
}
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
self.0.batch_execute(query).wait()
pub fn simple_query(&mut self, query: &str) -> Result<SimpleQuery<'_>, Error> {
Ok(SimpleQuery::new(self.0.simple_query(query)))
}
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
self.batch_execute("BEGIN")?;
self.simple_query("BEGIN")?.count()?;
Ok(Transaction::new(self))
}

View File

@ -10,6 +10,7 @@ mod copy_out_reader;
mod portal;
mod query;
mod query_portal;
mod simple_query;
mod statement;
mod to_statement;
mod transaction;
@ -25,6 +26,7 @@ pub use crate::copy_out_reader::*;
pub use crate::portal::*;
pub use crate::query::*;
pub use crate::query_portal::*;
pub use crate::simple_query::*;
pub use crate::statement::*;
pub use crate::to_statement::*;
pub use crate::transaction::*;

View File

@ -0,0 +1,42 @@
use fallible_iterator::FallibleIterator;
use futures::stream::{self, Stream};
use std::marker::PhantomData;
use tokio_postgres::impls;
use tokio_postgres::{Error, SimpleQueryMessage};
pub struct SimpleQuery<'a> {
it: stream::Wait<impls::SimpleQuery>,
_p: PhantomData<&'a mut ()>,
}
// no-op impl to extend borrow until drop
impl<'a> Drop for SimpleQuery<'a> {
fn drop(&mut self) {}
}
impl<'a> SimpleQuery<'a> {
pub(crate) fn new(stream: impls::SimpleQuery) -> SimpleQuery<'a> {
SimpleQuery {
it: stream.wait(),
_p: PhantomData,
}
}
/// A convenience API which collects the resulting messages into a `Vec` and returns them.
pub fn into_vec(self) -> Result<Vec<SimpleQueryMessage>, Error> {
self.collect()
}
}
impl<'a> FallibleIterator for SimpleQuery<'a> {
type Item = SimpleQueryMessage;
type Error = Error;
fn next(&mut self) -> Result<Option<SimpleQueryMessage>, Error> {
match self.it.next() {
Some(Ok(row)) => Ok(Some(row)),
Some(Err(e)) => Err(e),
None => Ok(None),
}
}
}

View File

@ -1,3 +1,4 @@
use fallible_iterator::FallibleIterator;
use std::io::Read;
use tokio_postgres::types::Type;
use tokio_postgres::NoTls;
@ -47,7 +48,9 @@ fn transaction_commit() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)")
.simple_query("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)")
.unwrap()
.count()
.unwrap();
let mut transaction = client.transaction().unwrap();
@ -72,7 +75,9 @@ fn transaction_rollback() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)")
.simple_query("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)")
.unwrap()
.count()
.unwrap();
let mut transaction = client.transaction().unwrap();
@ -96,7 +101,9 @@ fn transaction_drop() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)")
.simple_query("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)")
.unwrap()
.count()
.unwrap();
let mut transaction = client.transaction().unwrap();
@ -120,7 +127,9 @@ fn nested_transactions() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)")
.simple_query("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)")
.unwrap()
.count()
.unwrap();
let mut transaction = client.transaction().unwrap();
@ -177,7 +186,9 @@ fn copy_in() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
.simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
.unwrap()
.count()
.unwrap();
client
@ -206,13 +217,12 @@ fn copy_out() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute(
"
CREATE TEMPORARY TABLE foo (id INT, name TEXT);
INSERT INTO foo (id, name) VALUES (1, 'steven'), (2, 'timothy');
",
.simple_query(
"CREATE TEMPORARY TABLE foo (id INT, name TEXT);
INSERT INTO foo (id, name) VALUES (1, 'steven'), (2, 'timothy');",
)
.unwrap()
.count()
.unwrap();
let mut reader = client
@ -224,7 +234,7 @@ fn copy_out() {
assert_eq!(s, "1\tsteven\n2\ttimothy\n");
client.batch_execute("SELECT 1").unwrap();
client.simple_query("SELECT 1").unwrap().count().unwrap();
}
#[test]
@ -232,13 +242,12 @@ fn portal() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute(
"
CREATE TEMPORARY TABLE foo (id INT);
INSERT INTO foo (id) VALUES (1), (2), (3);
",
.simple_query(
"CREATE TEMPORARY TABLE foo (id INT);
INSERT INTO foo (id) VALUES (1), (2), (3);",
)
.unwrap()
.count()
.unwrap();
let mut transaction = client.transaction().unwrap();

View File

@ -1,9 +1,12 @@
use fallible_iterator::FallibleIterator;
use futures::Future;
use std::io::Read;
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::Error;
use crate::{Client, CopyOutReader, Portal, Query, QueryPortal, Statement, ToStatement};
use crate::{
Client, CopyOutReader, Portal, Query, QueryPortal, SimpleQuery, Statement, ToStatement,
};
pub struct Transaction<'a> {
client: &'a mut Client,
@ -30,12 +33,14 @@ impl<'a> Transaction<'a> {
pub fn commit(mut self) -> Result<(), Error> {
self.done = true;
if self.depth == 0 {
self.client.batch_execute("COMMIT")
let it = if self.depth == 0 {
self.client.simple_query("COMMIT")?
} else {
self.client
.batch_execute(&format!("RELEASE sp{}", self.depth))
}
.simple_query(&format!("RELEASE sp{}", self.depth))?
};
it.count()?;
Ok(())
}
pub fn rollback(mut self) -> Result<(), Error> {
@ -44,12 +49,14 @@ impl<'a> Transaction<'a> {
}
fn rollback_inner(&mut self) -> Result<(), Error> {
if self.depth == 0 {
self.client.batch_execute("ROLLBACK")
let it = if self.depth == 0 {
self.client.simple_query("ROLLBACK")?
} else {
self.client
.batch_execute(&format!("ROLLBACK TO sp{}", self.depth))
}
.simple_query(&format!("ROLLBACK TO sp{}", self.depth))?
};
it.count()?;
Ok(())
}
pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
@ -120,14 +127,15 @@ impl<'a> Transaction<'a> {
self.client.copy_out(query, params)
}
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
self.client.batch_execute(query)
pub fn simple_query(&mut self, query: &str) -> Result<SimpleQuery<'_>, Error> {
self.client.simple_query(query)
}
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let depth = self.depth + 1;
self.client
.batch_execute(&format!("SAVEPOINT sp{}", depth))?;
.simple_query(&format!("SAVEPOINT sp{}", depth))?
.count()?;
Ok(Transaction {
client: self.client,
depth,

View File

@ -100,6 +100,6 @@ fn runtime() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let execute = client.batch_execute("SELECT 1");
let execute = client.simple_query("SELECT 1").for_each(|_| Ok(()));
runtime.block_on(execute).unwrap();
}

View File

@ -85,6 +85,6 @@ fn runtime() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let execute = client.batch_execute("SELECT 1");
let execute = client.simple_query("SELECT 1").for_each(|_| Ok(()));
runtime.block_on(execute).unwrap();
}

View File

@ -336,6 +336,7 @@ enum Kind {
Tls,
ToSql,
FromSql,
Column,
CopyInStream,
Closed,
Db,
@ -373,6 +374,7 @@ impl fmt::Display for Error {
Kind::Tls => "error performing TLS handshake",
Kind::ToSql => "error serializing a value",
Kind::FromSql => "error deserializing a value",
Kind::Column => "invalid column",
Kind::CopyInStream => "error from a copy_in stream",
Kind::Closed => "connection closed",
Kind::Db => "db error",
@ -451,6 +453,10 @@ impl Error {
Error::new(Kind::FromSql, Some(e))
}
pub(crate) fn column() -> Error {
Error::new(Kind::Column, None)
}
pub(crate) fn copy_in_stream<E>(e: E) -> Error
where
E: Into<Box<dyn error::Error + Sync + Send>>,

View File

@ -5,7 +5,7 @@ use std::error;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto;
use crate::{Client, Connection, Error, Portal, Row, Statement, TlsConnect};
use crate::{Client, Connection, Error, Portal, Row, SimpleQueryMessage, Statement, TlsConnect};
#[cfg(feature = "runtime")]
use crate::{MakeTlsConnect, Socket};
@ -187,3 +187,16 @@ impl Stream for CopyOut {
self.0.poll()
}
}
/// The stream returned by `Client::simple_query`.
#[must_use = "streams do nothing unless polled"]
pub struct SimpleQuery(pub(crate) proto::SimpleQueryStream);
impl Stream for SimpleQuery {
type Item = SimpleQueryMessage;
type Error = Error;
fn poll(&mut self) -> Poll<Option<SimpleQueryMessage>, Error> {
self.0.poll()
}
}

View File

@ -102,7 +102,7 @@
#![warn(rust_2018_idioms, clippy::all)]
use bytes::IntoBuf;
use futures::{try_ready, Async, Future, Poll, Stream};
use futures::{Future, Poll, Stream};
use std::error::Error as StdError;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio_io::{AsyncRead, AsyncWrite};
@ -240,19 +240,21 @@ impl Client {
impls::CopyOut(self.0.copy_out(&statement.0, params))
}
/// Executes a sequence of SQL statements.
/// Executes a sequence of SQL statements using the simple query protocol.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
/// point. This is intended for the execution of batches of non-dynamic statements, for example, the creation of
/// a schema for a fresh database.
/// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
/// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a stream over the
/// rows, this method returns a stream over an enum which indicates either the completion of one of the commands,
/// or a row of data. This preserves the framing between the separate statements in the request.
///
/// # Warning
///
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely imbed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub fn batch_execute(&mut self, query: &str) -> BatchExecute {
BatchExecute(self.0.batch_execute(query))
pub fn simple_query(&mut self, query: &str) -> impls::SimpleQuery {
impls::SimpleQuery(self.0.simple_query(query))
}
/// A utility method to wrap a future in a database transaction.
@ -445,18 +447,11 @@ where
}
}
#[must_use = "futures do nothing unless polled"]
pub struct BatchExecute(proto::SimpleQueryStream);
impl Future for BatchExecute {
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<(), Error> {
while let Some(_) = try_ready!(self.0.poll()) {}
Ok(Async::Ready(()))
}
pub enum SimpleQueryMessage {
Row(SimpleQueryRow),
CommandComplete(u64),
#[doc(hidden)]
__NonExhaustive,
}
/// An asynchronous notification.

View File

@ -143,7 +143,7 @@ impl Client {
.map_err(|_| Error::closed())
}
pub fn batch_execute(&self, query: &str) -> SimpleQueryStream {
pub fn simple_query(&self, query: &str) -> SimpleQueryStream {
let pending = self.pending(|buf| {
frontend::query(query, buf).map_err(Error::parse)?;
Ok(())

View File

@ -1,6 +1,5 @@
#![allow(clippy::large_enum_variant)]
use fallible_iterator::FallibleIterator;
use futures::{try_ready, Async, Future, Poll, Stream};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::io;
@ -8,7 +7,7 @@ use std::io;
use crate::proto::{
Client, ConnectRawFuture, ConnectSocketFuture, Connection, MaybeTlsStream, SimpleQueryStream,
};
use crate::{Config, Error, Socket, TargetSessionAttrs, TlsConnect};
use crate::{Config, Error, SimpleQueryMessage, Socket, TargetSessionAttrs, TlsConnect};
#[derive(StateMachineFuture)]
pub enum ConnectOnce<T>
@ -75,7 +74,7 @@ where
if let TargetSessionAttrs::ReadWrite = state.target_session_attrs {
transition!(CheckingSessionAttrs {
stream: client.batch_execute("SHOW transaction_read_only"),
stream: client.simple_query("SHOW transaction_read_only"),
client,
connection,
})
@ -87,24 +86,26 @@ where
fn poll_checking_session_attrs<'a>(
state: &'a mut RentToOwn<'a, CheckingSessionAttrs<T>>,
) -> Poll<AfterCheckingSessionAttrs<T>, Error> {
if let Async::Ready(()) = state.connection.poll()? {
return Err(Error::closed());
}
match try_ready!(state.stream.poll()) {
Some(row) => {
let range = row.ranges().next().map_err(Error::parse)?.and_then(|r| r);
if range.map(|r| &row.buffer()[r]) == Some(b"on") {
Err(Error::connect(io::Error::new(
io::ErrorKind::PermissionDenied,
"database does not allow writes",
)))
} else {
let state = state.take();
transition!(Finished((state.client, state.connection)))
}
loop {
if let Async::Ready(()) = state.connection.poll()? {
return Err(Error::closed());
}
match try_ready!(state.stream.poll()) {
Some(SimpleQueryMessage::Row(row)) => {
if row.try_get(0)? == Some("on") {
return Err(Error::connect(io::Error::new(
io::ErrorKind::PermissionDenied,
"database does not allow writes",
)));
} else {
let state = state.take();
transition!(Finished((state.client, state.connection)))
}
}
Some(_) => {}
None => return Err(Error::closed()),
}
None => Err(Error::closed()),
}
}
}

View File

@ -1,10 +1,12 @@
use fallible_iterator::FallibleIterator;
use futures::sync::mpsc;
use futures::{Async, Poll, Stream};
use postgres_protocol::message::backend::{DataRowBody, Message};
use postgres_protocol::message::backend::Message;
use std::mem;
use std::sync::Arc;
use crate::proto::client::{Client, PendingRequest};
use crate::Error;
use crate::{Error, SimpleQueryMessage, SimpleQueryRow};
pub enum State {
Start {
@ -12,6 +14,7 @@ pub enum State {
request: PendingRequest,
},
ReadResponse {
columns: Option<Arc<[String]>>,
receiver: mpsc::Receiver<Message>,
},
Done,
@ -20,35 +23,76 @@ pub enum State {
pub struct SimpleQueryStream(State);
impl Stream for SimpleQueryStream {
type Item = DataRowBody;
type Item = SimpleQueryMessage;
type Error = Error;
fn poll(&mut self) -> Poll<Option<DataRowBody>, Error> {
fn poll(&mut self) -> Poll<Option<SimpleQueryMessage>, Error> {
loop {
match mem::replace(&mut self.0, State::Done) {
State::Start { client, request } => {
let receiver = client.send(request)?;
self.0 = State::ReadResponse { receiver };
self.0 = State::ReadResponse {
columns: None,
receiver,
};
}
State::ReadResponse { mut receiver } => {
State::ReadResponse {
columns,
mut receiver,
} => {
let message = match receiver.poll() {
Ok(Async::Ready(message)) => message,
Ok(Async::NotReady) => {
self.0 = State::ReadResponse { receiver };
self.0 = State::ReadResponse { columns, receiver };
return Ok(Async::NotReady);
}
Err(()) => unreachable!("mpsc receiver can't panic"),
};
match message {
Some(Message::CommandComplete(_))
| Some(Message::RowDescription(_))
| Some(Message::EmptyQueryResponse) => {
self.0 = State::ReadResponse { receiver };
Some(Message::CommandComplete(body)) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
self.0 = State::ReadResponse {
columns: None,
receiver,
};
return Ok(Async::Ready(Some(SimpleQueryMessage::CommandComplete(
rows,
))));
}
Some(Message::EmptyQueryResponse) => {
self.0 = State::ReadResponse {
columns: None,
receiver,
};
return Ok(Async::Ready(Some(SimpleQueryMessage::CommandComplete(0))));
}
Some(Message::RowDescription(body)) => {
let columns = body
.fields()
.map(|f| f.name().to_string())
.collect::<Vec<_>>()
.map_err(Error::parse)?
.into();
self.0 = State::ReadResponse {
columns: Some(columns),
receiver,
};
}
Some(Message::DataRow(body)) => {
self.0 = State::ReadResponse { receiver };
return Ok(Async::Ready(Some(body)));
let row = match &columns {
Some(columns) => SimpleQueryRow::new(columns.clone(), body)?,
None => return Err(Error::unexpected_message()),
};
self.0 = State::ReadResponse { columns, receiver };
return Ok(Async::Ready(Some(SimpleQueryMessage::Row(row))));
}
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(Message::ReadyForQuery(_)) => return Ok(Async::Ready(None)),

View File

@ -42,7 +42,7 @@ where
) -> Poll<AfterStart<F, T, E>, E> {
let state = state.take();
transition!(Beginning {
begin: state.client.batch_execute("BEGIN"),
begin: state.client.simple_query("BEGIN"),
client: state.client,
future: state.future,
})
@ -66,11 +66,11 @@ where
match state.future.poll() {
Ok(Async::NotReady) => Ok(Async::NotReady),
Ok(Async::Ready(t)) => transition!(Finishing {
future: state.client.batch_execute("COMMIT"),
future: state.client.simple_query("COMMIT"),
result: Ok(t),
}),
Err(e) => transition!(Finishing {
future: state.client.batch_execute("ROLLBACK"),
future: state.client.simple_query("ROLLBACK"),
result: Err(e),
}),
}

View File

@ -184,27 +184,13 @@ impl PollTypeinfo for Typeinfo {
None => return Err(Error::unexpected_message()),
};
let name = row
.try_get::<_, String>(0)?
.ok_or_else(Error::unexpected_message)?;
let type_ = row
.try_get::<_, i8>(1)?
.ok_or_else(Error::unexpected_message)?;
let elem_oid = row
.try_get::<_, Oid>(2)?
.ok_or_else(Error::unexpected_message)?;
let rngsubtype = row
.try_get::<_, Option<Oid>>(3)?
.ok_or_else(Error::unexpected_message)?;
let basetype = row
.try_get::<_, Oid>(4)?
.ok_or_else(Error::unexpected_message)?;
let schema = row
.try_get::<_, String>(5)?
.ok_or_else(Error::unexpected_message)?;
let relid = row
.try_get::<_, Oid>(6)?
.ok_or_else(Error::unexpected_message)?;
let name = row.try_get::<_, String>(0)?;
let type_ = row.try_get::<_, i8>(1)?;
let elem_oid = row.try_get::<_, Oid>(2)?;
let rngsubtype = row.try_get::<_, Option<Oid>>(3)?;
let basetype = row.try_get::<_, Oid>(4)?;
let schema = row.try_get::<_, String>(5)?;
let relid = row.try_get::<_, Oid>(6)?;
let kind = if type_ == b'e' as i8 {
transition!(QueryingEnumVariants {

View File

@ -96,8 +96,8 @@ impl PollTypeinfoComposite for TypeinfoComposite {
let fields = rows
.iter()
.map(|row| {
let name = row.try_get(0)?.ok_or_else(Error::unexpected_message)?;
let oid = row.try_get(1)?.ok_or_else(Error::unexpected_message)?;
let name = row.try_get(0)?;
let oid = row.try_get(1)?;
Ok((name, oid))
})
.collect::<Result<Vec<(String, Oid)>, Error>>()?;

View File

@ -124,7 +124,7 @@ impl PollTypeinfoEnum for TypeinfoEnum {
let variants = rows
.iter()
.map(|row| row.try_get(0)?.ok_or_else(Error::unexpected_message))
.map(|row| row.try_get(0))
.collect::<Result<Vec<_>, _>>()?;
transition!(Finished((variants, state.client)))

View File

@ -3,15 +3,32 @@ use postgres_protocol::message::backend::DataRowBody;
use std::fmt;
use std::ops::Range;
use std::str;
use std::sync::Arc;
use crate::proto;
use crate::row::sealed::Sealed;
use crate::row::sealed::{AsName, Sealed};
use crate::stmt::Column;
use crate::types::{FromSql, WrongType};
use crate::types::{FromSql, Type, WrongType};
use crate::Error;
mod sealed {
pub trait Sealed {}
pub trait AsName {
fn as_name(&self) -> &str;
}
}
impl AsName for Column {
fn as_name(&self) -> &str {
self.name()
}
}
impl AsName for String {
fn as_name(&self) -> &str {
self
}
}
/// A trait implemented by types that can index into columns of a row.
@ -19,14 +36,19 @@ mod sealed {
/// This cannot be implemented outside of this crate.
pub trait RowIndex: Sealed {
#[doc(hidden)]
fn __idx(&self, columns: &[Column]) -> Option<usize>;
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName;
}
impl Sealed for usize {}
impl RowIndex for usize {
#[inline]
fn __idx(&self, columns: &[Column]) -> Option<usize> {
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if *self >= columns.len() {
None
} else {
@ -39,8 +61,11 @@ impl Sealed for str {}
impl RowIndex for str {
#[inline]
fn __idx(&self, columns: &[Column]) -> Option<usize> {
if let Some(idx) = columns.iter().position(|d| d.name() == self) {
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if let Some(idx) = columns.iter().position(|d| d.as_name() == self) {
return Some(idx);
};
@ -49,7 +74,7 @@ impl RowIndex for str {
// uses the US locale.
columns
.iter()
.position(|d| d.name().eq_ignore_ascii_case(self))
.position(|d| d.as_name().eq_ignore_ascii_case(self))
}
}
@ -60,7 +85,10 @@ where
T: ?Sized + RowIndex,
{
#[inline]
fn __idx(&self, columns: &[Column]) -> Option<usize> {
fn __idx<U>(&self, columns: &[U]) -> Option<usize>
where
U: AsName,
{
T::__idx(*self, columns)
}
}
@ -100,13 +128,12 @@ impl Row {
T: FromSql<'a>,
{
match self.get_inner(&idx) {
Ok(Some(ok)) => ok,
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
Ok(None) => panic!("no such column {}", idx),
}
}
pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result<Option<T>, Error>
pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result<T, Error>
where
I: RowIndex,
T: FromSql<'a>,
@ -114,14 +141,14 @@ impl Row {
self.get_inner(&idx)
}
fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result<Option<T>, Error>
fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result<T, Error>
where
I: RowIndex,
T: FromSql<'a>,
{
let idx = match idx.__idx(self.columns()) {
Some(idx) => idx,
None => return Ok(None),
None => return Err(Error::column()),
};
let ty = self.columns()[idx].type_();
@ -130,7 +157,62 @@ impl Row {
}
let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]);
let value = FromSql::from_sql_nullable(ty, buf);
value.map(Some).map_err(Error::from_sql)
FromSql::from_sql_nullable(ty, buf).map_err(Error::from_sql)
}
}
pub struct SimpleQueryRow {
columns: Arc<[String]>,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl SimpleQueryRow {
#[allow(clippy::new_ret_no_self)]
pub(crate) fn new(columns: Arc<[String]>, body: DataRowBody) -> Result<SimpleQueryRow, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(SimpleQueryRow {
columns,
body,
ranges,
})
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> usize {
self.columns.len()
}
pub fn get<I>(&self, idx: I) -> Option<&str>
where
I: RowIndex + fmt::Display,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
pub fn try_get<I>(&self, idx: I) -> Result<Option<&str>, Error>
where
I: RowIndex,
{
self.get_inner(&idx)
}
fn get_inner<I>(&self, idx: &I) -> Result<Option<&str>, Error>
where
I: RowIndex,
{
let idx = match idx.__idx(&self.columns) {
Some(idx) => idx,
None => return Err(Error::column()),
};
let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]);
FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(Error::from_sql)
}
}

View File

@ -161,7 +161,11 @@ fn insert_select() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)"))
.block_on(
client
.simple_query("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)")
.for_each(|_| Ok(())),
)
.unwrap();
let insert = client.prepare("INSERT INTO foo (name) VALUES ($1), ($2)");
@ -193,11 +197,15 @@ fn query_portal() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT);
INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('charlie');
BEGIN;",
))
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT);
INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('charlie');
BEGIN;",
)
.for_each(|_| Ok(())),
)
.unwrap();
let statement = runtime
@ -233,7 +241,8 @@ fn cancel_query_raw() {
runtime.handle().spawn(connection).unwrap();
let sleep = client
.batch_execute("SELECT pg_sleep(100)")
.simple_query("SELECT pg_sleep(100)")
.for_each(|_| Ok(()))
.then(|r| match r {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()),
@ -266,13 +275,17 @@ fn custom_enum() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TYPE pg_temp.mood AS ENUM (
'sad',
'ok',
'happy'
)",
))
.block_on(
client
.simple_query(
"CREATE TYPE pg_temp.mood AS ENUM (
'sad',
'ok',
'happy'
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let select = client.prepare("SELECT $1::mood");
@ -300,9 +313,13 @@ fn custom_domain() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)",
))
.block_on(
client
.simple_query(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let select = client.prepare("SELECT $1::session_id");
@ -346,13 +363,17 @@ fn custom_composite() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier INTEGER,
price NUMERIC
)",
))
.block_on(
client
.simple_query(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier INTEGER,
price NUMERIC
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let select = client.prepare("SELECT $1::inventory_item");
@ -383,12 +404,16 @@ fn custom_range() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TYPE pg_temp.floatrange AS RANGE (
subtype = float8,
subtype_diff = float8mi
)",
))
.block_on(
client
.simple_query(
"CREATE TYPE pg_temp.floatrange AS RANGE (
subtype = float8,
subtype_diff = float8mi
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let select = client.prepare("SELECT $1::floatrange");
@ -438,15 +463,15 @@ fn notifications() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute("LISTEN test_notifications"))
.unwrap();
runtime
.block_on(client.batch_execute("NOTIFY test_notifications, 'hello'"))
.unwrap();
runtime
.block_on(client.batch_execute("NOTIFY test_notifications, 'world'"))
.block_on(
client
.simple_query(
"LISTEN test_notifications;
NOTIFY test_notifications, 'hello';
NOTIFY test_notifications, 'world';",
)
.for_each(|_| Ok(())),
)
.unwrap();
drop(client);
@ -470,15 +495,21 @@ fn transaction_commit() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
)",
))
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let f = client.batch_execute("INSERT INTO foo (name) VALUES ('steven')");
let f = client
.simple_query("INSERT INTO foo (name) VALUES ('steven')")
.for_each(|_| Ok(()));
runtime
.block_on(client.build_transaction().build(f))
.unwrap();
@ -505,16 +536,21 @@ fn transaction_abort() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
)",
))
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let f = client
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.simple_query("INSERT INTO foo (name) VALUES ('steven')")
.for_each(|_| Ok(()))
.map_err(|e| Box::new(e) as Box<dyn Error>)
.and_then(|_| Err::<(), _>(Box::<dyn Error>::from("")));
runtime
@ -542,12 +578,16 @@ fn copy_in() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
))
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let stream = stream::iter_ok::<_, String>(vec![b"1\tjim\n".to_vec(), b"2\tjoe\n".to_vec()]);
@ -585,12 +625,16 @@ fn copy_in_error() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
))
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let stream = stream::iter_result(vec![Ok(b"1\tjim\n".to_vec()), Err("asdf")]);
@ -624,13 +668,17 @@ fn copy_out() {
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
);
INSERT INTO foo (name) VALUES ('jim'), ('joe');",
))
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
);
INSERT INTO foo (name) VALUES ('jim'), ('joe');",
)
.for_each(|_| Ok(())),
)
.unwrap();
let data = runtime
@ -654,12 +702,13 @@ fn transaction_builder_around_moved_client() {
let transaction_builder = client.build_transaction();
let work = client
.batch_execute(
.simple_query(
"CREATE TEMPORARY TABLE transaction_foo (
id SERIAL,
name TEXT
)",
)
.for_each(|_| Ok(()))
.and_then(move |_| {
client
.prepare("INSERT INTO transaction_foo (name) VALUES ($1), ($2)")
@ -725,7 +774,9 @@ fn poll_idle_running() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let execute = client.batch_execute("CREATE TEMPORARY TABLE foo (id INT)");
let execute = client
.simple_query("CREATE TEMPORARY TABLE foo (id INT)")
.for_each(|_| Ok(()));
runtime.block_on(execute).unwrap();
let prepare = client.prepare("COPY foo FROM STDIN");

View File

@ -1,4 +1,4 @@
use futures::Future;
use futures::{Future, Stream};
use std::time::{Duration, Instant};
use tokio::runtime::current_thread::Runtime;
use tokio::timer::Delay;
@ -11,7 +11,7 @@ fn smoke_test(s: &str) {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let execute = client.batch_execute("SELECT 1");
let execute = client.simple_query("SELECT 1").for_each(|_| Ok(()));
runtime.block_on(execute).unwrap();
}
@ -80,7 +80,8 @@ fn cancel_query() {
runtime.spawn(connection);
let sleep = client
.batch_execute("SELECT pg_sleep(100)")
.simple_query("SELECT pg_sleep(100)")
.for_each(|_| Ok(()))
.then(|r| match r {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()),

View File

@ -212,12 +212,14 @@ fn test_bpchar_params() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let batch = client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY,
b CHAR(5)
)",
);
let batch = client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY,
b CHAR(5)
)",
)
.for_each(|_| Ok(()));
runtime.block_on(batch).unwrap();
let prepare = client.prepare("INSERT INTO foo (b) VALUES ($1), ($2), ($3)");
@ -245,12 +247,14 @@ fn test_citext_params() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let batch = client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY,
b CITEXT
)",
);
let batch = client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY,
b CITEXT
)",
)
.for_each(|_| Ok(()));
runtime.block_on(batch).unwrap();
let prepare = client.prepare("INSERT INTO foo (b) VALUES ($1), ($2), ($3)");
@ -393,15 +397,16 @@ fn test_slice() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let batch = client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY,
f TEXT
);
let batch = client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY,
f TEXT
);
INSERT INTO foo(f) VALUES ('a'), ('b'), ('c'), ('d');
",
);
INSERT INTO foo(f) VALUES ('a'), ('b'), ('c'), ('d');",
)
.for_each(|_| Ok(()));
runtime.block_on(batch).unwrap();
let prepare = client.prepare("SELECT f FROM foo WHERE id = ANY($1)");
@ -424,11 +429,13 @@ fn test_slice_wrong_type() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let batch = client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY
)",
);
let batch = client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY
)",
)
.for_each(|_| Ok(()));
runtime.block_on(batch).unwrap();
let prepare = client.prepare("SELECT * FROM foo WHERE id = ANY($1)");
@ -507,10 +514,12 @@ fn domain() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let batch = client.batch_execute(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);
CREATE TABLE pg_temp.foo (id pg_temp.session_id);",
);
let batch = client
.simple_query(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);
CREATE TABLE pg_temp.foo (id pg_temp.session_id);",
)
.for_each(|_| Ok(()));
runtime.block_on(batch).unwrap();
let id = SessionId(b"0123456789abcdef".to_vec());
@ -536,13 +545,15 @@ fn composite() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let batch = client.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier INTEGER,
price NUMERIC
)",
);
let batch = client
.simple_query(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier INTEGER,
price NUMERIC
)",
)
.for_each(|_| Ok(()));
runtime.block_on(batch).unwrap();
let prepare = client.prepare("SELECT $1::inventory_item");
@ -571,7 +582,9 @@ fn enum_() {
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let batch = client.batch_execute("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy');");
let batch = client
.simple_query("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy');")
.for_each(|_| Ok(()));
runtime.block_on(batch).unwrap();
let prepare = client.prepare("SELECT $1::mood");