Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use system installed certificates by default and accept neo4j+ssc connections #180

Merged
merged 12 commits into from
Jul 25, 2024
5 changes: 5 additions & 0 deletions .idea/.gitignore
madchicken marked this conversation as resolved.
Show resolved Hide resolved

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions .idea/aws.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions .idea/material_theme_project_new.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions .idea/neo4rs.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ thiserror = "1.0.7"
time = { version = "0.3.22", optional = true }
tokio = { version = "1.5.0", features = ["full"] }
url = "2.0.0"
webpki-roots = "0.26.0"
rustls-native-certs = "0.7.1"
rustls-pemfile = "2.1.2"
madchicken marked this conversation as resolved.
Show resolved Hide resolved

[dependencies.chrono]
version = "0.4.35"
Expand Down
12 changes: 12 additions & 0 deletions lib/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
mod static_client_certificate_provider;

#[derive(Debug, Clone)]
pub struct ClientCertificate {
pub(crate) cert_file: String, // Path to the TLS certificate file.
madchicken marked this conversation as resolved.
Show resolved Hide resolved
}

pub trait ClientCertificateProvider {
madchicken marked this conversation as resolved.
Show resolved Hide resolved
fn get_certificate(&self) -> ClientCertificate;
}

pub use static_client_certificate_provider::StaticClientCertificateProvider;
21 changes: 21 additions & 0 deletions lib/src/auth/static_client_certificate_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::auth::{ClientCertificate, ClientCertificateProvider};

pub struct StaticClientCertificateProvider {
certificate: ClientCertificate
}

impl StaticClientCertificateProvider {
pub fn new(cert_file: String) -> Self {
Self {
certificate: ClientCertificate {
cert_file,
}
}
}
}

impl ClientCertificateProvider for StaticClientCertificateProvider {
fn get_certificate(&self) -> ClientCertificate {
self.certificate.clone()
}
}
12 changes: 12 additions & 0 deletions lib/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::errors::{Error, Result};
use std::{ops::Deref, sync::Arc};
use crate::auth::{ClientCertificate, ClientCertificateProvider};

const DEFAULT_DATABASE: &str = "neo4j";
const DEFAULT_FETCH_SIZE: usize = 200;
Expand Down Expand Up @@ -58,6 +59,7 @@ pub struct Config {
pub(crate) max_connections: usize,
pub(crate) db: Database,
pub(crate) fetch_size: usize,
pub(crate) client_certificate: Option<ClientCertificate>,
}

impl Config {
Expand All @@ -77,6 +79,7 @@ pub struct ConfigBuilder {
db: Database,
fetch_size: usize,
max_connections: usize,
client_certificate_provider: Option<Box<dyn ClientCertificateProvider>>,
madchicken marked this conversation as resolved.
Show resolved Hide resolved
}

impl ConfigBuilder {
Expand Down Expand Up @@ -128,6 +131,11 @@ impl ConfigBuilder {
self
}

pub fn with_client_certificate_provider(mut self, provider: Box<dyn ClientCertificateProvider>) -> Self {
self.client_certificate_provider = Some(provider);
self
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when the trait is removed, this could either be

Suggested change
pub fn with_client_certificate_provider(mut self, provider: Box<dyn ClientCertificateProvider>) -> Self {
self.client_certificate_provider = Some(provider);
self
}
pub fn with_client_certificate(mut self, client_cert: impl AsRef<Path>) -> Self {
self.client_certificate = Some(ClientCertificate::new(client_cert));
self
}

or with a From impl

Suggested change
pub fn with_client_certificate_provider(mut self, provider: Box<dyn ClientCertificateProvider>) -> Self {
self.client_certificate_provider = Some(provider);
self
}
pub fn with_client_certificate(mut self, client_cert: impl Into<ClientCertificate>) -> Self {
self.client_certificate = Some(client_cert.into());
self
}

pub fn build(self) -> Result<Config> {
if let (Some(uri), Some(user), Some(password)) = (self.uri, self.user, self.password) {
Ok(Config {
Expand All @@ -137,6 +145,7 @@ impl ConfigBuilder {
fetch_size: self.fetch_size,
max_connections: self.max_connections,
db: self.db,
client_certificate: self.client_certificate_provider.map(|p| p.get_certificate()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this would be

Suggested change
client_certificate: self.client_certificate_provider.map(|p| p.get_certificate()),
client_certificate: self.client_certificate,

})
} else {
Err(Error::InvalidConfig)
Expand All @@ -153,6 +162,7 @@ impl Default for ConfigBuilder {
db: DEFAULT_DATABASE.into(),
max_connections: DEFAULT_MAX_CONNECTIONS,
fetch_size: DEFAULT_FETCH_SIZE,
client_certificate_provider: None,
}
}
}
Expand All @@ -178,6 +188,7 @@ mod tests {
assert_eq!(&*config.db, "some_db");
assert_eq!(config.fetch_size, 10);
assert_eq!(config.max_connections, 5);
assert!(config.client_certificate.is_none());
}

#[test]
Expand All @@ -194,6 +205,7 @@ mod tests {
assert_eq!(&*config.db, "neo4j");
assert_eq!(config.fetch_size, 200);
assert_eq!(config.max_connections, 16);
assert!(config.client_certificate.is_none());
}

#[test]
Expand Down
70 changes: 66 additions & 4 deletions lib/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use crate::{
};
use bytes::{Bytes, BytesMut};
use std::{mem, sync::Arc};
use std::fs::File;
use std::io::BufReader;
use log::warn;
use stream::ConnectionStream;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufStream},
Expand All @@ -19,7 +22,9 @@ use tokio_rustls::{
},
TlsConnector,
};
use tokio_rustls::client::TlsStream;
use url::{Host, Url};
use crate::auth::ClientCertificate;

const MAX_CHUNK_SIZE: usize = 65_535 - mem::size_of::<u16>();

Expand All @@ -39,7 +44,13 @@ impl Connection {

match info.encryption {
Encryption::No => Self::new_unencrypted(stream, &info.user, &info.password).await,
Encryption::Tls => Self::new_tls(stream, &info.host, &info.user, &info.password).await,
Encryption::Tls => {
if let Some(certificate) = info.client_certificate.as_ref() {
Self::new_tls_with_certificate(stream, &info.host, &info.user, &info.password, certificate).await
} else {
Self::new_tls(stream, &info.host, &info.user, &info.password).await
}
},
}
}

Expand All @@ -53,9 +64,47 @@ impl Connection {
user: &str,
password: &str,
) -> Result<Connection> {
let root_cert_store = Self::build_cert_store();
let stream = Self::build_stream(stream, host, root_cert_store).await?;

Self::init(user, password, stream).await
}

async fn new_tls_with_certificate<T: AsRef<str>>(
stream: TcpStream,
host: &Host<T>,
user: &str,
password: &str,
certificate: &ClientCertificate,
) -> Result<Connection> {
let mut root_cert_store = Self::build_cert_store();

let cert_file = File::open(certificate.cert_file.as_str())?;
let mut reader = BufReader::new(cert_file);
let certs = rustls_pemfile::certs(&mut reader);
for certificate in certs {
if let Ok(cert) = certificate {
root_cert_store.add(cert).unwrap();
}
}
madchicken marked this conversation as resolved.
Show resolved Hide resolved

let stream = Self::build_stream(stream, host, root_cert_store).await?;
Self::init(user, password, stream).await
}

fn build_cert_store() -> RootCertStore {
let mut root_cert_store = RootCertStore::empty();
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(ToOwned::to_owned));
if let Ok(certs) = rustls_native_certs::load_native_certs() {
for cert in certs {
root_cert_store.add(cert).unwrap();
}
} else {
warn!("Failed to load native certificates!");
}
root_cert_store
}
madchicken marked this conversation as resolved.
Show resolved Hide resolved

async fn build_stream<T: AsRef<str>>(stream: TcpStream, host: &Host<T>, root_cert_store: RootCertStore) -> Result<TlsStream<TcpStream>, Error> {
let config = ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
Expand All @@ -71,8 +120,7 @@ impl Connection {
};

let stream = connector.connect(domain, stream).await?;

Self::init(user, password, stream).await
Ok(stream)
}

async fn init(
Expand Down Expand Up @@ -212,6 +260,7 @@ pub(crate) struct ConnectionInfo {
host: Host<Arc<str>>,
port: u16,
encryption: Encryption,
client_certificate: Option<ClientCertificate>,
}

enum Encryption {
Expand Down Expand Up @@ -242,6 +291,13 @@ impl ConnectionInfo {
));
Encryption::Tls
}
"neo4j+ssc" => {
madchicken marked this conversation as resolved.
Show resolved Hide resolved
log::warn!(concat!(
"This driver does not yet implement client-side routing. ",
"It is possible that operations against a cluster (such as Aura) will fail."
));
Encryption::Tls
}
otherwise => return Err(Error::UnsupportedScheme(otherwise.to_owned())),
};

Expand All @@ -255,8 +311,14 @@ impl ConnectionInfo {
},
port,
encryption,
client_certificate: None
})
}

pub fn with_client_certificate(&mut self, certificate: &ClientCertificate) -> &Self {
madchicken marked this conversation as resolved.
Show resolved Hide resolved
self.client_certificate = Some(certificate.clone());
self
}
}

struct NeoUrl(Url);
Expand Down
2 changes: 2 additions & 0 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,9 @@ pub mod summary;
mod txn;
mod types;
mod version;
mod auth;

pub use crate::auth::{ClientCertificate, ClientCertificateProvider, StaticClientCertificateProvider};
pub use crate::config::{Config, ConfigBuilder, Database};
pub use crate::errors::*;
pub use crate::graph::{query, Graph};
Expand Down
10 changes: 7 additions & 3 deletions lib/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
};
use deadpool::managed::{Manager, Metrics, Object, Pool, RecycleResult};
use log::info;
use crate::auth::{ClientCertificate};

pub type ConnectionPool = Pool<ConnectionManager>;
pub type ManagedConnection = Object<ConnectionManager>;
Expand All @@ -14,8 +15,11 @@ pub struct ConnectionManager {
}

impl ConnectionManager {
pub fn new(uri: &str, user: &str, password: &str) -> Result<Self> {
let info = ConnectionInfo::new(uri, user, password)?;
pub fn new(uri: &str, user: &str, password: &str, client_certificate: Option<&ClientCertificate>) -> Result<Self> {
let mut info = ConnectionInfo::new(uri, user, password)?;
if let Some(client_certificate) = client_certificate {
info.with_client_certificate(client_certificate);
}
madchicken marked this conversation as resolved.
Show resolved Hide resolved
Ok(ConnectionManager { info })
}
}
Expand All @@ -35,7 +39,7 @@ impl Manager for ConnectionManager {
}

pub async fn create_pool(config: &Config) -> Result<ConnectionPool> {
let mgr = ConnectionManager::new(&config.uri, &config.user, &config.password)?;
let mgr = ConnectionManager::new(&config.uri, &config.user, &config.password, config.client_certificate.as_ref())?;
info!(
"creating connection pool with max size {}",
config.max_connections
Expand Down
Loading