Remove synchronous crate

It will be coming back! It's just going to involve a full rewrite and
removing it for now makes some of that restructuring easier.
This commit is contained in:
Steven Fackler 2018-12-08 16:11:03 -08:00
parent 73ee31e522
commit 14571ab029
27 changed files with 0 additions and 6274 deletions

View File

@ -34,6 +34,5 @@ jobs:
- run: rustc --version > ~/rust-version
- *RESTORE_DEPS
- run: cargo test --all
- run: cargo test -p postgres --all-features
- run: cargo test -p tokio-postgres --all-features
- *SAVE_DEPS

View File

@ -1,11 +1,8 @@
[workspace]
members = [
"codegen",
"postgres",
"postgres-protocol",
"postgres-shared",
"postgres-openssl",
"postgres-native-tls",
"tokio-postgres",
"tokio-postgres-native-tls",
"tokio-postgres-openssl",

View File

@ -1,9 +0,0 @@
[package]
name = "postgres-native-tls"
version = "0.1.0"
authors = ["Steven Fackler <sfackler@gmail.com>"]
[dependencies]
native-tls = "0.2"
postgres = { version = "0.15", path = "../postgres" }

View File

@ -1,76 +0,0 @@
pub extern crate native_tls;
extern crate postgres;
use native_tls::TlsConnector;
use postgres::tls::{Stream, TlsHandshake, TlsStream};
use std::error::Error;
use std::fmt::{self, Debug};
use std::io::{self, Read, Write};
#[cfg(test)]
mod test;
pub struct NativeTls {
connector: TlsConnector,
}
impl Debug for NativeTls {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("NativeTls").finish()
}
}
impl NativeTls {
pub fn new() -> Result<NativeTls, native_tls::Error> {
let connector = TlsConnector::builder().build()?;
Ok(NativeTls::with_connector(connector))
}
pub fn with_connector(connector: TlsConnector) -> NativeTls {
NativeTls { connector }
}
}
impl TlsHandshake for NativeTls {
fn tls_handshake(
&self,
domain: &str,
stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Sync + Send>> {
let stream = self.connector.connect(domain, stream)?;
Ok(Box::new(NativeTlsStream(stream)))
}
}
#[derive(Debug)]
struct NativeTlsStream(native_tls::TlsStream<Stream>);
impl Read for NativeTlsStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for NativeTlsStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl TlsStream for NativeTlsStream {
fn get_ref(&self) -> &Stream {
self.0.get_ref()
}
fn get_mut(&mut self) -> &mut Stream {
self.0.get_mut()
}
fn tls_server_end_point(&self) -> Option<Vec<u8>> {
self.0.tls_server_end_point().ok().and_then(|o| o)
}
}

View File

@ -1,38 +0,0 @@
use native_tls::{Certificate, TlsConnector};
use postgres::{Connection, TlsMode};
use NativeTls;
#[test]
fn connect() {
let cert = include_bytes!("../../test/server.crt");
let cert = Certificate::from_pem(cert).unwrap();
let mut builder = TlsConnector::builder();
builder.add_root_certificate(cert);
let connector = builder.build().unwrap();
let handshake = NativeTls::with_connector(connector);
let conn = Connection::connect(
"postgres://ssl_user@localhost:5433/postgres",
TlsMode::Require(&handshake),
).unwrap();
conn.execute("SELECT 1::VARCHAR", &[]).unwrap();
}
#[test]
fn scram_user() {
let cert = include_bytes!("../../test/server.crt");
let cert = Certificate::from_pem(cert).unwrap();
let mut builder = TlsConnector::builder();
builder.add_root_certificate(cert);
let connector = builder.build().unwrap();
let handshake = NativeTls::with_connector(connector);
let conn = Connection::connect(
"postgres://scram_user:password@localhost:5433/postgres",
TlsMode::Require(&handshake),
).unwrap();
conn.execute("SELECT 1::VARCHAR", &[]).unwrap();
}

View File

@ -1,9 +0,0 @@
[package]
name = "postgres-openssl"
version = "0.1.0"
authors = ["Steven Fackler <sfackler@gmail.com>"]
[dependencies]
openssl = "0.10.9"
postgres = { version = "0.15", path = "../postgres" }

View File

@ -1,100 +0,0 @@
pub extern crate openssl;
extern crate postgres;
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
use openssl::nid::Nid;
use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod, SslStream};
use postgres::tls::{Stream, TlsHandshake, TlsStream};
use std::error::Error;
use std::fmt;
use std::io::{self, Read, Write};
#[cfg(test)]
mod test;
pub struct OpenSsl {
connector: SslConnector,
config: Box<Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + Sync + Send>,
}
impl fmt::Debug for OpenSsl {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("OpenSsl").finish()
}
}
impl OpenSsl {
pub fn new() -> Result<OpenSsl, ErrorStack> {
let connector = SslConnector::builder(SslMethod::tls())?.build();
Ok(OpenSsl::with_connector(connector))
}
pub fn with_connector(connector: SslConnector) -> OpenSsl {
OpenSsl {
connector,
config: Box::new(|_| Ok(())),
}
}
pub fn callback<F>(&mut self, f: F)
where
F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.config = Box::new(f);
}
}
impl TlsHandshake for OpenSsl {
fn tls_handshake(
&self,
domain: &str,
stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Sync + Send>> {
let mut ssl = self.connector.configure()?;
(self.config)(&mut ssl)?;
let stream = ssl.connect(domain, stream)?;
Ok(Box::new(OpenSslStream(stream)))
}
}
#[derive(Debug)]
struct OpenSslStream(SslStream<Stream>);
impl Read for OpenSslStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for OpenSslStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl TlsStream for OpenSslStream {
fn get_ref(&self) -> &Stream {
self.0.get_ref()
}
fn get_mut(&mut self) -> &mut Stream {
self.0.get_mut()
}
fn tls_server_end_point(&self) -> Option<Vec<u8>> {
let cert = self.0.ssl().peer_certificate()?;
let algo_nid = cert.signature_algorithm().object().nid();
let signature_algorithms = algo_nid.signature_algorithms()?;
let md = match signature_algorithms.digest {
Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
nid => MessageDigest::from_nid(nid)?,
};
cert.digest(md).ok().map(|b| b.to_vec())
}
}

View File

@ -1,40 +0,0 @@
use openssl::ssl::{SslConnector, SslMethod};
use postgres::{Connection, TlsMode};
use OpenSsl;
#[test]
fn require() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let negotiator = OpenSsl::with_connector(builder.build());
let conn = Connection::connect(
"postgres://ssl_user@localhost:5433/postgres",
TlsMode::Require(&negotiator),
).unwrap();
conn.execute("SELECT 1::VARCHAR", &[]).unwrap();
}
#[test]
fn prefer() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let negotiator = OpenSsl::with_connector(builder.build());
let conn = Connection::connect(
"postgres://ssl_user@localhost:5433/postgres",
TlsMode::Require(&negotiator),
).unwrap();
conn.execute("SELECT 1::VARCHAR", &[]).unwrap();
}
#[test]
fn scram_user() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let negotiator = OpenSsl::with_connector(builder.build());
let conn = Connection::connect(
"postgres://scram_user:password@localhost:5433/postgres",
TlsMode::Require(&negotiator),
).unwrap();
conn.execute("SELECT 1::VARCHAR", &[]).unwrap();
}

View File

@ -1,66 +0,0 @@
[package]
name = "postgres"
version = "0.15.2"
authors = ["Steven Fackler <sfackler@gmail.com>"]
license = "MIT"
description = "A native PostgreSQL driver"
repository = "https://github.com/sfackler/rust-postgres"
readme = "../README.md"
keywords = ["database", "postgres", "postgresql", "sql"]
include = ["src/*", "Cargo.toml", "LICENSE", "README.md", "THIRD_PARTY"]
categories = ["database"]
[package.metadata.docs.rs]
features = [
"with-bit-vec-0.5",
"with-chrono-0.4",
"with-eui48-0.3",
"with-geo-0.10",
"with-serde_json-1",
"with-uuid-0.6",
"with-openssl",
"with-native-tls",
]
[badges]
circle-ci = { repository = "sfackler/rust-postgres" }
[lib]
name = "postgres"
path = "src/lib.rs"
test = false
bench = false
[[test]]
name = "test"
path = "tests/test.rs"
[features]
"with-bit-vec-0.5" = ["postgres-shared/with-bit-vec-0.5"]
"with-chrono-0.4" = ["postgres-shared/with-chrono-0.4"]
"with-eui48-0.3" = ["postgres-shared/with-eui48-0.3"]
"with-geo-0.10" = ["postgres-shared/with-geo-0.10"]
"with-serde_json-1" = ["postgres-shared/with-serde_json-1"]
"with-uuid-0.6" = ["postgres-shared/with-uuid-0.6"]
no-logging = []
[dependencies]
bytes = "0.4"
fallible-iterator = "0.1.3"
log = "0.4"
socket2 = { version = "0.3.5", features = ["unix"] }
postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
postgres-shared = { version = "0.4.1", path = "../postgres-shared" }
[dev-dependencies]
hex = "0.3"
url = "1.0"
bit-vec = "0.5"
chrono = "0.4"
eui48 = "0.3"
geo = "0.10"
serde_json = "1.0"
uuid = "0.6"

File diff suppressed because it is too large Load Diff

View File

@ -1,69 +0,0 @@
macro_rules! try_desync {
($s:expr, $e:expr) => (
match $e {
Ok(ok) => ok,
Err(err) => {
$s.desynchronized = true;
return Err(::std::convert::From::from(err));
}
}
)
}
macro_rules! check_desync {
($e:expr) => ({
if $e.is_desynchronized() {
return Err(::desynchronized().into());
}
})
}
macro_rules! bad_response {
($s:expr) => {{
debug!("Bad response at {}:{}", file!(), line!());
$s.desynchronized = true;
return Err(::bad_response().into());
}};
}
#[cfg(feature = "no-logging")]
macro_rules! debug {
($($t:tt)*) => {};
}
#[cfg(feature = "no-logging")]
macro_rules! info {
($($t:tt)*) => {};
}
/// Generates a simple implementation of `ToSql::accepts` which accepts the
/// types passed to it.
#[macro_export]
macro_rules! accepts {
($($expected:pat),+) => (
fn accepts(ty: &$crate::types::Type) -> bool {
match *ty {
$($expected)|+ => true,
_ => false
}
}
)
}
/// Generates an implementation of `ToSql::to_sql_checked`.
///
/// All `ToSql` implementations should use this macro.
#[macro_export]
macro_rules! to_sql_checked {
() => {
fn to_sql_checked(&self,
ty: &$crate::types::Type,
out: &mut ::std::vec::Vec<u8>)
-> ::std::result::Result<$crate::types::IsNull,
Box<::std::error::Error +
::std::marker::Sync +
::std::marker::Send>> {
$crate::types::__to_sql_checked(self, ty, out)
}
}
}

View File

@ -1,203 +0,0 @@
//! Asynchronous notifications.
use error::DbError;
use fallible_iterator::{FallibleIterator, IntoFallibleIterator};
use postgres_protocol::message::backend::{self, ErrorFields};
use std::fmt;
use std::time::Duration;
#[doc(inline)]
use postgres_shared;
pub use postgres_shared::Notification;
use error::Error;
use {desynchronized, Connection, Result};
/// Notifications from the Postgres backend.
pub struct Notifications<'conn> {
conn: &'conn Connection,
}
impl<'a> fmt::Debug for Notifications<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Notifications")
.field("pending", &self.len())
.finish()
}
}
impl<'conn> Notifications<'conn> {
pub(crate) fn new(conn: &'conn Connection) -> Notifications<'conn> {
Notifications { conn: conn }
}
/// Returns the number of pending notifications.
pub fn len(&self) -> usize {
self.conn.0.borrow().notifications.len()
}
/// Determines if there are any pending notifications.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns a fallible iterator over pending notifications.
///
/// # Note
///
/// This iterator may start returning `Some` after previously returning
/// `None` if more notifications are received.
pub fn iter<'a>(&'a self) -> Iter<'a> {
Iter { conn: self.conn }
}
/// Returns a fallible iterator over notifications that blocks until one is
/// received if none are pending.
///
/// The iterator will never return `None`.
pub fn blocking_iter<'a>(&'a self) -> BlockingIter<'a> {
BlockingIter { conn: self.conn }
}
/// Returns a fallible iterator over notifications that blocks for a limited
/// time waiting to receive one if none are pending.
///
/// # Note
///
/// This iterator may start returning `Some` after previously returning
/// `None` if more notifications are received.
pub fn timeout_iter<'a>(&'a self, timeout: Duration) -> TimeoutIter<'a> {
TimeoutIter {
conn: self.conn,
timeout: timeout,
}
}
}
impl<'a, 'conn> IntoFallibleIterator for &'a Notifications<'conn> {
type Item = Notification;
type Error = Error;
type IntoIter = Iter<'a>;
fn into_fallible_iterator(self) -> Iter<'a> {
self.iter()
}
}
/// A fallible iterator over pending notifications.
pub struct Iter<'a> {
conn: &'a Connection,
}
impl<'a> FallibleIterator for Iter<'a> {
type Item = Notification;
type Error = Error;
fn next(&mut self) -> Result<Option<Notification>> {
let mut conn = self.conn.0.borrow_mut();
if let Some(notification) = conn.notifications.pop_front() {
return Ok(Some(notification));
}
if conn.is_desynchronized() {
return Err(desynchronized().into());
}
match conn.read_message_with_notification_nonblocking() {
Ok(Some(backend::Message::NotificationResponse(body))) => Ok(Some(Notification {
process_id: body.process_id(),
channel: body.channel()?.to_owned(),
payload: body.message()?.to_owned(),
})),
Ok(Some(backend::Message::ErrorResponse(body))) => Err(err(&mut body.fields())),
Ok(None) => Ok(None),
Err(err) => Err(err.into()),
_ => unreachable!(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.conn.0.borrow().notifications.len(), None)
}
}
/// An iterator over notifications which will block if none are pending.
pub struct BlockingIter<'a> {
conn: &'a Connection,
}
impl<'a> FallibleIterator for BlockingIter<'a> {
type Item = Notification;
type Error = Error;
fn next(&mut self) -> Result<Option<Notification>> {
let mut conn = self.conn.0.borrow_mut();
if let Some(notification) = conn.notifications.pop_front() {
return Ok(Some(notification));
}
if conn.is_desynchronized() {
return Err(desynchronized().into());
}
match conn.read_message_with_notification() {
Ok(backend::Message::NotificationResponse(body)) => Ok(Some(Notification {
process_id: body.process_id(),
channel: body.channel()?.to_owned(),
payload: body.message()?.to_owned(),
})),
Ok(backend::Message::ErrorResponse(body)) => Err(err(&mut body.fields())),
Err(err) => Err(err.into()),
_ => unreachable!(),
}
}
}
/// An iterator over notifications which will block for a period of time if
/// none are pending.
pub struct TimeoutIter<'a> {
conn: &'a Connection,
timeout: Duration,
}
impl<'a> FallibleIterator for TimeoutIter<'a> {
type Item = Notification;
type Error = Error;
fn next(&mut self) -> Result<Option<Notification>> {
let mut conn = self.conn.0.borrow_mut();
if let Some(notification) = conn.notifications.pop_front() {
return Ok(Some(notification));
}
if conn.is_desynchronized() {
return Err(desynchronized().into());
}
match conn.read_message_with_notification_timeout(self.timeout) {
Ok(Some(backend::Message::NotificationResponse(body))) => Ok(Some(Notification {
process_id: body.process_id(),
channel: body.channel()?.to_owned(),
payload: body.message()?.to_owned(),
})),
Ok(Some(backend::Message::ErrorResponse(body))) => Err(err(&mut body.fields())),
Ok(None) => Ok(None),
Err(err) => Err(err.into()),
_ => unreachable!(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.conn.0.borrow().notifications.len(), None)
}
}
fn err(fields: &mut ErrorFields) -> Error {
match DbError::new(fields) {
Ok(err) => postgres_shared::error::db(err),
Err(err) => err.into(),
}
}

View File

@ -1,3 +0,0 @@
//! Connection parameters
pub use postgres_shared::params::{Builder, ConnectParams, Host, IntoConnectParams, User};

View File

@ -1,259 +0,0 @@
use bytes::{BufMut, BytesMut};
use postgres_protocol::message::backend;
use postgres_protocol::message::frontend;
use socket2::{Domain, SockAddr, Socket, Type};
use std::io::{self, BufWriter, Read, Write};
use std::net::{SocketAddr, ToSocketAddrs};
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::result;
use std::time::Duration;
use error;
use params::{ConnectParams, Host};
use tls::TlsStream;
use {Result, TlsMode};
const INITIAL_CAPACITY: usize = 8 * 1024;
pub struct MessageStream {
stream: BufWriter<Box<TlsStream>>,
in_buf: BytesMut,
out_buf: Vec<u8>,
}
impl MessageStream {
pub fn new(stream: Box<TlsStream>) -> MessageStream {
MessageStream {
stream: BufWriter::new(stream),
in_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
out_buf: vec![],
}
}
pub fn get_ref(&self) -> &TlsStream {
&**self.stream.get_ref()
}
pub fn write_message<F, E>(&mut self, f: F) -> result::Result<(), E>
where
F: FnOnce(&mut Vec<u8>) -> result::Result<(), E>,
E: From<io::Error>,
{
self.out_buf.clear();
f(&mut self.out_buf)?;
self.stream.write_all(&self.out_buf).map_err(From::from)
}
pub fn read_message(&mut self) -> io::Result<backend::Message> {
loop {
match backend::Message::parse(&mut self.in_buf) {
Ok(Some(message)) => return Ok(message),
Ok(None) => self.read_in()?,
Err(e) => return Err(e),
}
}
}
fn read_in(&mut self) -> io::Result<()> {
self.in_buf.reserve(1);
match self
.stream
.get_mut()
.read(unsafe { self.in_buf.bytes_mut() })
{
Ok(0) => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
Ok(n) => {
unsafe { self.in_buf.advance_mut(n) };
Ok(())
}
Err(e) => Err(e),
}
}
pub fn read_message_timeout(
&mut self,
timeout: Duration,
) -> io::Result<Option<backend::Message>> {
if self.in_buf.is_empty() {
self.set_read_timeout(Some(timeout))?;
let r = self.read_in();
self.set_read_timeout(None)?;
match r {
Ok(()) => {}
Err(ref e)
if e.kind() == io::ErrorKind::WouldBlock
|| e.kind() == io::ErrorKind::TimedOut =>
{
return Ok(None)
}
Err(e) => return Err(e),
}
}
self.read_message().map(Some)
}
pub fn read_message_nonblocking(&mut self) -> io::Result<Option<backend::Message>> {
if self.in_buf.is_empty() {
self.set_nonblocking(true)?;
let r = self.read_in();
self.set_nonblocking(false)?;
match r {
Ok(()) => {}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
Err(e) => return Err(e),
}
}
self.read_message().map(Some)
}
pub fn flush(&mut self) -> io::Result<()> {
self.stream.flush()
}
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
self.stream.get_ref().get_ref().0.set_read_timeout(timeout)
}
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> {
self.stream.get_ref().get_ref().0.set_nonblocking(nonblock)
}
}
/// A connection to the Postgres server.
///
/// It implements `Read`, `Write` and `TlsStream`, as well as `AsRawFd` on
/// Unix platforms and `AsRawSocket` on Windows platforms.
#[derive(Debug)]
pub struct Stream(Socket);
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl TlsStream for Stream {
fn get_ref(&self) -> &Stream {
self
}
fn get_mut(&mut self) -> &mut Stream {
self
}
}
#[cfg(unix)]
impl AsRawFd for Stream {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}
#[cfg(windows)]
impl AsRawSocket for Stream {
fn as_raw_socket(&self) -> RawSocket {
self.0.as_raw_socket()
}
}
fn open_socket(params: &ConnectParams) -> Result<Socket> {
let port = params.port();
match *params.host() {
Host::Tcp(ref host) => {
let mut error = None;
for addr in (&**host, port).to_socket_addrs()? {
let domain = match addr {
SocketAddr::V4(_) => Domain::ipv4(),
SocketAddr::V6(_) => Domain::ipv6(),
};
let socket = Socket::new(domain, Type::stream(), None)?;
if let Some(keepalive) = params.keepalive() {
socket.set_keepalive(Some(keepalive))?;
}
let addr = SockAddr::from(addr);
let r = match params.connect_timeout() {
Some(timeout) => socket.connect_timeout(&addr, timeout),
None => socket.connect(&addr),
};
match r {
Ok(()) => return Ok(socket),
Err(e) => error = Some(e),
}
}
Err(error
.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
)
}).into())
}
#[cfg(unix)]
Host::Unix(ref path) => {
let path = path.join(&format!(".s.PGSQL.{}", port));
let socket = Socket::new(Domain::unix(), Type::stream(), None)?;
let addr = SockAddr::unix(path)?;
socket.connect(&addr)?;
Ok(socket)
}
#[cfg(not(unix))]
Host::Unix(..) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"unix sockets are not supported on this system",
).into()),
}
}
pub fn initialize_stream(params: &ConnectParams, tls: TlsMode) -> Result<Box<TlsStream>> {
let mut socket = Stream(open_socket(params)?);
let (tls_required, handshaker) = match tls {
TlsMode::None => return Ok(Box::new(socket)),
TlsMode::Prefer(handshaker) => (false, handshaker),
TlsMode::Require(handshaker) => (true, handshaker),
};
let mut buf = vec![];
frontend::ssl_request(&mut buf);
socket.write_all(&buf)?;
socket.flush()?;
let mut b = [0; 1];
socket.read_exact(&mut b)?;
if b[0] == b'N' {
if tls_required {
return Err(error::tls("the server does not support TLS".into()));
} else {
return Ok(Box::new(socket));
}
}
let host = match *params.host() {
Host::Tcp(ref host) => host,
// Postgres doesn't support TLS over unix sockets
Host::Unix(_) => return Err(::bad_response().into()),
};
handshaker.tls_handshake(host, socket).map_err(error::tls)
}

View File

@ -1,342 +0,0 @@
//! Query result rows.
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::frontend;
use postgres_shared::rows::RowData;
use std::collections::VecDeque;
use std::fmt;
use std::io;
use std::ops::Deref;
use std::slice;
use std::sync::Arc;
#[doc(inline)]
pub use postgres_shared::rows::RowIndex;
use error;
use stmt::{Column, Statement};
use transaction::Transaction;
use types::{FromSql, WrongType};
use {Error, Result, StatementInfo};
enum MaybeOwned<'a, T: 'a> {
Borrowed(&'a T),
Owned(T),
}
impl<'a, T> Deref for MaybeOwned<'a, T> {
type Target = T;
fn deref(&self) -> &T {
match *self {
MaybeOwned::Borrowed(s) => s,
MaybeOwned::Owned(ref s) => s,
}
}
}
/// The resulting rows of a query.
pub struct Rows {
stmt_info: Arc<StatementInfo>,
data: Vec<RowData>,
}
impl fmt::Debug for Rows {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Rows")
.field("columns", &self.columns())
.field("rows", &self.data.len())
.finish()
}
}
impl Rows {
pub(crate) fn new(stmt: &Statement, data: Vec<RowData>) -> Rows {
Rows {
stmt_info: stmt.info().clone(),
data: data,
}
}
/// Returns a slice describing the columns of the `Rows`.
pub fn columns(&self) -> &[Column] {
&self.stmt_info.columns[..]
}
/// Returns the number of rows present.
pub fn len(&self) -> usize {
self.data.len()
}
/// Determines if there are any rows present.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns a specific `Row`.
///
/// # Panics
///
/// Panics if `idx` is out of bounds.
pub fn get<'a>(&'a self, idx: usize) -> Row<'a> {
Row {
stmt_info: &self.stmt_info,
data: MaybeOwned::Borrowed(&self.data[idx]),
}
}
/// Returns an iterator over the `Row`s.
pub fn iter<'a>(&'a self) -> Iter<'a> {
Iter {
stmt_info: &self.stmt_info,
iter: self.data.iter(),
}
}
}
impl<'a> IntoIterator for &'a Rows {
type Item = Row<'a>;
type IntoIter = Iter<'a>;
fn into_iter(self) -> Iter<'a> {
self.iter()
}
}
/// An iterator over `Row`s.
pub struct Iter<'a> {
stmt_info: &'a StatementInfo,
iter: slice::Iter<'a, RowData>,
}
impl<'a> Iterator for Iter<'a> {
type Item = Row<'a>;
fn next(&mut self) -> Option<Row<'a>> {
self.iter.next().map(|row| Row {
stmt_info: self.stmt_info,
data: MaybeOwned::Borrowed(row),
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<'a> DoubleEndedIterator for Iter<'a> {
fn next_back(&mut self) -> Option<Row<'a>> {
self.iter.next_back().map(|row| Row {
stmt_info: self.stmt_info,
data: MaybeOwned::Borrowed(row),
})
}
}
impl<'a> ExactSizeIterator for Iter<'a> {}
/// A single result row of a query.
pub struct Row<'a> {
stmt_info: &'a StatementInfo,
data: MaybeOwned<'a, RowData>,
}
impl<'a> fmt::Debug for Row<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Row")
.field("statement", self.stmt_info)
.finish()
}
}
impl<'a> Row<'a> {
/// Returns the number of values in the row.
pub fn len(&self) -> usize {
self.data.len()
}
/// Determines if there are any values in the row.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns a slice describing the columns of the `Row`.
pub fn columns(&self) -> &[Column] {
&self.stmt_info.columns[..]
}
/// Retrieves the contents of a field of the row.
///
/// A field can be accessed by the name or index of its column, though
/// access by index is more efficient. Rows are 0-indexed.
///
/// # Panics
///
/// Panics if the index does not reference a column or the return type is
/// not compatible with the Postgres type.
///
/// # Example
///
/// ```rust,no_run
/// # use postgres::{Connection, TlsMode};
/// # let conn = Connection::connect("", TlsMode::None).unwrap();
/// let stmt = conn.prepare("SELECT foo, bar from BAZ").unwrap();
/// for row in &stmt.query(&[]).unwrap() {
/// let foo: i32 = row.get(0);
/// let bar: String = row.get("bar");
/// println!("{}: {}", foo, bar);
/// }
/// ```
pub fn get<'b, I, T>(&'b self, idx: I) -> T
where
I: RowIndex + fmt::Debug,
T: FromSql<'b>,
{
match self.get_inner(&idx) {
Some(Ok(ok)) => ok,
Some(Err(err)) => panic!("error retrieving column {:?}: {:?}", idx, err),
None => panic!("no such column {:?}", idx),
}
}
/// Retrieves the contents of a field of the row.
///
/// A field can be accessed by the name or index of its column, though
/// access by index is more efficient. Rows are 0-indexed.
///
/// Returns `None` if the index does not reference a column, `Some(Err(..))`
/// if there was an error converting the result value, and `Some(Ok(..))`
/// on success.
pub fn get_opt<'b, I, T>(&'b self, idx: I) -> Option<Result<T>>
where
I: RowIndex,
T: FromSql<'b>,
{
self.get_inner(&idx)
}
fn get_inner<'b, I, T>(&'b self, idx: &I) -> Option<Result<T>>
where
I: RowIndex,
T: FromSql<'b>,
{
let idx = match idx.__idx(&self.stmt_info.columns) {
Some(idx) => idx,
None => return None,
};
let ty = self.stmt_info.columns[idx].type_();
if !<T as FromSql>::accepts(ty) {
return Some(Err(error::conversion(Box::new(WrongType::new(ty.clone())))));
}
let value = FromSql::from_sql_nullable(ty, self.data.get(idx));
Some(value.map_err(error::conversion))
}
}
/// A lazily-loaded iterator over the resulting rows of a query.
pub struct LazyRows<'trans, 'stmt> {
stmt: &'stmt Statement<'stmt>,
data: VecDeque<RowData>,
name: String,
row_limit: i32,
more_rows: bool,
finished: bool,
_trans: &'trans Transaction<'trans>,
}
impl<'a, 'b> Drop for LazyRows<'a, 'b> {
fn drop(&mut self) {
if !self.finished {
let _ = self.finish_inner();
}
}
}
impl<'a, 'b> fmt::Debug for LazyRows<'a, 'b> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("LazyRows")
.field("name", &self.name)
.field("row_limit", &self.row_limit)
.field("remaining_rows", &self.data.len())
.field("more_rows", &self.more_rows)
.finish()
}
}
impl<'trans, 'stmt> LazyRows<'trans, 'stmt> {
pub(crate) fn new(
stmt: &'stmt Statement<'stmt>,
data: VecDeque<RowData>,
name: String,
row_limit: i32,
more_rows: bool,
finished: bool,
trans: &'trans Transaction<'trans>,
) -> LazyRows<'trans, 'stmt> {
LazyRows {
stmt: stmt,
data: data,
name: name,
row_limit: row_limit,
more_rows: more_rows,
finished: finished,
_trans: trans,
}
}
fn finish_inner(&mut self) -> Result<()> {
let mut conn = self.stmt.conn().0.borrow_mut();
check_desync!(conn);
conn.close_statement(&self.name, b'P')
}
fn execute(&mut self) -> Result<()> {
let mut conn = self.stmt.conn().0.borrow_mut();
conn.stream
.write_message(|buf| frontend::execute(&self.name, self.row_limit, buf))?;
conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?;
conn.stream.flush()?;
conn.read_rows(|row| self.data.push_back(row))
.map(|more_rows| self.more_rows = more_rows)
}
/// Returns a slice describing the columns of the `LazyRows`.
pub fn columns(&self) -> &[Column] {
self.stmt.columns()
}
/// Consumes the `LazyRows`, cleaning up associated state.
///
/// Functionally identical to the `Drop` implementation on `LazyRows`
/// except that it returns any error to the caller.
pub fn finish(mut self) -> Result<()> {
self.finish_inner()
}
}
impl<'trans, 'stmt> FallibleIterator for LazyRows<'trans, 'stmt> {
type Item = Row<'stmt>;
type Error = Error;
fn next(&mut self) -> Result<Option<Row<'stmt>>> {
if self.data.is_empty() && self.more_rows {
self.execute()?;
}
let row = self.data.pop_front().map(|r| Row {
stmt_info: &**self.stmt.info(),
data: MaybeOwned::Owned(r),
});
Ok(row)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let lower = self.data.len();
let upper = if self.more_rows { None } else { Some(lower) };
(lower, upper)
}
}

View File

@ -1,605 +0,0 @@
//! Prepared statements
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::{backend, frontend};
use postgres_shared::rows::RowData;
use std::cell::Cell;
use std::collections::VecDeque;
use std::fmt;
use std::io::{self, Read, Write};
use std::sync::Arc;
#[doc(inline)]
pub use postgres_shared::stmt::Column;
use rows::{LazyRows, Rows};
use transaction::Transaction;
use types::{ToSql, Type};
use {bad_response, err, Connection, Result, StatementInfo};
/// A prepared statement.
pub struct Statement<'conn> {
conn: &'conn Connection,
info: Arc<StatementInfo>,
next_portal_id: Cell<u32>,
finished: bool,
}
impl<'a> fmt::Debug for Statement<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&*self.info, fmt)
}
}
impl<'conn> Drop for Statement<'conn> {
fn drop(&mut self) {
let _ = self.finish_inner();
}
}
impl<'conn> Statement<'conn> {
pub(crate) fn new(
conn: &'conn Connection,
info: Arc<StatementInfo>,
next_portal_id: Cell<u32>,
finished: bool,
) -> Statement<'conn> {
Statement {
conn: conn,
info: info,
next_portal_id: next_portal_id,
finished: finished,
}
}
pub(crate) fn info(&self) -> &Arc<StatementInfo> {
&self.info
}
pub(crate) fn conn(&self) -> &'conn Connection {
self.conn
}
pub(crate) fn into_query(self, params: &[&ToSql]) -> Result<Rows> {
check_desync!(self.conn);
let mut rows = vec![];
self.inner_query("", 0, params, |row| rows.push(row))?;
Ok(Rows::new(&self, rows))
}
fn finish_inner(&mut self) -> Result<()> {
if self.finished {
Ok(())
} else {
self.finished = true;
let mut conn = self.conn.0.borrow_mut();
check_desync!(conn);
conn.close_statement(&self.info.name, b'S')
}
}
#[allow(type_complexity)]
fn inner_query<F>(
&self,
portal_name: &str,
row_limit: i32,
params: &[&ToSql],
acceptor: F,
) -> Result<bool>
where
F: FnMut(RowData),
{
let mut conn = self.conn.0.borrow_mut();
conn.raw_execute(
&self.info.name,
portal_name,
row_limit,
self.param_types(),
params,
)?;
conn.read_rows(acceptor)
}
/// Returns a slice containing the expected parameter types.
pub fn param_types(&self) -> &[Type] {
&self.info.param_types
}
/// Returns a slice describing the columns of the result of the query.
pub fn columns(&self) -> &[Column] {
&self.info.columns
}
/// Executes the prepared statement, returning the number of rows modified.
///
/// If the statement does not modify any rows (e.g. SELECT), 0 is returned.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number
/// expected.
///
/// # Example
///
/// ```rust,no_run
/// # use postgres::{Connection, TlsMode};
/// # let conn = Connection::connect("", TlsMode::None).unwrap();
/// # let bar = 1i32;
/// # let baz = true;
/// let stmt = conn.prepare("UPDATE foo SET bar = $1 WHERE baz = $2").unwrap();
/// let rows_updated = stmt.execute(&[&bar, &baz]).unwrap();
/// println!("{} rows updated", rows_updated);
/// ```
pub fn execute(&self, params: &[&ToSql]) -> Result<u64> {
let mut conn = self.conn.0.borrow_mut();
check_desync!(conn);
conn.raw_execute(&self.info.name, "", 0, self.param_types(), params)?;
let num;
loop {
match conn.read_message()? {
backend::Message::DataRow(_) => {}
backend::Message::ErrorResponse(body) => {
conn.wait_for_ready()?;
return Err(err(&mut body.fields()));
}
backend::Message::CommandComplete(body) => {
num = parse_update_count(body.tag()?);
break;
}
backend::Message::EmptyQueryResponse => {
num = 0;
break;
}
backend::Message::CopyInResponse(_) => {
conn.stream.write_message(|buf| {
frontend::copy_fail("COPY queries cannot be directly executed", buf)
})?;
conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?;
conn.stream.flush()?;
}
backend::Message::CopyOutResponse(_) => {
loop {
match conn.read_message()? {
backend::Message::CopyDone => break,
backend::Message::ErrorResponse(body) => {
conn.wait_for_ready()?;
return Err(err(&mut body.fields()));
}
_ => {}
}
}
num = 0;
break;
}
_ => {
conn.desynchronized = true;
return Err(bad_response().into());
}
}
}
conn.wait_for_ready()?;
Ok(num)
}
/// Executes the prepared statement, returning the resulting rows.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number
/// expected.
///
/// # Example
///
/// ```rust,no_run
/// # use postgres::{Connection, TlsMode};
/// # let conn = Connection::connect("", TlsMode::None).unwrap();
/// let stmt = conn.prepare("SELECT foo FROM bar WHERE baz = $1").unwrap();
/// # let baz = true;
/// for row in &stmt.query(&[&baz]).unwrap() {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// ```
pub fn query(&self, params: &[&ToSql]) -> Result<Rows> {
check_desync!(self.conn);
let mut rows = vec![];
self.inner_query("", 0, params, |row| rows.push(row))?;
Ok(Rows::new(self, rows))
}
/// Executes the prepared statement, returning a lazily loaded iterator
/// over the resulting rows.
///
/// No more than `row_limit` rows will be stored in memory at a time. Rows
/// will be pulled from the database in batches of `row_limit` as needed.
/// If `row_limit` is less than or equal to 0, `lazy_query` is equivalent
/// to `query`.
///
/// This can only be called inside of a transaction, and the `Transaction`
/// object representing the active transaction must be passed to
/// `lazy_query`.
///
/// # Panics
///
/// Panics if the provided `Transaction` is not associated with the same
/// `Connection` as this `Statement`, if the `Transaction` is not
/// active, or if the number of parameters provided does not match the
/// number of parameters expected.
///
/// # Examples
///
/// ```no_run
/// extern crate fallible_iterator;
/// extern crate postgres;
///
/// use fallible_iterator::FallibleIterator;
/// # use postgres::{Connection, TlsMode};
///
/// # fn main() {
/// # let conn = Connection::connect("", TlsMode::None).unwrap();
/// let stmt = conn.prepare("SELECT foo FROM bar WHERE baz = $1").unwrap();
/// let trans = conn.transaction().unwrap();
/// # let baz = true;
/// let mut rows = stmt.lazy_query(&trans, &[&baz], 100).unwrap();
///
/// while let Some(row) = rows.next().unwrap() {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// # }
/// ```
pub fn lazy_query<'trans, 'stmt>(
&'stmt self,
trans: &'trans Transaction,
params: &[&ToSql],
row_limit: i32,
) -> Result<LazyRows<'trans, 'stmt>> {
assert!(
self.conn as *const _ == trans.conn() as *const _,
"the `Transaction` passed to `lazy_query` must be associated with the same \
`Connection` as the `Statement`"
);
let conn = self.conn.0.borrow();
check_desync!(conn);
assert!(
conn.trans_depth == trans.depth(),
"`lazy_query` must be passed the active transaction"
);
drop(conn);
let id = self.next_portal_id.get();
self.next_portal_id.set(id + 1);
let portal_name = format!("{}p{}", self.info.name, id);
let mut rows = VecDeque::new();
let more_rows =
self.inner_query(&portal_name, row_limit, params, |row| rows.push_back(row))?;
Ok(LazyRows::new(
self,
rows,
portal_name,
row_limit,
more_rows,
false,
trans,
))
}
/// Executes a `COPY FROM STDIN` statement, returning the number of rows
/// added.
///
/// The contents of the provided reader are passed to the Postgres server
/// verbatim; it is the caller's responsibility to ensure it uses the
/// proper format. See the
/// [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html)
/// for details.
///
/// If the statement is not a `COPY FROM STDIN` statement it will still be
/// executed and this method will return an error.
///
/// # Examples
///
/// ```rust,no_run
/// # use postgres::{Connection, TlsMode};
/// # let conn = Connection::connect("", TlsMode::None).unwrap();
/// conn.batch_execute("CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR)").unwrap();
/// let stmt = conn.prepare("COPY people FROM STDIN").unwrap();
/// stmt.copy_in(&[], &mut "1\tjohn\n2\tjane\n".as_bytes()).unwrap();
/// ```
pub fn copy_in<R: ReadWithInfo>(&self, params: &[&ToSql], r: &mut R) -> Result<u64> {
let mut conn = self.conn.0.borrow_mut();
conn.raw_execute(&self.info.name, "", 0, self.param_types(), params)?;
let (format, column_formats) = match conn.read_message()? {
backend::Message::CopyInResponse(body) => {
let format = body.format();
let column_formats = body
.column_formats()
.map(|f| Format::from_u16(f))
.collect()?;
(format, column_formats)
}
backend::Message::ErrorResponse(body) => {
conn.wait_for_ready()?;
return Err(err(&mut body.fields()));
}
_ => loop {
if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"called `copy_in` on a non-`COPY FROM STDIN` statement",
).into());
}
},
};
let info = CopyInfo {
format: Format::from_u16(format as u16),
column_formats: column_formats,
};
let mut buf = [0; 16 * 1024];
loop {
match fill_copy_buf(&mut buf, r, &info) {
Ok(0) => break,
Ok(len) => {
conn.stream
.write_message(|out| frontend::copy_data(&buf[..len], out))?;
}
Err(err) => {
conn.stream
.write_message(|buf| frontend::copy_fail("", buf))?;
conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::copy_done(buf)))?;
conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?;
conn.stream.flush()?;
match conn.read_message()? {
backend::Message::ErrorResponse(_) => {
// expected from the CopyFail
}
_ => {
conn.desynchronized = true;
return Err(bad_response().into());
}
}
conn.wait_for_ready()?;
return Err(err.into());
}
}
}
conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::copy_done(buf)))?;
conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?;
conn.stream.flush()?;
let num = match conn.read_message()? {
backend::Message::CommandComplete(body) => parse_update_count(body.tag()?),
backend::Message::ErrorResponse(body) => {
conn.wait_for_ready()?;
return Err(err(&mut body.fields()));
}
_ => {
conn.desynchronized = true;
return Err(bad_response().into());
}
};
conn.wait_for_ready()?;
Ok(num)
}
/// Executes a `COPY TO STDOUT` statement, passing the resulting data to
/// the provided writer and returning the number of rows received.
///
/// See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html)
/// for details on the data format.
///
/// If the statement is not a `COPY TO STDOUT` statement it will still be
/// executed and this method will return an error.
///
/// # Examples
///
/// ```rust,no_run
/// # use postgres::{Connection, TlsMode};
/// # let conn = Connection::connect("", TlsMode::None).unwrap();
/// conn.batch_execute("
/// CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR);
/// INSERT INTO people (id, name) VALUES (1, 'john'), (2, 'jane');").unwrap();
/// let stmt = conn.prepare("COPY people TO STDOUT").unwrap();
/// let mut buf = vec![];
/// stmt.copy_out(&[], &mut buf).unwrap();
/// assert_eq!(buf, b"1\tjohn\n2\tjane\n");
/// ```
pub fn copy_out<'a, W: WriteWithInfo>(&'a self, params: &[&ToSql], w: &mut W) -> Result<u64> {
let mut conn = self.conn.0.borrow_mut();
conn.raw_execute(&self.info.name, "", 0, self.param_types(), params)?;
let (format, column_formats) = match conn.read_message()? {
backend::Message::CopyOutResponse(body) => {
let format = body.format();
let column_formats = body
.column_formats()
.map(|f| Format::from_u16(f))
.collect()?;
(format, column_formats)
}
backend::Message::CopyInResponse(_) => {
conn.stream
.write_message(|buf| frontend::copy_fail("", buf))?;
conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::copy_done(buf)))?;
conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?;
conn.stream.flush()?;
match conn.read_message()? {
backend::Message::ErrorResponse(_) => {
// expected from the CopyFail
}
_ => {
conn.desynchronized = true;
return Err(bad_response().into());
}
}
conn.wait_for_ready()?;
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"called `copy_out` on a non-`COPY TO STDOUT` statement",
).into());
}
backend::Message::ErrorResponse(body) => {
conn.wait_for_ready()?;
return Err(err(&mut body.fields()));
}
_ => loop {
if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"called `copy_out` on a non-`COPY TO STDOUT` statement",
).into());
}
},
};
let info = CopyInfo {
format: Format::from_u16(format as u16),
column_formats: column_formats,
};
let count;
loop {
match conn.read_message()? {
backend::Message::CopyData(body) => {
let mut data = body.data();
while !data.is_empty() {
match w.write_with_info(data, &info) {
Ok(n) => data = &data[n..],
Err(e) => loop {
if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
return Err(e.into());
}
},
}
}
}
backend::Message::CopyDone => {}
backend::Message::CommandComplete(body) => {
count = parse_update_count(body.tag()?);
break;
}
backend::Message::ErrorResponse(body) => loop {
if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
return Err(err(&mut body.fields()));
}
},
_ => loop {
if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
return Err(bad_response().into());
}
},
}
}
conn.wait_for_ready()?;
Ok(count)
}
/// Consumes the statement, clearing it from the Postgres session.
///
/// If this statement was created via the `prepare_cached` method, `finish`
/// does nothing.
///
/// Functionally identical to the `Drop` implementation of the
/// `Statement` except that it returns any error to the caller.
pub fn finish(mut self) -> Result<()> {
self.finish_inner()
}
}
fn fill_copy_buf<R: ReadWithInfo>(buf: &mut [u8], r: &mut R, info: &CopyInfo) -> io::Result<usize> {
let mut nread = 0;
while nread < buf.len() {
match r.read_with_info(&mut buf[nread..], info) {
Ok(0) => break,
Ok(n) => nread += n,
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(nread)
}
/// A struct containing information relevant for a `COPY` operation.
pub struct CopyInfo {
format: Format,
column_formats: Vec<Format>,
}
impl CopyInfo {
/// Returns the format of the overall data.
pub fn format(&self) -> Format {
self.format
}
/// Returns the format of the individual columns.
pub fn column_formats(&self) -> &[Format] {
&self.column_formats
}
}
/// Like `Read` except that a `CopyInfo` object is provided as well.
///
/// All types that implement `Read` also implement this trait.
pub trait ReadWithInfo {
/// Like `Read::read`.
fn read_with_info(&mut self, buf: &mut [u8], info: &CopyInfo) -> io::Result<usize>;
}
impl<R: Read> ReadWithInfo for R {
fn read_with_info(&mut self, buf: &mut [u8], _: &CopyInfo) -> io::Result<usize> {
self.read(buf)
}
}
/// Like `Write` except that a `CopyInfo` object is provided as well.
///
/// All types that implement `Write` also implement this trait.
pub trait WriteWithInfo {
/// Like `Write::write`.
fn write_with_info(&mut self, buf: &[u8], info: &CopyInfo) -> io::Result<usize>;
}
impl<W: Write> WriteWithInfo for W {
fn write_with_info(&mut self, buf: &[u8], _: &CopyInfo) -> io::Result<usize> {
self.write(buf)
}
}
/// The format of a portion of COPY query data.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Format {
/// A text based format.
Text,
/// A binary format.
Binary,
}
impl Format {
fn from_u16(value: u16) -> Format {
match value {
0 => Format::Text,
_ => Format::Binary,
}
}
}
fn parse_update_count(tag: &str) -> u64 {
tag.split(' ').last().unwrap().parse().unwrap_or(0)
}

View File

@ -1,191 +0,0 @@
//! Query result rows.
use postgres_shared::rows::RowData;
use std::fmt;
use std::slice;
use std::str;
#[doc(inline)]
pub use postgres_shared::rows::RowIndex;
use stmt::Column;
use {error, Result};
/// The resulting rows of a query.
pub struct TextRows {
columns: Vec<Column>,
data: Vec<RowData>,
}
impl fmt::Debug for TextRows {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("TextRows")
.field("columns", &self.columns())
.field("rows", &self.data.len())
.finish()
}
}
impl TextRows {
pub(crate) fn new(columns: Vec<Column>, data: Vec<RowData>) -> TextRows {
TextRows {
columns: columns,
data: data,
}
}
/// Returns a slice describing the columns of the `TextRows`.
pub fn columns(&self) -> &[Column] {
&self.columns[..]
}
/// Returns the number of rows present.
pub fn len(&self) -> usize {
self.data.len()
}
/// Determines if there are any rows present.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns a specific `TextRow`.
///
/// # Panics
///
/// Panics if `idx` is out of bounds.
pub fn get<'a>(&'a self, idx: usize) -> TextRow<'a> {
TextRow {
columns: &self.columns,
data: &self.data[idx],
}
}
/// Returns an iterator over the `TextRow`s.
pub fn iter<'a>(&'a self) -> Iter<'a> {
Iter {
columns: self.columns(),
iter: self.data.iter(),
}
}
}
impl<'a> IntoIterator for &'a TextRows {
type Item = TextRow<'a>;
type IntoIter = Iter<'a>;
fn into_iter(self) -> Iter<'a> {
self.iter()
}
}
/// An iterator over `TextRow`s.
pub struct Iter<'a> {
columns: &'a [Column],
iter: slice::Iter<'a, RowData>,
}
impl<'a> Iterator for Iter<'a> {
type Item = TextRow<'a>;
fn next(&mut self) -> Option<TextRow<'a>> {
self.iter.next().map(|row| TextRow {
columns: self.columns,
data: row,
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<'a> DoubleEndedIterator for Iter<'a> {
fn next_back(&mut self) -> Option<TextRow<'a>> {
self.iter.next_back().map(|row| TextRow {
columns: self.columns,
data: row,
})
}
}
impl<'a> ExactSizeIterator for Iter<'a> {}
/// A single result row of a query.
pub struct TextRow<'a> {
columns: &'a [Column],
data: &'a RowData,
}
impl<'a> fmt::Debug for TextRow<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("TextRow")
.field("columns", &self.columns)
.finish()
}
}
impl<'a> TextRow<'a> {
/// Returns the number of values in the row.
pub fn len(&self) -> usize {
self.data.len()
}
/// Determines if there are any values in the row.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns a slice describing the columns of the `TextRow`.
pub fn columns(&self) -> &[Column] {
self.columns
}
/// Retrieve the contents of a field of a row
///
/// A field can be accessed by the name or index of its column, though
/// access by index is more efficient. Rows are 0-indexed.
///
/// # Panics
///
/// Panics if the index does not reference a column
pub fn get<I>(&self, idx: I) -> &str
where
I: RowIndex + fmt::Debug,
{
match self.get_inner(&idx) {
Some(Ok(value)) => value,
Some(Err(err)) => panic!("error retrieving column {:?}: {:?}", idx, err),
None => panic!("no such column {:?}", idx),
}
}
/// Retrieves the contents of a field of the row.
///
/// A field can be accessed by the name or index of its column, though
/// access by index is more efficient. Rows are 0-indexed.
///
/// Returns None if the index does not reference a column, Some(Err(..)) if
/// there was an error parsing the result as UTF-8, and Some(Ok(..)) on
/// success.
pub fn get_opt<I>(&self, idx: I) -> Option<Result<&str>>
where
I: RowIndex,
{
self.get_inner(&idx)
}
fn get_inner<I>(&self, idx: &I) -> Option<Result<&str>>
where
I: RowIndex,
{
let idx = match idx.__idx(self.columns) {
Some(idx) => idx,
None => return None,
};
self.data
.get(idx)
.map(|s| str::from_utf8(s).map_err(|e| error::conversion(Box::new(e))))
}
}

View File

@ -1,50 +0,0 @@
//! Types and traits for TLS support.
pub use priv_io::Stream;
use std::error::Error;
use std::fmt;
use std::io::prelude::*;
/// A trait implemented by TLS streams.
pub trait TlsStream: fmt::Debug + Read + Write + Send {
/// Returns a reference to the underlying `Stream`.
fn get_ref(&self) -> &Stream;
/// Returns a mutable reference to the underlying `Stream`.
fn get_mut(&mut self) -> &mut Stream;
/// Returns the data associated with the `tls-server-end-point` channel binding type as
/// described in [RFC 5929], if supported.
///
/// An implementation only needs to support one of this or `tls_unique`.
///
/// [RFC 5929]: https://tools.ietf.org/html/rfc5929
fn tls_server_end_point(&self) -> Option<Vec<u8>> {
None
}
}
/// A trait implemented by types that can initiate a TLS session over a Postgres
/// stream.
pub trait TlsHandshake: fmt::Debug {
/// Performs a client-side TLS handshake, returning a wrapper around the
/// provided stream.
///
/// The host portion of the connection parameters is provided for hostname
/// verification.
fn tls_handshake(
&self,
host: &str,
stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Sync + Send>>;
}
impl<T: TlsHandshake + ?Sized> TlsHandshake for Box<T> {
fn tls_handshake(
&self,
host: &str,
stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Sync + Send>> {
(**self).tls_handshake(host, stream)
}
}

View File

@ -1,327 +0,0 @@
//! Transactions
use std::cell::Cell;
use std::fmt;
use rows::Rows;
use stmt::Statement;
use text_rows::TextRows;
use types::ToSql;
use {bad_response, Connection, Result};
/// An enumeration of transaction isolation levels.
///
/// See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/transaction-iso.html)
/// for full details on the semantics of each level.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IsolationLevel {
/// The "read uncommitted" level.
///
/// In current versions of Postgres, this behaves identically to
/// `ReadCommitted`.
ReadUncommitted,
/// The "read committed" level.
///
/// This is the default isolation level in Postgres.
ReadCommitted,
/// The "repeatable read" level.
RepeatableRead,
/// The "serializable" level.
Serializable,
}
impl IsolationLevel {
pub(crate) fn new(raw: &str) -> Result<IsolationLevel> {
if raw.eq_ignore_ascii_case("READ UNCOMMITTED") {
Ok(IsolationLevel::ReadUncommitted)
} else if raw.eq_ignore_ascii_case("READ COMMITTED") {
Ok(IsolationLevel::ReadCommitted)
} else if raw.eq_ignore_ascii_case("REPEATABLE READ") {
Ok(IsolationLevel::RepeatableRead)
} else if raw.eq_ignore_ascii_case("SERIALIZABLE") {
Ok(IsolationLevel::Serializable)
} else {
Err(bad_response().into())
}
}
fn to_sql(&self) -> &'static str {
match *self {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
}
}
}
/// Configuration of a transaction.
#[derive(Debug)]
pub struct Config {
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
deferrable: Option<bool>,
}
impl Default for Config {
fn default() -> Config {
Config {
isolation_level: None,
read_only: None,
deferrable: None,
}
}
}
impl Config {
pub(crate) fn build_command(&self, s: &mut String) {
let mut first = true;
if let Some(isolation_level) = self.isolation_level {
s.push_str(" ISOLATION LEVEL ");
s.push_str(isolation_level.to_sql());
first = false;
}
if let Some(read_only) = self.read_only {
if !first {
s.push(',');
}
if read_only {
s.push_str(" READ ONLY");
} else {
s.push_str(" READ WRITE");
}
first = false;
}
if let Some(deferrable) = self.deferrable {
if !first {
s.push(',');
}
if deferrable {
s.push_str(" DEFERRABLE");
} else {
s.push_str(" NOT DEFERRABLE");
}
}
}
/// Creates a new `Config` with no configuration overrides.
pub fn new() -> Config {
Config::default()
}
/// Sets the isolation level of the configuration.
pub fn isolation_level(&mut self, isolation_level: IsolationLevel) -> &mut Config {
self.isolation_level = Some(isolation_level);
self
}
/// Sets the read-only property of a transaction.
///
/// If enabled, a transaction will be unable to modify any persistent
/// database state.
pub fn read_only(&mut self, read_only: bool) -> &mut Config {
self.read_only = Some(read_only);
self
}
/// Sets the deferrable property of a transaction.
///
/// If enabled in a read only, serializable transaction, the transaction may
/// block when created, after which it will run without the normal overhead
/// of a serializable transaction and will not be forced to roll back due
/// to serialization failures.
pub fn deferrable(&mut self, deferrable: bool) -> &mut Config {
self.deferrable = Some(deferrable);
self
}
}
/// A transaction on a database connection.
///
/// The transaction will roll back by default.
pub struct Transaction<'conn> {
conn: &'conn Connection,
depth: u32,
savepoint_name: Option<String>,
commit: Cell<bool>,
finished: bool,
}
impl<'a> fmt::Debug for Transaction<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Transaction")
.field("commit", &self.commit.get())
.field("depth", &self.depth)
.finish()
}
}
impl<'conn> Drop for Transaction<'conn> {
fn drop(&mut self) {
if !self.finished {
let _ = self.finish_inner();
}
}
}
impl<'conn> Transaction<'conn> {
pub(crate) fn new(conn: &'conn Connection, depth: u32) -> Transaction<'conn> {
Transaction {
conn: conn,
depth: depth,
savepoint_name: None,
commit: Cell::new(false),
finished: false,
}
}
pub(crate) fn conn(&self) -> &'conn Connection {
self.conn
}
pub(crate) fn depth(&self) -> u32 {
self.depth
}
fn finish_inner(&mut self) -> Result<()> {
let mut conn = self.conn.0.borrow_mut();
debug_assert!(self.depth == conn.trans_depth);
conn.trans_depth -= 1;
match (self.commit.get(), &self.savepoint_name) {
(false, &Some(ref sp)) => conn.quick_query(&format!("ROLLBACK TO {}", sp))?,
(false, &None) => conn.quick_query("ROLLBACK")?,
(true, &Some(ref sp)) => conn.quick_query(&format!("RELEASE {}", sp))?,
(true, &None) => conn.quick_query("COMMIT")?,
};
Ok(())
}
/// Like `Connection::prepare`.
pub fn prepare(&self, query: &str) -> Result<Statement<'conn>> {
self.conn.prepare(query)
}
/// Like `Connection::prepare_cached`.
///
/// # Note
///
/// The statement will be cached for the duration of the
/// connection, not just the duration of this transaction.
pub fn prepare_cached(&self, query: &str) -> Result<Statement<'conn>> {
self.conn.prepare_cached(query)
}
/// Like `Connection::execute`.
pub fn execute(&self, query: &str, params: &[&ToSql]) -> Result<u64> {
self.conn.execute(query, params)
}
/// Like `Connection::query`.
pub fn query<'a>(&'a self, query: &str, params: &[&ToSql]) -> Result<Rows> {
self.conn.query(query, params)
}
/// Like `Connection::batch_execute`.
#[deprecated(since = "0.15.3", note = "please use `simple_query` instead")]
pub fn batch_execute(&self, query: &str) -> Result<()> {
self.simple_query(query).map(|_| ())
}
/// Like `Connection::simple_query`.
pub fn simple_query(&self, query: &str) -> Result<Vec<TextRows>> {
self.conn.simple_query(query)
}
/// Like `Connection::transaction`, but creates a nested transaction via
/// a savepoint.
///
/// # Panics
///
/// Panics if there is an active nested transaction.
pub fn transaction<'a>(&'a self) -> Result<Transaction<'a>> {
self.savepoint(format!("sp_{}", self.depth()))
}
/// Like `Connection::transaction`, but creates a nested transaction via
/// a savepoint with the specified name.
///
/// # Panics
///
/// Panics if there is an active nested transaction.
#[inline]
pub fn savepoint<'a, I>(&'a self, name: I) -> Result<Transaction<'a>>
where
I: Into<String>,
{
self._savepoint(name.into())
}
fn _savepoint<'a>(&'a self, name: String) -> Result<Transaction<'a>> {
let mut conn = self.conn.0.borrow_mut();
check_desync!(conn);
assert!(
conn.trans_depth == self.depth,
"`savepoint` may only be called on the active transaction"
);
conn.quick_query(&format!("SAVEPOINT {}", name))?;
conn.trans_depth += 1;
Ok(Transaction {
conn: self.conn,
depth: self.depth + 1,
savepoint_name: Some(name),
commit: Cell::new(false),
finished: false,
})
}
/// Returns a reference to the `Transaction`'s `Connection`.
pub fn connection(&self) -> &'conn Connection {
self.conn
}
/// Like `Connection::is_active`.
pub fn is_active(&self) -> bool {
self.conn.0.borrow().trans_depth == self.depth
}
/// Alters the configuration of the active transaction.
pub fn set_config(&self, config: &Config) -> Result<()> {
let mut command = "SET TRANSACTION".to_owned();
config.build_command(&mut command);
self.simple_query(&command).map(|_| ())
}
/// Determines if the transaction is currently set to commit or roll back.
pub fn will_commit(&self) -> bool {
self.commit.get()
}
/// Sets the transaction to commit at its completion.
pub fn set_commit(&self) {
self.commit.set(true);
}
/// Sets the transaction to roll back at its completion.
pub fn set_rollback(&self) {
self.commit.set(false);
}
/// A convenience method which consumes and commits a transaction.
pub fn commit(self) -> Result<()> {
self.set_commit();
self.finish()
}
/// Consumes the transaction, committing or rolling it back as appropriate.
///
/// Functionally equivalent to the `Drop` implementation of `Transaction`
/// except that it returns any error to the caller.
pub fn finish(mut self) -> Result<()> {
self.finished = true;
self.finish_inner()
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,30 +0,0 @@
extern crate bit_vec;
use self::bit_vec::BitVec;
use types::test_type;
#[test]
fn test_bit_params() {
let mut bv = BitVec::from_bytes(&[0b0110_1001, 0b0000_0111]);
bv.pop();
bv.pop();
test_type(
"BIT(14)",
&[(Some(bv), "B'01101001000001'"), (None, "NULL")],
)
}
#[test]
fn test_varbit_params() {
let mut bv = BitVec::from_bytes(&[0b0110_1001, 0b0000_0111]);
bv.pop();
bv.pop();
test_type(
"VARBIT",
&[
(Some(bv), "B'01101001000001'"),
(Some(BitVec::from_bytes(&[])), "B''"),
(None, "NULL"),
],
)
}

View File

@ -1,150 +0,0 @@
extern crate chrono;
use self::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use types::test_type;
use postgres::types::{Date, Timestamp};
#[test]
fn test_naive_date_time_params() {
fn make_check<'a>(time: &'a str) -> (Option<NaiveDateTime>, &'a str) {
(
Some(NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap()),
time,
)
}
test_type(
"TIMESTAMP",
&[
make_check("'1970-01-01 00:00:00.010000000'"),
make_check("'1965-09-25 11:19:33.100314000'"),
make_check("'2010-02-09 23:11:45.120200000'"),
(None, "NULL"),
],
);
}
#[test]
fn test_with_special_naive_date_time_params() {
fn make_check<'a>(time: &'a str) -> (Timestamp<NaiveDateTime>, &'a str) {
(
Timestamp::Value(
NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(),
),
time,
)
}
test_type(
"TIMESTAMP",
&[
make_check("'1970-01-01 00:00:00.010000000'"),
make_check("'1965-09-25 11:19:33.100314000'"),
make_check("'2010-02-09 23:11:45.120200000'"),
(Timestamp::PosInfinity, "'infinity'"),
(Timestamp::NegInfinity, "'-infinity'"),
],
);
}
#[test]
fn test_date_time_params() {
fn make_check<'a>(time: &'a str) -> (Option<DateTime<Utc>>, &'a str) {
(
Some(
Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'")
.unwrap(),
),
time,
)
}
test_type(
"TIMESTAMP WITH TIME ZONE",
&[
make_check("'1970-01-01 00:00:00.010000000'"),
make_check("'1965-09-25 11:19:33.100314000'"),
make_check("'2010-02-09 23:11:45.120200000'"),
(None, "NULL"),
],
);
}
#[test]
fn test_with_special_date_time_params() {
fn make_check<'a>(time: &'a str) -> (Timestamp<DateTime<Utc>>, &'a str) {
(
Timestamp::Value(
Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'")
.unwrap(),
),
time,
)
}
test_type(
"TIMESTAMP WITH TIME ZONE",
&[
make_check("'1970-01-01 00:00:00.010000000'"),
make_check("'1965-09-25 11:19:33.100314000'"),
make_check("'2010-02-09 23:11:45.120200000'"),
(Timestamp::PosInfinity, "'infinity'"),
(Timestamp::NegInfinity, "'-infinity'"),
],
);
}
#[test]
fn test_date_params() {
fn make_check<'a>(time: &'a str) -> (Option<NaiveDate>, &'a str) {
(
Some(NaiveDate::parse_from_str(time, "'%Y-%m-%d'").unwrap()),
time,
)
}
test_type(
"DATE",
&[
make_check("'1970-01-01'"),
make_check("'1965-09-25'"),
make_check("'2010-02-09'"),
(None, "NULL"),
],
);
}
#[test]
fn test_with_special_date_params() {
fn make_check<'a>(date: &'a str) -> (Date<NaiveDate>, &'a str) {
(
Date::Value(NaiveDate::parse_from_str(date, "'%Y-%m-%d'").unwrap()),
date,
)
}
test_type(
"DATE",
&[
make_check("'1970-01-01'"),
make_check("'1965-09-25'"),
make_check("'2010-02-09'"),
(Date::PosInfinity, "'infinity'"),
(Date::NegInfinity, "'-infinity'"),
],
);
}
#[test]
fn test_time_params() {
fn make_check<'a>(time: &'a str) -> (Option<NaiveTime>, &'a str) {
(
Some(NaiveTime::parse_from_str(time, "'%H:%M:%S.%f'").unwrap()),
time,
)
}
test_type(
"TIME",
&[
make_check("'00:00:00.010000000'"),
make_check("'11:19:33.100314000'"),
make_check("'23:11:45.120200000'"),
(None, "NULL"),
],
);
}

View File

@ -1,17 +0,0 @@
extern crate eui48;
use types::test_type;
#[test]
fn test_eui48_params() {
test_type(
"MACADDR",
&[
(
Some(eui48::MacAddress::parse_str("12-34-56-AB-CD-EF").unwrap()),
"'12-34-56-ab-cd-ef'",
),
(None, "NULL"),
],
)
}

View File

@ -1,58 +0,0 @@
extern crate geo;
use self::geo::{Coordinate, LineString, Point, Rect};
use types::test_type;
#[test]
fn test_point_params() {
test_type(
"POINT",
&[
(Some(Point::new(0.0, 0.0)), "POINT(0, 0)"),
(Some(Point::new(-3.14, 1.618)), "POINT(-3.14, 1.618)"),
(None, "NULL"),
],
);
}
#[test]
fn test_box_params() {
test_type(
"BOX",
&[
(
Some(Rect {
min: Coordinate { x: -3.14, y: 1.618 },
max: Coordinate {
x: 160.0,
y: 69701.5615,
},
}),
"BOX(POINT(160.0, 69701.5615), POINT(-3.14, 1.618))",
),
(None, "NULL"),
],
);
}
#[test]
fn test_path_params() {
let points = vec![
Coordinate { x: 0., y: 0. },
Coordinate { x: -3.14, y: 1.618 },
Coordinate {
x: 160.0,
y: 69701.5615,
},
];
test_type(
"PATH",
&[
(
Some(LineString(points)),
"path '((0, 0), (-3.14, 1.618), (160.0, 69701.5615))'",
),
(None, "NULL"),
],
);
}

View File

@ -1,530 +0,0 @@
use std::collections::HashMap;
use std::error;
use std::f32;
use std::f64;
use std::fmt;
use std::result;
use std::time::{Duration, UNIX_EPOCH};
use postgres::types::{FromSql, FromSqlOwned, IsNull, Kind, ToSql, Type, WrongType};
use postgres::{Connection, TlsMode};
#[cfg(feature = "with-bit-vec-0.5")]
mod bit_vec;
#[cfg(feature = "with-chrono-0.4")]
mod chrono;
#[cfg(feature = "with-eui48-0.3")]
mod eui48;
#[cfg(feature = "with-geo-0.10")]
mod geo;
#[cfg(feature = "with-serde_json-1")]
mod serde_json;
#[cfg(feature = "with-uuid-0.6")]
mod uuid;
fn test_type<T, S>(sql_type: &str, checks: &[(T, S)])
where
T: PartialEq + for<'a> FromSqlOwned + ToSql,
S: fmt::Display,
{
let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
));
for &(ref val, ref repr) in checks.iter() {
let stmt = or_panic!(conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type)));
let rows = or_panic!(stmt.query(&[]));
let row = rows.iter().next().unwrap();
let result = row.get(0);
assert_eq!(val, &result);
let stmt = or_panic!(conn.prepare(&*format!("SELECT $1::{}", sql_type)));
let rows = or_panic!(stmt.query(&[val]));
let row = rows.iter().next().unwrap();
let result = row.get(0);
assert_eq!(val, &result);
}
}
#[test]
fn test_ref_tosql() {
let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
));
let stmt = conn.prepare("SELECT $1::Int").unwrap();
let num: &ToSql = &&7;
stmt.query(&[num]).unwrap();
}
#[test]
fn test_bool_params() {
test_type(
"BOOL",
&[(Some(true), "'t'"), (Some(false), "'f'"), (None, "NULL")],
);
}
#[test]
fn test_i8_params() {
test_type("\"char\"", &[(Some('a' as i8), "'a'"), (None, "NULL")]);
}
#[test]
fn test_name_params() {
test_type(
"NAME",
&[
(Some("hello world".to_owned()), "'hello world'"),
(
Some("イロハニホヘト チリヌルヲ".to_owned()),
"'イロハニホヘト チリヌルヲ'",
),
(None, "NULL"),
],
);
}
#[test]
fn test_i16_params() {
test_type(
"SMALLINT",
&[
(Some(15001i16), "15001"),
(Some(-15001i16), "-15001"),
(None, "NULL"),
],
);
}
#[test]
fn test_i32_params() {
test_type(
"INT",
&[
(Some(2147483548i32), "2147483548"),
(Some(-2147483548i32), "-2147483548"),
(None, "NULL"),
],
);
}
#[test]
fn test_oid_params() {
test_type(
"OID",
&[
(Some(2147483548u32), "2147483548"),
(Some(4000000000), "4000000000"),
(None, "NULL"),
],
);
}
#[test]
fn test_i64_params() {
test_type(
"BIGINT",
&[
(Some(9223372036854775708i64), "9223372036854775708"),
(Some(-9223372036854775708i64), "-9223372036854775708"),
(None, "NULL"),
],
);
}
#[test]
fn test_f32_params() {
test_type(
"REAL",
&[
(Some(f32::INFINITY), "'infinity'"),
(Some(f32::NEG_INFINITY), "'-infinity'"),
(Some(1000.55), "1000.55"),
(None, "NULL"),
],
);
}
#[test]
fn test_f64_params() {
test_type(
"DOUBLE PRECISION",
&[
(Some(f64::INFINITY), "'infinity'"),
(Some(f64::NEG_INFINITY), "'-infinity'"),
(Some(10000.55), "10000.55"),
(None, "NULL"),
],
);
}
#[test]
fn test_varchar_params() {
test_type(
"VARCHAR",
&[
(Some("hello world".to_owned()), "'hello world'"),
(
Some("イロハニホヘト チリヌルヲ".to_owned()),
"'イロハニホヘト チリヌルヲ'",
),
(None, "NULL"),
],
);
}
#[test]
fn test_text_params() {
test_type(
"TEXT",
&[
(Some("hello world".to_owned()), "'hello world'"),
(
Some("イロハニホヘト チリヌルヲ".to_owned()),
"'イロハニホヘト チリヌルヲ'",
),
(None, "NULL"),
],
);
}
#[test]
fn test_borrowed_text() {
let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
));
let rows = or_panic!(conn.query("SELECT 'foo'", &[]));
let row = rows.get(0);
let s: &str = row.get(0);
assert_eq!(s, "foo");
}
#[test]
fn test_bpchar_params() {
let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
));
or_panic!(conn.execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY,
b CHAR(5)
)",
&[],
));
or_panic!(conn.execute(
"INSERT INTO foo (b) VALUES ($1), ($2), ($3)",
&[&Some("12345"), &Some("123"), &None::<&'static str>],
));
let stmt = or_panic!(conn.prepare("SELECT b FROM foo ORDER BY id"));
let res = or_panic!(stmt.query(&[]));
assert_eq!(
vec![Some("12345".to_owned()), Some("123 ".to_owned()), None],
res.iter().map(|row| row.get(0)).collect::<Vec<_>>()
);
}
#[test]
fn test_citext_params() {
let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
));
or_panic!(conn.execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY,
b CITEXT
)",
&[],
));
or_panic!(conn.execute(
"INSERT INTO foo (b) VALUES ($1), ($2), ($3)",
&[&Some("foobar"), &Some("FooBar"), &None::<&'static str>],
));
let stmt = or_panic!(conn.prepare("SELECT id FROM foo WHERE b = 'FOOBAR' ORDER BY id",));
let res = or_panic!(stmt.query(&[]));
assert_eq!(
vec![Some(1i32), Some(2i32)],
res.iter().map(|row| row.get(0)).collect::<Vec<_>>()
);
}
#[test]
fn test_bytea_params() {
test_type(
"BYTEA",
&[
(Some(vec![0u8, 1, 2, 3, 254, 255]), "'\\x00010203feff'"),
(None, "NULL"),
],
);
}
#[test]
fn test_borrowed_bytea() {
let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
));
let rows = or_panic!(conn.query("SELECT 'foo'::BYTEA", &[]));
let row = rows.get(0);
let s: &[u8] = row.get(0);
assert_eq!(s, b"foo");
}
#[test]
fn test_hstore_params() {
macro_rules! make_map {
($($k:expr => $v:expr),+) => ({
let mut map = HashMap::new();
$(map.insert($k, $v);)+
map
})
}
test_type(
"hstore",
&[
(
Some(make_map!("a".to_owned() => Some("1".to_owned()))),
"'a=>1'",
),
(
Some(make_map!("hello".to_owned() => Some("world!".to_owned()),
"hola".to_owned() => Some("mundo!".to_owned()),
"what".to_owned() => None)),
"'hello=>world!,hola=>mundo!,what=>NULL'",
),
(None, "NULL"),
],
);
}
#[test]
fn test_array_params() {
test_type(
"integer[]",
&[
(Some(vec![1i32, 2i32]), "ARRAY[1,2]"),
(Some(vec![1i32]), "ARRAY[1]"),
(Some(vec![]), "ARRAY[]"),
(None, "NULL"),
],
);
}
fn test_nan_param<T>(sql_type: &str)
where
T: PartialEq + ToSql + FromSqlOwned,
{
let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
));
let stmt = or_panic!(conn.prepare(&*format!("SELECT 'NaN'::{}", sql_type)));
let result = or_panic!(stmt.query(&[]));
let val: T = result.iter().next().unwrap().get(0);
assert!(val != val);
}
#[test]
fn test_f32_nan_param() {
test_nan_param::<f32>("REAL");
}
#[test]
fn test_f64_nan_param() {
test_nan_param::<f64>("DOUBLE PRECISION");
}
#[test]
fn test_pg_database_datname() {
let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
));
let stmt = or_panic!(conn.prepare("SELECT datname FROM pg_database"));
let result = or_panic!(stmt.query(&[]));
let next = result.iter().next().unwrap();
or_panic!(next.get_opt::<_, String>(0).unwrap());
or_panic!(next.get_opt::<_, String>("datname").unwrap());
}
#[test]
fn test_slice() {
let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap();
conn.simple_query(
"CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, f VARCHAR);
INSERT INTO foo (f) VALUES ('a'), ('b'), ('c'), ('d');",
).unwrap();
let stmt = conn
.prepare("SELECT f FROM foo WHERE id = ANY($1)")
.unwrap();
let result = stmt.query(&[&&[1i32, 3, 4][..]]).unwrap();
assert_eq!(
vec!["a".to_owned(), "c".to_owned(), "d".to_owned()],
result
.iter()
.map(|r| r.get::<_, String>(0))
.collect::<Vec<_>>()
);
}
#[test]
fn test_slice_wrong_type() {
let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap();
conn.simple_query("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)")
.unwrap();
let stmt = conn
.prepare("SELECT * FROM foo WHERE id = ANY($1)")
.unwrap();
let err = stmt.query(&[&&["hi"][..]]).unwrap_err();
match err.as_conversion() {
Some(e) if e.is::<WrongType>() => {}
_ => panic!("Unexpected error {:?}", err),
};
}
#[test]
fn test_slice_range() {
let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap();
let stmt = conn.prepare("SELECT $1::INT8RANGE").unwrap();
let err = stmt.query(&[&&[1i64][..]]).unwrap_err();
match err.as_conversion() {
Some(e) if e.is::<WrongType>() => {}
_ => panic!("Unexpected error {:?}", err),
};
}
#[test]
fn domain() {
#[derive(Debug, PartialEq)]
struct SessionId(Vec<u8>);
impl ToSql for SessionId {
fn to_sql(
&self,
ty: &Type,
out: &mut Vec<u8>,
) -> result::Result<IsNull, Box<error::Error + Sync + Send>> {
let inner = match *ty.kind() {
Kind::Domain(ref inner) => inner,
_ => unreachable!(),
};
self.0.to_sql(inner, out)
}
fn accepts(ty: &Type) -> bool {
ty.name() == "session_id" && match *ty.kind() {
Kind::Domain(_) => true,
_ => false,
}
}
to_sql_checked!();
}
impl<'a> FromSql<'a> for SessionId {
fn from_sql(
ty: &Type,
raw: &[u8],
) -> result::Result<Self, Box<error::Error + Sync + Send>> {
Vec::<u8>::from_sql(ty, raw).map(SessionId)
}
fn accepts(ty: &Type) -> bool {
// This is super weird!
<Vec<u8> as FromSql>::accepts(ty)
}
}
let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap();
conn.simple_query(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);
CREATE TABLE pg_temp.foo (id pg_temp.session_id);",
).unwrap();
let id = SessionId(b"0123456789abcdef".to_vec());
conn.execute("INSERT INTO pg_temp.foo (id) VALUES ($1)", &[&id])
.unwrap();
let rows = conn.query("SELECT id FROM pg_temp.foo", &[]).unwrap();
assert_eq!(id, rows.get(0).get(0));
}
#[test]
fn composite() {
let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap();
conn.simple_query(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier INTEGER,
price NUMERIC
)",
).unwrap();
let stmt = conn.prepare("SELECT $1::inventory_item").unwrap();
let type_ = &stmt.param_types()[0];
assert_eq!(type_.name(), "inventory_item");
match *type_.kind() {
Kind::Composite(ref fields) => {
assert_eq!(fields[0].name(), "name");
assert_eq!(fields[0].type_(), &Type::TEXT);
assert_eq!(fields[1].name(), "supplier");
assert_eq!(fields[1].type_(), &Type::INT4);
assert_eq!(fields[2].name(), "price");
assert_eq!(fields[2].type_(), &Type::NUMERIC);
}
ref t => panic!("bad type {:?}", t),
}
}
#[test]
fn enum_() {
let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap();
conn.simple_query("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy');")
.unwrap();
let stmt = conn.prepare("SELECT $1::mood").unwrap();
let type_ = &stmt.param_types()[0];
assert_eq!(type_.name(), "mood");
match *type_.kind() {
Kind::Enum(ref variants) => {
assert_eq!(
variants,
&["sad".to_owned(), "ok".to_owned(), "happy".to_owned()]
);
}
_ => panic!("bad type"),
}
}
#[test]
fn system_time() {
test_type(
"TIMESTAMP",
&[
(
Some(UNIX_EPOCH + Duration::from_millis(1_010)),
"'1970-01-01 00:00:01.01'",
),
(
Some(UNIX_EPOCH - Duration::from_millis(1_010)),
"'1969-12-31 23:59:58.99'",
),
(
Some(UNIX_EPOCH + Duration::from_millis(946684800 * 1000 + 1_010)),
"'2000-01-01 00:00:01.01'",
),
(None, "NULL"),
],
);
}

View File

@ -1,40 +0,0 @@
extern crate serde_json;
use self::serde_json::Value;
use types::test_type;
#[test]
fn test_json_params() {
test_type(
"JSON",
&[
(
Some(serde_json::from_str::<Value>("[10, 11, 12]").unwrap()),
"'[10, 11, 12]'",
),
(
Some(serde_json::from_str::<Value>("{\"f\": \"asd\"}").unwrap()),
"'{\"f\": \"asd\"}'",
),
(None, "NULL"),
],
)
}
#[test]
fn test_jsonb_params() {
test_type(
"JSONB",
&[
(
Some(serde_json::from_str::<Value>("[10, 11, 12]").unwrap()),
"'[10, 11, 12]'",
),
(
Some(serde_json::from_str::<Value>("{\"f\": \"asd\"}").unwrap()),
"'{\"f\": \"asd\"}'",
),
(None, "NULL"),
],
)
}

View File

@ -1,17 +0,0 @@
extern crate uuid;
use types::test_type;
#[test]
fn test_uuid_params() {
test_type(
"UUID",
&[
(
Some(uuid::Uuid::parse_str("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11").unwrap()),
"'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'",
),
(None, "NULL"),
],
)
}