Support copy_in

This commit is contained in:
Steven Fackler 2019-07-30 21:25:30 -07:00
parent 4afd5235db
commit f45884711f
8 changed files with 365 additions and 163 deletions

View File

@ -1,8 +1,8 @@
use crate::client::SocketConfig;
use crate::config::{SslMode, Host};
use crate::config::{Host, SslMode};
use crate::tls::MakeTlsConnect;
use crate::{cancel_query_raw, connect_socket, connect_tls, Error, Socket};
use std::io;
use crate::tls::MakeTlsConnect;
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
@ -10,7 +10,10 @@ pub(crate) async fn cancel_query<T>(
mut tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error> where T: MakeTlsConnect<Socket> {
) -> Result<(), Error>
where
T: MakeTlsConnect<Socket>,
{
let config = match config {
Some(config) => config,
None => {
@ -27,7 +30,9 @@ pub(crate) async fn cancel_query<T>(
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls = tls.make_tls_connect(hostname).map_err(|e| Error::tls(e.into()))?;
let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
let socket = connect_socket::connect_socket(
&config.host,

View File

@ -7,17 +7,19 @@ use crate::tls::TlsConnect;
use crate::types::{Oid, ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{cancel_query, cancel_query_raw, query, Transaction};
use crate::{cancel_query, cancel_query_raw, copy_in, query, Transaction};
use crate::{prepare, SimpleQueryMessage};
use crate::{simple_query, Row};
use crate::{Error, Statement};
use bytes::IntoBuf;
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use futures::{future, Stream};
use futures::{future, Stream, TryStream};
use futures::{ready, StreamExt};
use parking_lot::Mutex;
use postgres_protocol::message::backend::Message;
use std::collections::HashMap;
use std::error;
use std::future::Future;
use std::sync::Arc;
use std::task::{Context, Poll};
@ -240,6 +242,30 @@ impl Client {
query::execute(self.inner(), buf)
}
/// Executes a `COPY FROM STDIN` statement, returning the number of rows created.
///
/// The data in the provided stream is passed along to the server verbatim; it is the caller's responsibility to
/// ensure it uses the proper format.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub fn copy_in<S>(
&mut self,
statement: &Statement,
params: &[&dyn ToSql],
stream: S,
) -> impl Future<Output = Result<u64, Error>>
where
S: TryStream,
S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
let buf = query::encode(statement, params.iter().cloned());
copy_in::copy_in(self.inner(), buf, stream)
}
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that

View File

@ -1,4 +1,5 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::copy_in::CopyInReceiver;
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
@ -17,6 +18,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
pub enum RequestMessages {
Single(FrontendMessage),
CopyIn(CopyInReceiver),
}
pub struct Request {
@ -237,6 +239,24 @@ where
self.state = State::Closing;
}
}
RequestMessages::CopyIn(mut receiver) => {
let message = match receiver.poll_next_unpin(cx) {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => {
trace!("poll_write: finished copy_in request");
continue;
}
Poll::Pending => {
trace!("poll_write: waiting on copy_in stream");
self.pending_request = Some(RequestMessages::CopyIn(receiver));
return Ok(true);
}
};
Pin::new(&mut self.stream)
.start_send(message)
.map_err(Error::io)?;
self.pending_request = Some(RequestMessages::CopyIn(receiver));
}
}
}
}

View File

@ -0,0 +1,155 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::Error;
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
use futures::channel::mpsc;
use futures::ready;
use futures::{SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
use pin_utils::pin_mut;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::error;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use postgres_protocol::message::frontend::CopyData;
enum CopyInMessage {
Message(FrontendMessage),
Done,
}
pub struct CopyInReceiver {
receiver: mpsc::Receiver<CopyInMessage>,
done: bool,
}
impl CopyInReceiver {
fn new(receiver: mpsc::Receiver<CopyInMessage>) -> CopyInReceiver {
CopyInReceiver {
receiver,
done: false,
}
}
}
impl Stream for CopyInReceiver {
type Item = FrontendMessage;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
if self.done {
return Poll::Ready(None);
}
match ready!(self.receiver.poll_next_unpin(cx)) {
Some(CopyInMessage::Message(message)) => Poll::Ready(Some(message)),
Some(CopyInMessage::Done) => {
self.done = true;
let mut buf = vec![];
frontend::copy_done(&mut buf);
frontend::sync(&mut buf);
Poll::Ready(Some(FrontendMessage::Raw(buf)))
}
None => {
self.done = true;
let mut buf = vec![];
frontend::copy_fail("", &mut buf).unwrap();
frontend::sync(&mut buf);
Poll::Ready(Some(FrontendMessage::Raw(buf)))
}
}
}
}
pub async fn copy_in<S>(
client: Arc<InnerClient>,
buf: Result<Vec<u8>, Error>,
stream: S,
) -> Result<u64, Error>
where
S: TryStream,
S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
let buf = buf?;
let (mut sender, receiver) = mpsc::channel(1);
let receiver = CopyInReceiver::new(receiver);
let mut responses = client.send(RequestMessages::CopyIn(receiver))?;
sender
.send(CopyInMessage::Message(FrontendMessage::Raw(buf)))
.await
.map_err(|_| Error::closed())?;
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
match responses.next().await? {
Message::CopyInResponse(_) => {}
_ => return Err(Error::unexpected_message()),
}
let mut bytes = BytesMut::new();
let stream = stream.into_stream();
pin_mut!(stream);
while let Some(buf) = stream.try_next().await.map_err(Error::copy_in_stream)? {
let buf = buf.into_buf();
let data: Box<dyn Buf + Send> = if buf.remaining() > 4096 {
if bytes.is_empty() {
Box::new(buf)
} else {
Box::new(bytes.take().freeze().into_buf().chain(buf))
}
} else {
bytes.reserve(buf.remaining());
bytes.put(buf);
if bytes.len() > 4096 {
Box::new(bytes.take().freeze().into_buf())
} else {
continue;
}
};
let data = CopyData::new(data).map_err(Error::encode)?;
sender
.send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
.await
.map_err(|_| Error::closed())?;
}
if !bytes.is_empty() {
let data: Box<dyn Buf + Send> = Box::new(bytes.freeze().into_buf());
let data = CopyData::new(data).map_err(Error::encode)?;
sender
.send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
.await
.map_err(|_| Error::closed())?;
}
sender
.send(CopyInMessage::Done)
.await
.map_err(|_| Error::closed())?;
match responses.next().await? {
Message::CommandComplete(body) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
Ok(rows)
}
_ => Err(Error::unexpected_message()),
}
}

View File

@ -9,10 +9,11 @@
//! use tokio_postgres::{NoTls, Error, Row};
//!
//! # #[cfg(feature = "runtime")]
//! #[tokio::main]
//! #[tokio::main] // By default, tokio_postgres uses the tokio crate as its runtime.
//! async fn main() -> Result<(), Error> {
//! // Connect to the database.
//! let (mut client, connection) = tokio_postgres::connect("host=localhost user=postgres", NoTls).await?;
//! let (mut client, connection) =
//! tokio_postgres::connect("host=localhost user=postgres", NoTls).await?;
//!
//! // The connection object performs the actual communication with the database,
//! // so spawn it off to run on its own.
@ -108,7 +109,6 @@
pub use crate::client::Client;
pub use crate::config::Config;
pub use crate::transaction::Transaction;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
@ -118,6 +118,7 @@ pub use crate::socket::Socket;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
pub use crate::tls::NoTls;
pub use crate::transaction::Transaction;
pub use statement::{Column, Statement};
#[cfg(feature = "runtime")]
@ -133,6 +134,7 @@ mod connect_raw;
mod connect_socket;
mod connect_tls;
mod connection;
mod copy_in;
pub mod error;
mod maybe_tls_stream;
mod prepare;
@ -142,8 +144,8 @@ mod simple_query;
#[cfg(feature = "runtime")]
mod socket;
mod statement;
mod transaction;
pub mod tls;
mod transaction;
pub mod types;
/// A convenience function which parses a connection string and connects to the database.

View File

@ -7,8 +7,10 @@ use crate::types::{ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{query, Client, Error, Row, SimpleQueryMessage, Statement};
use futures::Stream;
use bytes::IntoBuf;
use futures::{Stream, TryStream};
use postgres_protocol::message::frontend;
use std::error;
use std::future::Future;
use tokio::io::{AsyncRead, AsyncWrite};
@ -120,6 +122,22 @@ impl<'a> Transaction<'a> {
query::execute(self.client.inner(), buf)
}
/// Like `Client::copy_in`.
pub fn copy_in<S>(
&mut self,
statement: &Statement,
params: &[&dyn ToSql],
stream: S,
) -> impl Future<Output = Result<u64, Error>>
where
S: TryStream,
S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
self.client.copy_in(statement, params, stream)
}
/// Like `Client::simple_query`.
pub fn simple_query(
&mut self,

View File

@ -2,13 +2,15 @@
#![feature(async_await)]
use futures::{join, try_join, FutureExt, TryStreamExt};
use std::fmt::Write;
use std::time::{Duration, Instant};
use futures::stream;
use tokio::net::TcpStream;
use tokio::timer::Delay;
use tokio_postgres::error::SqlState;
use tokio_postgres::tls::{NoTls, NoTlsStream};
use tokio_postgres::types::{Kind, Type};
use tokio_postgres::{Client, Config, Connection, Error, SimpleQueryMessage};
use tokio::timer::Delay;
use std::time::{Duration, Instant};
mod parse;
#[cfg(feature = "runtime")]
@ -301,7 +303,9 @@ async fn simple_query() {
async fn cancel_query_raw() {
let mut client = connect("user=postgres").await;
let socket = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()).await.unwrap();
let socket = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
.await
.unwrap();
let cancel = client.cancel_query_raw(socket, NoTls);
let cancel = Delay::new(Instant::now() + Duration::from_millis(100)).then(|()| cancel);
@ -317,19 +321,29 @@ async fn cancel_query_raw() {
async fn transaction_commit() {
let mut client = connect("user=postgres").await;
client.batch_execute(
"CREATE TEMPORARY TABLE foo(
client
.batch_execute(
"CREATE TEMPORARY TABLE foo(
id SERIAL,
name TEXT
)",
).await.unwrap();
)
.await
.unwrap();
let mut transaction = client.transaction().await.unwrap();
transaction.batch_execute("INSERT INTO foo (name) VALUES ('steven')").await.unwrap();
transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await
.unwrap();
transaction.commit().await.unwrap();
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client.query(&stmt, &[]).try_collect::<Vec<_>>().await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, &str>(0), "steven");
@ -339,19 +353,29 @@ async fn transaction_commit() {
async fn transaction_rollback() {
let mut client = connect("user=postgres").await;
client.batch_execute(
"CREATE TEMPORARY TABLE foo(
client
.batch_execute(
"CREATE TEMPORARY TABLE foo(
id SERIAL,
name TEXT
)",
).await.unwrap();
)
.await
.unwrap();
let mut transaction = client.transaction().await.unwrap();
transaction.batch_execute("INSERT INTO foo (name) VALUES ('steven')").await.unwrap();
transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await
.unwrap();
transaction.rollback().await.unwrap();
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client.query(&stmt, &[]).try_collect::<Vec<_>>().await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows.len(), 0);
}
@ -360,20 +384,105 @@ async fn transaction_rollback() {
async fn transaction_rollback_drop() {
let mut client = connect("user=postgres").await;
client.batch_execute(
"CREATE TEMPORARY TABLE foo(
client
.batch_execute(
"CREATE TEMPORARY TABLE foo(
id SERIAL,
name TEXT
)",
).await.unwrap();
)
.await
.unwrap();
let mut transaction = client.transaction().await.unwrap();
transaction.batch_execute("INSERT INTO foo (name) VALUES ('steven')").await.unwrap();
transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await
.unwrap();
drop(transaction);
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows.len(), 0);
}
#[tokio::test]
async fn copy_in() {
let mut client = connect("user=postgres").await;
client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,\
name TEXT\
)"
).await.unwrap();
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
let stream = stream::iter(vec![b"1\tjim\n".to_vec(), b"2\tjoe\n".to_vec()].into_iter().map(Ok::<_, String>));
let rows = client.copy_in(&stmt, &[], stream).await.unwrap();
assert_eq!(rows, 2);
let stmt = client.prepare("SELECT id, name FROM foo ORDER BY id").await.unwrap();
let rows = client.query(&stmt, &[]).try_collect::<Vec<_>>().await.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[0].get::<_, &str>(1), "jim");
assert_eq!(rows[1].get::<_, i32>(0), 2);
assert_eq!(rows[1].get::<_, &str>(1), "joe");
}
#[tokio::test]
async fn copy_in_large() {
let mut client = connect("user=postgres").await;
client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,\
name TEXT\
)"
).await.unwrap();
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
let a = "0\tname0\n".to_string();
let mut b = String::new();
for i in 1..5_000 {
writeln!(b, "{0}\tname{0}", i).unwrap();
}
let mut c = String::new();
for i in 5_000..10_000 {
writeln!(c, "{0}\tname{0}", i).unwrap();
}
let stream = stream::iter(vec![a, b, c].into_iter().map(Ok::<_, String>));
let rows = client.copy_in(&stmt, &[], stream).await.unwrap();
assert_eq!(rows, 10_000);
}
#[tokio::test]
async fn copy_in_error() {
let mut client = connect("user=postgres").await;
client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,\
name TEXT\
)"
).await.unwrap();
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
let stream = stream::iter(vec![Ok(b"1\tjim\n".to_vec()), Err("asdf")]);
let error = client.copy_in(&stmt, &[], stream).await.unwrap_err();
assert!(error.to_string().contains("asdf"));
let stmt = client.prepare("SELECT id, name FROM foo ORDER BY id").await.unwrap();
let rows = client.query(&stmt, &[]).try_collect::<Vec<_>>().await.unwrap();
assert_eq!(rows.len(), 0);
}
@ -466,139 +575,6 @@ fn notifications() {
assert_eq!(notifications[1].payload(), "world");
}
#[test]
fn copy_in() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let stream = stream::iter_ok::<_, String>(vec![b"1\tjim\n".to_vec(), b"2\tjoe\n".to_vec()]);
let rows = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
)
.unwrap();
assert_eq!(rows, 2);
let rows = runtime
.block_on(
client
.prepare("SELECT id, name FROM foo ORDER BY id")
.and_then(|s| client.query(&s, &[]).collect()),
)
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[0].get::<_, &str>(1), "jim");
assert_eq!(rows[1].get::<_, i32>(0), 2);
assert_eq!(rows[1].get::<_, &str>(1), "joe");
}
#[test]
fn copy_in_large() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let a = "0\tname0\n".to_string();
let mut b = String::new();
for i in 1..5_000 {
writeln!(b, "{0}\tname{0}", i).unwrap();
}
let mut c = String::new();
for i in 5_000..10_000 {
writeln!(c, "{0}\tname{0}", i).unwrap();
}
let stream = stream::iter_ok::<_, String>(vec![a, b, c]);
let rows = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
)
.unwrap();
assert_eq!(rows, 10_000);
}
#[test]
fn copy_in_error() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let stream = stream::iter_result(vec![Ok(b"1\tjim\n".to_vec()), Err("asdf")]);
let error = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
)
.unwrap_err();
assert!(error.to_string().contains("asdf"));
let rows = runtime
.block_on(
client
.prepare("SELECT id, name FROM foo ORDER BY id")
.and_then(|s| client.query(&s, &[]).collect()),
)
.unwrap();
assert_eq!(rows.len(), 0);
}
#[test]
fn copy_out() {
let _ = env_logger::try_init();

View File

@ -1,8 +1,8 @@
use futures::{FutureExt, TryStreamExt, join};
use futures::{join, FutureExt, TryStreamExt};
use std::time::{Duration, Instant};
use tokio::timer::Delay;
use tokio_postgres::error::SqlState;
use tokio_postgres::{NoTls, Client};
use tokio_postgres::{Client, NoTls};
async fn connect(s: &str) -> Client {
let (client, connection) = tokio_postgres::connect(s, NoTls).await.unwrap();