Correctly handle bad column counts in copy
This commit is contained in:
parent
f88f908498
commit
249db6b55a
37
src/lib.rs
37
src/lib.rs
@ -1614,26 +1614,33 @@ impl<'a> PostgresCopyInStatement<'a> {
|
||||
let _ = buf.write_be_i32(0);
|
||||
let _ = buf.write_be_i32(0);
|
||||
|
||||
for mut row in rows {
|
||||
for row in rows {
|
||||
let _ = buf.write_be_i16(self.column_types.len() as i16);
|
||||
|
||||
let mut count = 0;
|
||||
for (i, (val, ty)) in row.by_ref().zip(self.column_types.iter()).enumerate() {
|
||||
match try!(val.to_sql(ty)) {
|
||||
(_, None) => {
|
||||
let _ = buf.write_be_i32(-1);
|
||||
let mut row = row.fuse();
|
||||
let mut types = self.column_types.iter();
|
||||
loop {
|
||||
match (row.next(), types.next()) {
|
||||
(Some(val), Some(ty)) => {
|
||||
match try!(val.to_sql(ty)) {
|
||||
(_, None) => {
|
||||
let _ = buf.write_be_i32(-1);
|
||||
}
|
||||
(_, Some(val)) => {
|
||||
let _ = buf.write_be_i32(val.len() as i32);
|
||||
let _ = buf.write(val.as_slice());
|
||||
}
|
||||
}
|
||||
}
|
||||
(_, Some(val)) => {
|
||||
let _ = buf.write_be_i32(val.len() as i32);
|
||||
let _ = buf.write(val.as_slice());
|
||||
(Some(_), None) | (None, Some(_)) => {
|
||||
try_pg!(conn.stream.write_message(
|
||||
&CopyFail {
|
||||
message: "Invalid column count",
|
||||
}));
|
||||
break;
|
||||
}
|
||||
(None, None) => break
|
||||
}
|
||||
count = i+1;
|
||||
}
|
||||
|
||||
if row.next().is_some() || count != self.column_types.len() {
|
||||
// FIXME
|
||||
fail!()
|
||||
}
|
||||
|
||||
try_pg!(conn.stream.write_message(
|
||||
|
@ -723,3 +723,30 @@ fn test_copy_in() {
|
||||
assert_eq!(vec![(0i32, Some("Steven".to_string())), (1, None)],
|
||||
or_fail!(stmt.query([])).map(|r| (r.get(0u), r.get(1u))).collect());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_copy_in_bad_column_count() {
|
||||
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));
|
||||
or_fail!(conn.execute("CREATE TEMPORARY TABLE foo (id INT, name VARCHAR)", []));
|
||||
|
||||
let stmt = or_fail!(conn.prepare_copy_in("foo", ["id", "name"]));
|
||||
let data: &[&[&ToSql]] = &[&[&0i32, &"Steven".to_string()], &[&1i32]];
|
||||
|
||||
let res = stmt.execute(data.iter().map(|r| r.iter().map(|&e| e)));
|
||||
match res {
|
||||
Err(PgDbError(ref err)) if err.message.as_slice().contains("Invalid column count") => {}
|
||||
Err(err) => fail!("unexpected error {}", err),
|
||||
_ => fail!("Expected error"),
|
||||
}
|
||||
|
||||
let data: &[&[&ToSql]] = &[&[&0i32, &"Steven".to_string()], &[&1i32, &"Steven".to_string(), &1i32]];
|
||||
|
||||
let res = stmt.execute(data.iter().map(|r| r.iter().map(|&e| e)));
|
||||
match res {
|
||||
Err(PgDbError(ref err)) if err.message.as_slice().contains("Invalid column count") => {}
|
||||
Err(err) => fail!("unexpected error {}", err),
|
||||
_ => fail!("Expected error"),
|
||||
}
|
||||
|
||||
or_fail!(conn.execute("SELECT 1", []));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user