Detect direct queries to COPY FROM

This commit is contained in:
Steven Fackler 2014-09-29 21:05:42 -07:00
parent 945714b692
commit a4a625a30c
3 changed files with 68 additions and 1 deletions

View File

@ -108,6 +108,7 @@ use message::{AuthenticationCleartextPassword,
BackendMessage,
BindComplete,
CommandComplete,
CopyInResponse,
DataRow,
EmptyQueryResponse,
ErrorResponse,
@ -124,6 +125,7 @@ use message::{AuthenticationCleartextPassword,
use message::{Bind,
CancelRequest,
Close,
CopyFail,
Describe,
Execute,
FrontendMessage,
@ -1116,7 +1118,7 @@ impl<'conn> PostgresStatement<'conn> {
}
_ => {
conn.desynchronized = true;
return Err(PgBadResponse);
Err(PgBadResponse)
}
}
}
@ -1190,6 +1192,13 @@ impl<'conn> PostgresStatement<'conn> {
num = 0;
break;
}
CopyInResponse { .. } => {
try_pg!(conn.write_messages([
CopyFail {
message: "COPY queries cannot be directly executed",
},
Sync]));
}
_ => {
conn.desynchronized = true;
return Err(PgBadResponse);
@ -1305,6 +1314,14 @@ impl<'stmt> PostgresRows<'stmt> {
try!(conn.wait_for_ready());
return Err(PgDbError(PostgresDbError::new(fields)));
}
CopyInResponse { .. } => {
try_pg!(conn.write_messages([
CopyFail {
message: "COPY queries cannot be directly executed",
},
Sync]));
continue;
}
_ => {
conn.desynchronized = true;
return Err(PgBadResponse);

View File

@ -26,6 +26,10 @@ pub enum BackendMessage {
CommandComplete {
pub tag: String,
},
CopyInResponse {
pub format: u8,
pub column_formats: Vec<u16>,
},
DataRow {
pub row: Vec<Option<Vec<u8>>>
},
@ -86,6 +90,13 @@ pub enum FrontendMessage<'a> {
pub variant: u8,
pub name: &'a str
},
CopyData {
pub data: &'a [u8],
},
CopyDone,
CopyFail {
pub message: &'a str
},
Describe {
pub variant: u8,
pub name: &'a str
@ -177,6 +188,17 @@ impl<W: Writer> WriteMessage for W {
try!(buf.write_u8(variant));
try!(buf.write_cstr(name));
}
CopyData { data } => {
ident = Some(b'd');
try!(buf.write(data));
}
CopyDone => {
ident = Some(b'C');
}
CopyFail { message } => {
ident = Some(b'f');
try!(buf.write_cstr(message));
}
Describe { variant, name } => {
ident = Some(b'D');
try!(buf.write_u8(variant));
@ -276,6 +298,17 @@ impl<R: Reader> ReadMessage for R {
b'C' => CommandComplete { tag: try!(buf.read_cstr()) },
b'D' => try!(read_data_row(&mut buf)),
b'E' => ErrorResponse { fields: try!(read_fields(&mut buf)) },
b'G' => {
let format = try!(buf.read_u8());
let mut column_formats = vec![];
for _ in range(0, try!(buf.read_be_u16())) {
column_formats.push(try!(buf.read_be_u16()));
}
CopyInResponse {
format: format,
column_formats: column_formats,
}
}
b'I' => EmptyQueryResponse,
b'K' => BackendKeyData {
process_id: try!(buf.read_be_u32()),

View File

@ -691,3 +691,20 @@ fn test_md5_pass_wrong_pass() {
_ => fail!("Expected error")
}
}
#[test]
fn test_execute_copy_from_err() {
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));
or_fail!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", []));
let stmt = or_fail!(conn.prepare("COPY foo (id) FROM STDIN"));
match stmt.execute([]) {
Err(PgDbError(ref err)) if err.message.as_slice().contains("COPY") => {}
Err(err) => fail!("Unexptected error {}", err),
_ => fail!("Expected error"),
}
match stmt.query([]) {
Err(PgDbError(ref err)) if err.message.as_slice().contains("COPY") => {}
Err(err) => fail!("Unexptected error {}", err),
_ => fail!("Expected error"),
}
}