Correctly handle bad column counts in copy

This commit is contained in:
Steven Fackler 2014-09-29 23:32:57 -07:00
parent f88f908498
commit 249db6b55a
2 changed files with 49 additions and 15 deletions

View File

@ -1614,26 +1614,33 @@ impl<'a> PostgresCopyInStatement<'a> {
let _ = buf.write_be_i32(0);
let _ = buf.write_be_i32(0);
for mut row in rows {
for row in rows {
let _ = buf.write_be_i16(self.column_types.len() as i16);
let mut count = 0;
for (i, (val, ty)) in row.by_ref().zip(self.column_types.iter()).enumerate() {
match try!(val.to_sql(ty)) {
(_, None) => {
let _ = buf.write_be_i32(-1);
let mut row = row.fuse();
let mut types = self.column_types.iter();
loop {
match (row.next(), types.next()) {
(Some(val), Some(ty)) => {
match try!(val.to_sql(ty)) {
(_, None) => {
let _ = buf.write_be_i32(-1);
}
(_, Some(val)) => {
let _ = buf.write_be_i32(val.len() as i32);
let _ = buf.write(val.as_slice());
}
}
}
(_, Some(val)) => {
let _ = buf.write_be_i32(val.len() as i32);
let _ = buf.write(val.as_slice());
(Some(_), None) | (None, Some(_)) => {
try_pg!(conn.stream.write_message(
&CopyFail {
message: "Invalid column count",
}));
break;
}
(None, None) => break
}
count = i+1;
}
if row.next().is_some() || count != self.column_types.len() {
// FIXME
fail!()
}
try_pg!(conn.stream.write_message(

View File

@ -723,3 +723,30 @@ fn test_copy_in() {
assert_eq!(vec![(0i32, Some("Steven".to_string())), (1, None)],
or_fail!(stmt.query([])).map(|r| (r.get(0u), r.get(1u))).collect());
}
#[test]
fn test_copy_in_bad_column_count() {
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));
or_fail!(conn.execute("CREATE TEMPORARY TABLE foo (id INT, name VARCHAR)", []));
let stmt = or_fail!(conn.prepare_copy_in("foo", ["id", "name"]));
let data: &[&[&ToSql]] = &[&[&0i32, &"Steven".to_string()], &[&1i32]];
let res = stmt.execute(data.iter().map(|r| r.iter().map(|&e| e)));
match res {
Err(PgDbError(ref err)) if err.message.as_slice().contains("Invalid column count") => {}
Err(err) => fail!("unexpected error {}", err),
_ => fail!("Expected error"),
}
let data: &[&[&ToSql]] = &[&[&0i32, &"Steven".to_string()], &[&1i32, &"Steven".to_string(), &1i32]];
let res = stmt.execute(data.iter().map(|r| r.iter().map(|&e| e)));
match res {
Err(PgDbError(ref err)) if err.message.as_slice().contains("Invalid column count") => {}
Err(err) => fail!("unexpected error {}", err),
_ => fail!("Expected error"),
}
or_fail!(conn.execute("SELECT 1", []));
}