SCRAM-SHA-256 protocol support
This commit is contained in:
parent
801835a05b
commit
8fc5ba218b
@ -9,8 +9,14 @@ documentation = "https://docs.rs/postgres-protocol/0.2.2/postgres_protocol"
|
||||
readme = "../README.md"
|
||||
|
||||
[dependencies]
|
||||
bytes = "0.4"
|
||||
base64 = "0.5"
|
||||
byteorder = "1.0"
|
||||
bytes = "0.4"
|
||||
fallible-iterator = "0.1"
|
||||
generic-array = "0.7"
|
||||
hmac = "0.1"
|
||||
md5 = "0.3"
|
||||
memchr = "1.0"
|
||||
rand = "0.3"
|
||||
sha2 = "0.5"
|
||||
stringprep = "0.1"
|
||||
|
@ -1,6 +1,8 @@
|
||||
//! Authentication protocol support.
|
||||
use md5::Context;
|
||||
|
||||
pub mod sasl;
|
||||
|
||||
/// Hashes authentication information in a way suitable for use in response
|
||||
/// to an `AuthenticationMd5Password` message.
|
||||
///
|
418
postgres-protocol/src/authentication/sasl.rs
Normal file
418
postgres-protocol/src/authentication/sasl.rs
Normal file
@ -0,0 +1,418 @@
|
||||
//! SASL-based authentication support.
|
||||
|
||||
use base64;
|
||||
use generic_array::GenericArray;
|
||||
use generic_array::typenum::U32;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Sha256, Digest};
|
||||
use std::fmt::Write;
|
||||
use std::io;
|
||||
use std::iter;
|
||||
use std::mem;
|
||||
use std::str;
|
||||
use rand::{OsRng, Rng};
|
||||
use stringprep;
|
||||
|
||||
const NONCE_LENGTH: usize = 24;
|
||||
|
||||
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
|
||||
pub const SCRAM_SHA_256: &'static str = "SCRAM-SHA-256";
|
||||
|
||||
// since postgres passwords are not required to exclude saslprep-prohibited
|
||||
// characters or even be valid UTF8, we run saslprep if possible and otherwise
|
||||
// return the raw password.
|
||||
fn normalize(pass: &[u8]) -> Vec<u8> {
|
||||
let pass = match str::from_utf8(pass) {
|
||||
Ok(pass) => pass,
|
||||
Err(_) => return pass.to_vec(),
|
||||
};
|
||||
|
||||
match stringprep::saslprep(pass) {
|
||||
Ok(pass) => pass.into_owned().into_bytes(),
|
||||
Err(_) => pass.as_bytes().to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
fn hi(str: &[u8], salt: &[u8], i: u32) -> GenericArray<u8, U32> {
|
||||
let mut hmac = Hmac::<Sha256>::new(str);
|
||||
hmac.input(salt);
|
||||
hmac.input(&[0, 0, 0, 1]);
|
||||
let mut prev = hmac.result();
|
||||
|
||||
let mut hi = GenericArray::<u8, U32>::clone_from_slice(prev.code());
|
||||
|
||||
for _ in 1..i {
|
||||
let mut hmac = Hmac::<Sha256>::new(str);
|
||||
hmac.input(prev.code());
|
||||
prev = hmac.result();
|
||||
|
||||
for (hi, prev) in hi.iter_mut().zip(prev.code()) {
|
||||
*hi ^= *prev;
|
||||
}
|
||||
}
|
||||
|
||||
hi
|
||||
}
|
||||
|
||||
enum State {
|
||||
Update { nonce: String, password: Vec<u8> },
|
||||
Finish {
|
||||
salted_password: GenericArray<u8, U32>,
|
||||
auth_message: String,
|
||||
},
|
||||
Done,
|
||||
}
|
||||
|
||||
/// A type which handles the client side of the SCRAM-SHA-256 authentication process.
|
||||
///
|
||||
/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
|
||||
/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
|
||||
///
|
||||
/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
|
||||
/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
|
||||
///
|
||||
/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
|
||||
/// passed to the `update()` method, after which the buffer returned by the `message()` method
|
||||
/// should be sent to the backend in a `SASLResponse` message.
|
||||
///
|
||||
/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
|
||||
/// to the `finish()` method method, after which the authentication process is complete.
|
||||
pub struct ScramSha256 {
|
||||
message: String,
|
||||
state: State,
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
impl ScramSha256 {
|
||||
/// Constructs a new instance which will use the provided password for authentication.
|
||||
pub fn new(password: &[u8]) -> io::Result<ScramSha256> {
|
||||
let mut rng = try!(OsRng::new());
|
||||
let nonce = (0..NONCE_LENGTH)
|
||||
.map(|_| {
|
||||
let mut v = rng.gen_range(0x21u8, 0x7e);
|
||||
if v == 0x2c {
|
||||
v = 0x7e
|
||||
}
|
||||
v as char
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
ScramSha256::new_inner(password, nonce)
|
||||
}
|
||||
|
||||
fn new_inner(password: &[u8], nonce: String) -> io::Result<ScramSha256> {
|
||||
// the docs say to use pg_same_as_startup_message as the username, but
|
||||
// psql uses an empty string, so we'll go with that.
|
||||
let message = format!("n,,n=,r={}", nonce);
|
||||
|
||||
let password = normalize(password);
|
||||
|
||||
Ok(ScramSha256 {
|
||||
message: message,
|
||||
state: State::Update {
|
||||
nonce: nonce,
|
||||
password: password,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the message which should be sent to the backend in an `SASLResponse` message.
|
||||
pub fn message(&self) -> &[u8] {
|
||||
if let State::Done = self.state {
|
||||
panic!("invalid SCRAM state");
|
||||
}
|
||||
self.message.as_bytes()
|
||||
}
|
||||
|
||||
/// Updates the state machine with the response from the backend.
|
||||
///
|
||||
/// This should be called when an `AuthenticationSASLContinue` message is received.
|
||||
pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
|
||||
let (client_nonce, password) = match mem::replace(&mut self.state, State::Done) {
|
||||
State::Update { nonce, password } => (nonce, password),
|
||||
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
|
||||
};
|
||||
|
||||
let message = str::from_utf8(message)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
|
||||
|
||||
let parsed = Parser::new(message).server_first_message()?;
|
||||
|
||||
if !parsed.nonce.starts_with(&client_nonce) {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
|
||||
}
|
||||
|
||||
let salt = match base64::decode(parsed.salt) {
|
||||
Ok(salt) => salt,
|
||||
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
|
||||
};
|
||||
|
||||
let salted_password = hi(&password, &salt, parsed.iteration_count);
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new(&salted_password);
|
||||
hmac.input(b"Client Key");
|
||||
let client_key = hmac.result();
|
||||
|
||||
let mut hash = Sha256::default();
|
||||
hash.input(client_key.code());
|
||||
let stored_key = hash.result();
|
||||
|
||||
self.message.clear();
|
||||
write!(&mut self.message, "c=biws,r={}", parsed.nonce).unwrap();
|
||||
|
||||
let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new(&stored_key);
|
||||
hmac.input(auth_message.as_bytes());
|
||||
let client_signature = hmac.result();
|
||||
|
||||
let mut client_proof = GenericArray::<u8, U32>::clone_from_slice(client_key.code());
|
||||
for (proof, signature) in client_proof.iter_mut().zip(client_signature.code()) {
|
||||
*proof ^= *signature;
|
||||
}
|
||||
|
||||
write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap();
|
||||
|
||||
self.state = State::Finish {
|
||||
salted_password: salted_password,
|
||||
auth_message: auth_message,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Finalizes the authentication process.
|
||||
///
|
||||
/// This should be called when the backend sends an `AuthenticationSASLFinal` message.
|
||||
/// Authentication has only succeeded if this method returns `Ok(())`.
|
||||
pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
|
||||
let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
|
||||
State::Finish {
|
||||
salted_password,
|
||||
auth_message,
|
||||
} => (salted_password, auth_message),
|
||||
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
|
||||
};
|
||||
|
||||
let message = str::from_utf8(message)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
|
||||
|
||||
let parsed = Parser::new(message).server_final_message()?;
|
||||
|
||||
let verifier = match parsed {
|
||||
ServerFinalMessage::Error(e) => {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, format!("SCRAM error: {}", e)))
|
||||
}
|
||||
ServerFinalMessage::Verifier(verifier) => verifier,
|
||||
};
|
||||
|
||||
let verifier = match base64::decode(verifier) {
|
||||
Ok(verifier) => verifier,
|
||||
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
|
||||
};
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new(&salted_password);
|
||||
hmac.input(b"Server Key");
|
||||
let server_key = hmac.result();
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new(server_key.code());
|
||||
hmac.input(auth_message.as_bytes());
|
||||
if hmac.verify(&verifier) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Parser<'a> {
|
||||
s: &'a str,
|
||||
it: iter::Peekable<str::CharIndices<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
fn new(s: &'a str) -> Parser<'a> {
|
||||
Parser {
|
||||
s: s,
|
||||
it: s.char_indices().peekable(),
|
||||
}
|
||||
}
|
||||
|
||||
fn eat(&mut self, target: char) -> io::Result<()> {
|
||||
match self.it.next() {
|
||||
Some((_, c)) if c == target => Ok(()),
|
||||
Some((i, c)) => {
|
||||
let m = format!("unexpected character at byte {}: expected `{}` but got `{}",
|
||||
i,
|
||||
target,
|
||||
c);
|
||||
Err(io::Error::new(io::ErrorKind::InvalidInput, m))
|
||||
}
|
||||
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")),
|
||||
}
|
||||
}
|
||||
|
||||
fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
|
||||
where F: Fn(char) -> bool
|
||||
{
|
||||
let start = match self.it.peek() {
|
||||
Some(&(i, _)) => i,
|
||||
None => return Ok(""),
|
||||
};
|
||||
|
||||
loop {
|
||||
match self.it.peek() {
|
||||
Some(&(_, c)) if f(c) => {
|
||||
self.it.next();
|
||||
}
|
||||
Some(&(i, _)) => return Ok(&self.s[start..i]),
|
||||
None => return Ok(&self.s[start..]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn printable(&mut self) -> io::Result<&'a str> {
|
||||
self.take_while(|c| match c {
|
||||
'\x21'...'\x2b' | '\x2d'...'\x7e' => true,
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
|
||||
fn nonce(&mut self) -> io::Result<&'a str> {
|
||||
self.eat('r')?;
|
||||
self.eat('=')?;
|
||||
self.printable()
|
||||
}
|
||||
|
||||
fn base64(&mut self) -> io::Result<&'a str> {
|
||||
self.take_while(|c| match c {
|
||||
'a'...'z' | 'A'...'Z' | '0'...'9' | '/' | '+' | '=' => true,
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
|
||||
fn salt(&mut self) -> io::Result<&'a str> {
|
||||
self.eat('s')?;
|
||||
self.eat('=')?;
|
||||
self.base64()
|
||||
}
|
||||
|
||||
fn posit_number(&mut self) -> io::Result<u32> {
|
||||
let n = self.take_while(|c| match c {
|
||||
'0'...'9' => true,
|
||||
_ => false,
|
||||
})?;
|
||||
n.parse()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
|
||||
fn iteration_count(&mut self) -> io::Result<u32> {
|
||||
self.eat('i')?;
|
||||
self.eat('=')?;
|
||||
self.posit_number()
|
||||
}
|
||||
|
||||
fn eof(&mut self) -> io::Result<()> {
|
||||
match self.it.peek() {
|
||||
Some(&(i, _)) => {
|
||||
Err(io::Error::new(io::ErrorKind::InvalidInput,
|
||||
format!("unexpected trailing data at byte {}", i)))
|
||||
}
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
|
||||
let nonce = self.nonce()?;
|
||||
self.eat(',')?;
|
||||
let salt = self.salt()?;
|
||||
self.eat(',')?;
|
||||
let iteration_count = self.iteration_count()?;
|
||||
self.eof()?;
|
||||
|
||||
Ok(ServerFirstMessage {
|
||||
nonce: nonce,
|
||||
salt: salt,
|
||||
iteration_count: iteration_count,
|
||||
})
|
||||
}
|
||||
|
||||
fn value(&mut self) -> io::Result<&'a str> {
|
||||
self.take_while(|c| match c {
|
||||
'\0' | '=' | ',' => false,
|
||||
_ => true,
|
||||
})
|
||||
}
|
||||
|
||||
fn server_error(&mut self) -> io::Result<Option<&'a str>> {
|
||||
match self.it.peek() {
|
||||
Some(&(_, 'e')) => {}
|
||||
_ => return Ok(None),
|
||||
}
|
||||
|
||||
self.eat('e')?;
|
||||
self.eat('=')?;
|
||||
self.value().map(Some)
|
||||
}
|
||||
|
||||
fn verifier(&mut self) -> io::Result<&'a str> {
|
||||
self.eat('v')?;
|
||||
self.eat('=')?;
|
||||
self.base64()
|
||||
}
|
||||
|
||||
fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
|
||||
let message = match self.server_error()? {
|
||||
Some(error) => ServerFinalMessage::Error(error),
|
||||
None => ServerFinalMessage::Verifier(self.verifier()?),
|
||||
};
|
||||
self.eof()?;
|
||||
Ok(message)
|
||||
}
|
||||
}
|
||||
|
||||
struct ServerFirstMessage<'a> {
|
||||
nonce: &'a str,
|
||||
salt: &'a str,
|
||||
iteration_count: u32,
|
||||
}
|
||||
|
||||
enum ServerFinalMessage<'a> {
|
||||
Error(&'a str),
|
||||
Verifier(&'a str),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_server_first_message() {
|
||||
let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
|
||||
let message = Parser::new(message).server_first_message().unwrap();
|
||||
assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
|
||||
assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
|
||||
assert_eq!(message.iteration_count, 4096);
|
||||
}
|
||||
|
||||
// recorded auth exchange from psql
|
||||
#[test]
|
||||
fn exchange() {
|
||||
let password = "foobar";
|
||||
let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
|
||||
|
||||
let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
|
||||
let server_first = "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
|
||||
=4096";
|
||||
let client_final = "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
|
||||
1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
|
||||
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
|
||||
|
||||
let mut scram = ScramSha256::new_inner(password.as_bytes(), nonce.to_string()).unwrap();
|
||||
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
|
||||
|
||||
scram.update(server_first.as_bytes()).unwrap();
|
||||
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
|
||||
|
||||
scram.finish(server_final.as_bytes()).unwrap();
|
||||
}
|
||||
}
|
@ -11,11 +11,17 @@
|
||||
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
|
||||
#![doc(html_root_url="https://docs.rs/postgres-protocol/0.2.2")]
|
||||
#![warn(missing_docs)]
|
||||
extern crate bytes;
|
||||
extern crate base64;
|
||||
extern crate byteorder;
|
||||
extern crate bytes;
|
||||
extern crate fallible_iterator;
|
||||
extern crate generic_array;
|
||||
extern crate hmac;
|
||||
extern crate md5;
|
||||
extern crate memchr;
|
||||
extern crate rand;
|
||||
extern crate sha2;
|
||||
extern crate stringprep;
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use std::io;
|
||||
|
Loading…
Reference in New Issue
Block a user