Add new SASL messages to protocol

This commit is contained in:
Steven Fackler 2017-05-18 21:54:39 -07:00
parent 588ecc8a6c
commit 801835a05b
2 changed files with 173 additions and 93 deletions

View File

@ -20,6 +20,10 @@ pub enum Message {
AuthenticationOk,
AuthenticationScmCredential,
AuthenticationSspi,
AuthenticationGssContinue(AuthenticationGssContinueBody),
AuthenticationSasl(AuthenticationSaslBody),
AuthenticationSaslContinue(AuthenticationSaslContinueBody),
AuthenticationSaslFinal(AuthenticationSaslFinalBody),
BackendKeyData(BackendKeyDataBody),
BindComplete,
CloseComplete,
@ -81,17 +85,15 @@ impl Message {
let channel = try!(buf.read_cstr());
let message = try!(buf.read_cstr());
Message::NotificationResponse(NotificationResponseBody {
process_id: process_id,
channel: channel,
message: message,
})
process_id: process_id,
channel: channel,
message: message,
})
}
b'c' => Message::CopyDone,
b'C' => {
let tag = try!(buf.read_cstr());
Message::CommandComplete(CommandCompleteBody {
tag: tag,
})
Message::CommandComplete(CommandCompleteBody { tag: tag })
}
b'd' => {
let storage = buf.read_all();
@ -101,9 +103,9 @@ impl Message {
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.read_all();
Message::DataRow(DataRowBody {
storage: storage,
len: len,
})
storage: storage,
len: len,
})
}
b'E' => {
let storage = buf.read_all();
@ -114,36 +116,34 @@ impl Message {
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.read_all();
Message::CopyInResponse(CopyInResponseBody {
format: format,
len: len,
storage: storage,
})
format: format,
len: len,
storage: storage,
})
}
b'H' => {
let format = try!(buf.read_u8());
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.read_all();
Message::CopyOutResponse(CopyOutResponseBody {
format: format,
len: len,
storage: storage,
})
format: format,
len: len,
storage: storage,
})
}
b'I' => Message::EmptyQueryResponse,
b'K' => {
let process_id = try!(buf.read_i32::<BigEndian>());
let secret_key = try!(buf.read_i32::<BigEndian>());
Message::BackendKeyData(BackendKeyDataBody {
process_id: process_id,
secret_key: secret_key,
})
process_id: process_id,
secret_key: secret_key,
})
}
b'n' => Message::NoData,
b'N' => {
let storage = buf.read_all();
Message::NoticeResponse(NoticeResponseBody {
storage: storage,
})
Message::NoticeResponse(NoticeResponseBody { storage: storage })
}
b'R' => {
match try!(buf.read_i32::<BigEndian>()) {
@ -154,12 +154,28 @@ impl Message {
let mut salt = [0; 4];
try!(buf.read_exact(&mut salt));
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody {
salt: salt,
})
salt: salt,
})
}
6 => Message::AuthenticationScmCredential,
7 => Message::AuthenticationGss,
8 => {
let storage = buf.read_all();
Message::AuthenticationGssContinue(AuthenticationGssContinueBody(storage))
}
9 => Message::AuthenticationSspi,
10 => {
let storage = buf.read_all();
Message::AuthenticationSasl(AuthenticationSaslBody(storage))
}
11 => {
let storage = buf.read_all();
Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage))
}
12 => {
let storage = buf.read_all();
Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage))
}
tag => {
return Err(io::Error::new(io::ErrorKind::InvalidInput,
format!("unknown authentication tag `{}`", tag)));
@ -171,31 +187,29 @@ impl Message {
let name = try!(buf.read_cstr());
let value = try!(buf.read_cstr());
Message::ParameterStatus(ParameterStatusBody {
name: name,
value: value,
})
name: name,
value: value,
})
}
b't' => {
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.read_all();
Message::ParameterDescription(ParameterDescriptionBody {
storage: storage,
len: len,
})
storage: storage,
len: len,
})
}
b'T' => {
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.read_all();
Message::RowDescription(RowDescriptionBody {
storage: storage,
len: len,
})
storage: storage,
len: len,
})
}
b'Z' => {
let status = try!(buf.read_u8());
Message::ReadyForQuery(ReadyForQueryBody {
status: status,
})
Message::ReadyForQuery(ReadyForQueryBody { status: status })
}
tag => {
return Err(io::Error::new(io::ErrorKind::InvalidInput,
@ -269,6 +283,64 @@ impl AuthenticationMd5PasswordBody {
}
}
pub struct AuthenticationGssContinueBody(Bytes);
impl AuthenticationGssContinueBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct AuthenticationSaslBody(Bytes);
impl AuthenticationSaslBody {
#[inline]
pub fn mechanisms<'a>(&'a self) -> SaslMechanisms<'a> {
SaslMechanisms(&self.0)
}
}
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> FallibleIterator for SaslMechanisms<'a> {
type Item = &'a str;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<&'a str>> {
let value_end = try!(find_null(self.0, 0));
if value_end == 0 {
if self.0.len() != 1 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid message length"));
}
Ok(None)
} else {
let value = try!(get_str(&self.0[..value_end]));
self.0 = &self.0[value_end + 1..];
Ok(Some(value))
}
}
}
pub struct AuthenticationSaslContinueBody(Bytes);
impl AuthenticationSaslContinueBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct AuthenticationSaslFinalBody(Bytes);
impl AuthenticationSaslFinalBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,

View File

@ -16,44 +16,24 @@ pub enum Message<'a> {
values: &'a [Option<Vec<u8>>],
result_formats: &'a [i16],
},
CancelRequest {
process_id: i32,
secret_key: i32,
},
Close {
variant: u8,
name: &'a str,
},
CopyData {
data: &'a [u8],
},
CancelRequest { process_id: i32, secret_key: i32 },
Close { variant: u8, name: &'a str },
CopyData { data: &'a [u8] },
CopyDone,
CopyFail {
message: &'a str,
},
Describe {
variant: u8,
name: &'a str,
},
Execute {
portal: &'a str,
max_rows: i32,
},
CopyFail { message: &'a str },
Describe { variant: u8, name: &'a str },
Execute { portal: &'a str, max_rows: i32 },
Parse {
name: &'a str,
query: &'a str,
param_types: &'a [Oid],
},
PasswordMessage {
password: &'a str,
},
Query {
query: &'a str,
},
PasswordMessage { password: &'a str },
Query { query: &'a str },
SaslInitialResponse { mechanism: &'a str, data: &'a [u8] },
SaslResponse { data: &'a [u8] },
SslRequest,
StartupMessage {
parameters: &'a [(String, String)],
},
StartupMessage { parameters: &'a [(String, String)] },
Sync,
Terminate,
#[doc(hidden)]
@ -64,19 +44,23 @@ impl<'a> Message<'a> {
#[inline]
pub fn serialize(&self, buf: &mut Vec<u8>) -> io::Result<()> {
match *self {
Message::Bind { portal, statement, formats, values, result_formats } => {
Message::Bind {
portal,
statement,
formats,
values,
result_formats,
} => {
let r = bind(portal,
statement,
formats.iter().cloned(),
values,
|v, buf| {
match *v {
Some(ref v) => {
buf.extend_from_slice(v);
Ok(IsNull::No)
}
None => Ok(IsNull::Yes),
}
|v, buf| match *v {
Some(ref v) => {
buf.extend_from_slice(v);
Ok(IsNull::No)
}
None => Ok(IsNull::Yes),
},
result_formats.iter().cloned(),
buf);
@ -86,20 +70,27 @@ impl<'a> Message<'a> {
Err(BindError::Serialization(e)) => Err(e),
}
}
Message::CancelRequest { process_id, secret_key } => {
Ok(cancel_request(process_id, secret_key, buf))
}
Message::CancelRequest {
process_id,
secret_key,
} => Ok(cancel_request(process_id, secret_key, buf)),
Message::Close { variant, name } => close(variant, name, buf),
Message::CopyData { data } => copy_data(data, buf),
Message::CopyDone => Ok(copy_done(buf)),
Message::CopyFail { message } => copy_fail(message, buf),
Message::Describe { variant, name } => describe(variant, name, buf),
Message::Execute { portal, max_rows } => execute(portal, max_rows, buf),
Message::Parse { name, query, param_types } => {
parse(name, query, param_types.iter().cloned(), buf)
}
Message::Parse {
name,
query,
param_types,
} => parse(name, query, param_types.iter().cloned(), buf),
Message::PasswordMessage { password } => password_message(password, buf),
Message::Query { query: q } => query(q, buf),
Message::SaslInitialResponse { mechanism, data } => {
sasl_initial_response(mechanism, data, buf)
}
Message::SaslResponse { data } => sasl_response(data, buf),
Message::SslRequest => Ok(ssl_request(buf)),
Message::StartupMessage { parameters } => {
startup_message(parameters.iter().map(|&(ref k, ref v)| (&**k, &**v)), buf)
@ -147,17 +138,17 @@ impl From<io::Error> for BindError {
#[inline]
pub fn bind<I, J, F, T, K>(portal: &str,
statement: &str,
formats: I,
values: J,
mut serializer: F,
result_formats: K,
buf: &mut Vec<u8>)
-> Result<(), BindError>
statement: &str,
formats: I,
values: J,
mut serializer: F,
result_formats: K,
buf: &mut Vec<u8>)
-> Result<(), BindError>
where I: IntoIterator<Item = i16>,
J: IntoIterator<Item = T>,
F: FnMut(T, &mut Vec<u8>) -> Result<IsNull, Box<Error + marker::Sync + Send>>,
K: IntoIterator<Item = i16>,
K: IntoIterator<Item = i16>
{
buf.push(b'B');
@ -199,7 +190,8 @@ pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut Vec<u8>) {
buf.write_i32::<BigEndian>(80877102).unwrap();
buf.write_i32::<BigEndian>(process_id).unwrap();
buf.write_i32::<BigEndian>(secret_key)
}).unwrap();
})
.unwrap();
}
#[inline]
@ -277,6 +269,22 @@ pub fn query(query: &str, buf: &mut Vec<u8>) -> io::Result<()> {
write_body(buf, |buf| buf.write_cstr(query))
}
#[inline]
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
buf.push(b'p');
write_body(buf, |buf| {
try!(buf.write_cstr(mechanism));
buf.extend_from_slice(data);
Ok(())
})
}
#[inline]
pub fn sasl_response(data: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
buf.push(b'p');
write_body(buf, |buf| Ok(buf.extend_from_slice(data)))
}
#[inline]
pub fn ssl_request(buf: &mut Vec<u8>) {
write_body(buf, |buf| buf.write_i32::<BigEndian>(80877103)).unwrap();