Support SCRAM channel binding for Postgres 11

This commit is contained in:
Steven Fackler 2018-05-26 12:53:12 -07:00
parent 9762eb609f
commit 11ffcac087
13 changed files with 188 additions and 45 deletions

View File

@ -26,7 +26,7 @@ jobs:
- image: rust:1.23.0
environment:
RUSTFLAGS: -D warnings
- image: sfackler/rust-postgres-test:3
- image: sfackler/rust-postgres-test:4
steps:
- checkout
- *RESTORE_REGISTRY

View File

@ -1,6 +1,6 @@
version: '2'
services:
postgres:
image: "sfackler/rust-postgres-test:3"
image: "sfackler/rust-postgres-test:4"
ports:
- 5433:5433

View File

@ -1,3 +1,3 @@
FROM postgres:10.0
FROM postgres:11-beta1
COPY sql_setup.sh /docker-entrypoint-initdb.d/

View File

@ -4,6 +4,6 @@ version = "0.1.0"
authors = ["Steven Fackler <sfackler@gmail.com>"]
[dependencies]
openssl = "0.10"
openssl = "0.10.9"
postgres = { version = "0.15", path = "../postgres" }

View File

@ -2,6 +2,8 @@ pub extern crate openssl;
extern crate postgres;
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
use openssl::nid::Nid;
use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod, SslStream};
use postgres::tls::{Stream, TlsHandshake, TlsStream};
use std::error::Error;
@ -84,4 +86,19 @@ impl TlsStream for OpenSslStream {
fn get_mut(&mut self) -> &mut Stream {
self.0.get_mut()
}
fn tls_server_end_point(&self) -> Option<Vec<u8>> {
let cert = self.0.ssl().peer_certificate()?;
let algo_nid = cert.signature_algorithm().object().nid();
let signature_algorithms = algo_nid.signature_algorithms()?;
let md = match signature_algorithms.digest {
Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
nid => MessageDigest::from_nid(nid)?,
};
let digest = cert.digest(md).ok()?;
Some(digest.to_vec())
}
}

View File

@ -4,7 +4,7 @@ use postgres::{Connection, TlsMode};
use OpenSsl;
#[test]
fn test_require_ssl_conn() {
fn require() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let negotiator = OpenSsl::with_connector(builder.build());
@ -16,7 +16,7 @@ fn test_require_ssl_conn() {
}
#[test]
fn test_prefer_ssl_conn() {
fn prefer() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let negotiator = OpenSsl::with_connector(builder.build());
@ -26,3 +26,15 @@ fn test_prefer_ssl_conn() {
).unwrap();
conn.execute("SELECT 1::VARCHAR", &[]).unwrap();
}
#[test]
fn scram_user() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let negotiator = OpenSsl::with_connector(builder.build());
let conn = Connection::connect(
"postgres://scram_user:password@localhost:5433/postgres",
TlsMode::Require(&negotiator),
).unwrap();
conn.execute("SELECT 1::VARCHAR", &[]).unwrap();
}

View File

@ -16,6 +16,6 @@ generic-array = "0.11"
hmac = "0.6"
md5 = "0.3"
memchr = "2.0"
rand = "0.4"
rand = "0.5"
sha2 = "0.7"
stringprep = "0.1"

View File

@ -4,7 +4,7 @@ use base64;
use generic_array::typenum::U32;
use generic_array::GenericArray;
use hmac::{Hmac, Mac};
use rand::{OsRng, Rng};
use rand::{self, Rng};
use sha2::{Digest, Sha256};
use std::fmt::Write;
use std::io;
@ -17,6 +17,8 @@ const NONCE_LENGTH: usize = 24;
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
pub const SCRAM_SHA_256: &'static str = "SCRAM-SHA-256";
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
pub const SCRAM_SHA_256_PLUS: &'static str = "SCRAM-SHA-256-PLUS";
// since postgres passwords are not required to exclude saslprep-prohibited
// characters or even be valid UTF8, we run saslprep if possible and otherwise
@ -54,10 +56,61 @@ fn hi(str: &[u8], salt: &[u8], i: u32) -> GenericArray<u8, U32> {
hi
}
enum ChannelBindingInner {
Unrequested,
Unsupported,
TlsUnique(Vec<u8>),
TlsServerEndPoint(Vec<u8>),
}
/// The channel binding configuration for a SCRAM authentication exchange.
pub struct ChannelBinding(ChannelBindingInner);
impl ChannelBinding {
/// The server did not request channel binding.
pub fn unrequested() -> ChannelBinding {
ChannelBinding(ChannelBindingInner::Unrequested)
}
/// The server requested channel binding but the client is unable to provide it.
pub fn unsupported() -> ChannelBinding {
ChannelBinding(ChannelBindingInner::Unsupported)
}
/// The server requested channel binding and the client will use the `tls-unique` method.
pub fn tls_unique(finished: Vec<u8>) -> ChannelBinding {
ChannelBinding(ChannelBindingInner::TlsUnique(finished))
}
/// The server requested channel binding and the client will use the `tls-server-end-point`
/// method.
pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
}
fn gs2_header(&self) -> &'static str {
match self.0 {
ChannelBindingInner::Unrequested => "y,,",
ChannelBindingInner::Unsupported => "n,,",
ChannelBindingInner::TlsUnique(_) => "p=tls-unique,,",
ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
}
}
fn cbind_data(&self) -> &[u8] {
match self.0 {
ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
ChannelBindingInner::TlsUnique(ref buf)
| ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
}
}
}
enum State {
Update {
nonce: String,
password: Vec<u8>,
channel_binding: ChannelBinding,
},
Finish {
salted_password: GenericArray<u8, U32>,
@ -66,7 +119,8 @@ enum State {
Done,
}
/// A type which handles the client side of the SCRAM-SHA-256 authentication process.
/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
/// process.
///
/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
@ -85,11 +139,11 @@ pub struct ScramSha256 {
state: State,
}
#[allow(missing_docs)]
impl ScramSha256 {
/// Constructs a new instance which will use the provided password for authentication.
pub fn new(password: &[u8]) -> io::Result<ScramSha256> {
let mut rng = OsRng::new()?;
pub fn new(password: &[u8], channel_binding: ChannelBinding) -> io::Result<ScramSha256> {
// rand 0.5's ThreadRng is cryptographically secure
let mut rng = rand::thread_rng();
let nonce = (0..NONCE_LENGTH)
.map(|_| {
let mut v = rng.gen_range(0x21u8, 0x7e);
@ -100,21 +154,20 @@ impl ScramSha256 {
})
.collect::<String>();
ScramSha256::new_inner(password, nonce)
ScramSha256::new_inner(password, channel_binding, nonce)
}
fn new_inner(password: &[u8], nonce: String) -> io::Result<ScramSha256> {
// the docs say to use pg_same_as_startup_message as the username, but
// psql uses an empty string, so we'll go with that.
let message = format!("n,,n=,r={}", nonce);
let password = normalize(password);
fn new_inner(
password: &[u8],
channel_binding: ChannelBinding,
nonce: String,
) -> io::Result<ScramSha256> {
Ok(ScramSha256 {
message: message,
message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
state: State::Update {
nonce: nonce,
password: password,
nonce,
password: normalize(password),
channel_binding: channel_binding,
},
})
}
@ -131,8 +184,13 @@ impl ScramSha256 {
///
/// This should be called when an `AuthenticationSASLContinue` message is received.
pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
let (client_nonce, password) = match mem::replace(&mut self.state, State::Done) {
State::Update { nonce, password } => (nonce, password),
let (client_nonce, password, channel_binding) =
match mem::replace(&mut self.state, State::Done) {
State::Update {
nonce,
password,
channel_binding,
} => (nonce, password, channel_binding),
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
};
@ -161,8 +219,13 @@ impl ScramSha256 {
hash.input(client_key.as_slice());
let stored_key = hash.result();
let mut cbind_input = vec![];
cbind_input.extend(channel_binding.gs2_header().as_bytes());
cbind_input.extend(channel_binding.cbind_data());
let cbind_input = base64::encode(&cbind_input);
self.message.clear();
write!(&mut self.message, "c=biws,r={}", parsed.nonce).unwrap();
write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
@ -420,7 +483,11 @@ mod test {
1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
let mut scram = ScramSha256::new_inner(password.as_bytes(), nonce.to_string()).unwrap();
let mut scram = ScramSha256::new_inner(
password.as_bytes(),
ChannelBinding::unsupported(),
nonce.to_string(),
).unwrap();
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
scram.update(server_first.as_bytes()).unwrap();

View File

@ -81,7 +81,7 @@ extern crate socket2;
use fallible_iterator::FallibleIterator;
use postgres_protocol::authentication;
use postgres_protocol::authentication::sasl::{self, ScramSha256};
use postgres_protocol::authentication::sasl::{self, ChannelBinding, ScramSha256};
use postgres_protocol::message::backend::{self, ErrorFields};
use postgres_protocol::message::frontend;
use postgres_shared::rows::RowData;
@ -422,25 +422,52 @@ impl InnerConnection {
self.stream.flush()?;
}
backend::Message::AuthenticationSasl(body) => {
// count to validate the entire message body.
if body
.mechanisms()
.filter(|m| *m == sasl::SCRAM_SHA_256)
.count()? == 0
{
let mut has_scram = false;
let mut has_scram_plus = false;
let mut mechanisms = body.mechanisms();
while let Some(mechanism) = mechanisms.next()? {
match mechanism {
sasl::SCRAM_SHA_256 => has_scram = true,
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
_ => {}
}
}
let channel_binding = self
.stream
.get_ref()
.tls_unique()
.map(ChannelBinding::tls_unique)
.or_else(|| {
self.stream
.get_ref()
.tls_server_end_point()
.map(ChannelBinding::tls_server_end_point)
});
let (channel_binding, mechanism) = if has_scram_plus {
match channel_binding {
Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
None => (ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
}
} else if has_scram {
match channel_binding {
Some(_) => (ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
None => (ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
}
} else {
return Err(
io::Error::new(io::ErrorKind::Other, "unsupported authentication").into(),
);
}
};
let pass = user.password().ok_or_else(|| {
error::connect("a password was requested but not provided".into())
})?;
let mut scram = ScramSha256::new(pass.as_bytes())?;
let mut scram = ScramSha256::new(pass.as_bytes(), channel_binding)?;
self.stream.write_message(|buf| {
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), buf)
frontend::sasl_initial_response(mechanism, scram.message(), buf)
})?;
self.stream.flush()?;

View File

@ -12,6 +12,26 @@ pub trait TlsStream: fmt::Debug + Read + Write + Send {
/// Returns a mutable reference to the underlying `Stream`.
fn get_mut(&mut self) -> &mut Stream;
/// Returns the data associated with the `tls-unique` channel binding type as described in
/// [RFC 5929], if supported.
///
/// An implementation only needs to support one of this or `tls_server_end_point`.
///
/// [RFC 5929]: https://tools.ietf.org/html/rfc5929
fn tls_unique(&self) -> Option<Vec<u8>> {
None
}
/// Returns the data associated with the `tls-server-end-point` channel binding type as
/// described in [RFC 5929], if supported.
///
/// An implementation only needs to support one of this or `tls_unique`.
///
/// [RFC 5929]: https://tools.ietf.org/html/rfc5929
fn tls_server_end_point(&self) -> Option<Vec<u8>> {
None
}
}
/// A trait implemented by types that can initiate a TLS session over a Postgres

View File

@ -44,8 +44,8 @@ tokio-core = "0.1.8"
tokio-dns-unofficial = "0.1"
tokio-io = "0.1"
tokio-openssl = { version = "0.1", optional = true }
openssl = { version = "0.9.23", optional = true }
tokio-openssl = { version = "0.2", optional = true }
openssl = { version = "0.10", optional = true }
[target.'cfg(unix)'.dependencies]
tokio-uds = "0.1"

View File

@ -270,10 +270,10 @@ fn ssl_user_ssl_required() {
#[cfg(feature = "with-openssl")]
#[test]
fn openssl_required() {
use tls::openssl::openssl::ssl::{SslConnectorBuilder, SslMethod};
use tls::openssl::openssl::ssl::{SslConnector, SslMethod};
use tls::openssl::OpenSsl;
let mut builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap();
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let negotiator = OpenSsl::from(builder.build());

View File

@ -3,7 +3,7 @@ extern crate tokio_openssl;
pub extern crate openssl;
use futures::Future;
use self::openssl::ssl::{SslMethod, SslConnector, SslConnectorBuilder};
use self::openssl::ssl::{SslMethod, SslConnector};
use self::openssl::error::ErrorStack;
use std::error::Error;
use self::tokio_openssl::{SslConnectorExt, SslStream};
@ -27,7 +27,7 @@ pub struct OpenSsl(SslConnector);
impl OpenSsl {
/// Creates a new `OpenSsl` with default settings.
pub fn new() -> Result<OpenSsl, ErrorStack> {
let connector = SslConnectorBuilder::new(SslMethod::tls())?.build();
let connector = SslConnector::builder(SslMethod::tls())?.build();
Ok(OpenSsl(connector))
}
}