extern crate bytes;
extern crate futures;
extern crate openssl;
extern crate tokio_io;
extern crate tokio_openssl;
extern crate tokio_postgres;

#[cfg(test)]
extern crate tokio;

use bytes::{Buf, BufMut};
use futures::{Future, IntoFuture, Poll};
use openssl::error::ErrorStack;
use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod, SslRef};
use std::error::Error;
use std::io::{self, Read, Write};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_openssl::ConnectConfigurationExt;
use tokio_postgres::tls::{Socket, TlsConnect, TlsStream};

#[cfg(test)]
mod test;

pub struct TlsConnector {
    connector: SslConnector,
    callback: Box<Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + Sync + Send>,
}

impl TlsConnector {
    pub fn new() -> Result<TlsConnector, ErrorStack> {
        let connector = SslConnector::builder(SslMethod::tls())?.build();
        Ok(TlsConnector::with_connector(connector))
    }

    pub fn with_connector(connector: SslConnector) -> TlsConnector {
        TlsConnector {
            connector,
            callback: Box::new(|_| Ok(())),
        }
    }

    pub fn set_callback<F>(&mut self, f: F)
    where
        F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send,
    {
        self.callback = Box::new(f);
    }
}

impl TlsConnect for TlsConnector {
    fn connect(
        &self,
        domain: &str,
        socket: Socket,
    ) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send> {
        let f = self
            .connector
            .configure()
            .and_then(|mut ssl| (self.callback)(&mut ssl).map(|_| ssl))
            .map_err(|e| {
                let e: Box<Error + Sync + Send> = Box::new(e);
                e
            })
            .into_future()
            .and_then({
                let domain = domain.to_string();
                move |ssl| {
                    ssl.connect_async(&domain, socket)
                        .map(|s| {
                            let s: Box<TlsStream> = Box::new(SslStream(s));
                            s
                        })
                        .map_err(|e| {
                            let e: Box<Error + Sync + Send> = Box::new(e);
                            e
                        })
                }
            });
        Box::new(f)
    }
}

struct SslStream(tokio_openssl::SslStream<Socket>);

impl Read for SslStream {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.0.read(buf)
    }
}

impl AsyncRead for SslStream {
    unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
        self.0.prepare_uninitialized_buffer(buf)
    }

    fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
    where
        B: BufMut,
    {
        self.0.read_buf(buf)
    }
}

impl Write for SslStream {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.0.write(buf)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.0.flush()
    }
}

impl AsyncWrite for SslStream {
    fn shutdown(&mut self) -> Poll<(), io::Error> {
        self.0.shutdown()
    }

    fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
    where
        B: Buf,
    {
        self.0.write_buf(buf)
    }
}

impl TlsStream for SslStream {
    fn tls_unique(&self) -> Option<Vec<u8>> {
        let f = if self.0.get_ref().ssl().session_reused() {
            SslRef::peer_finished
        } else {
            SslRef::finished
        };

        let len = f(self.0.get_ref().ssl(), &mut []);
        let mut buf = vec![0; len];
        f(self.0.get_ref().ssl(), &mut buf);

        Some(buf)
    }
}