Support database connection via SSL

This is required to use managed Postgres databases. It is necessary to
use SSL connection to the remote host as the connection goes through the
open internet.
This commit is contained in:
Rafael Caricio 2023-04-26 10:38:25 +02:00
parent b7fafe6458
commit 1e40a42524
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
7 changed files with 59 additions and 15 deletions

4
.gitignore vendored
View file

@ -1,5 +1,9 @@
.env.local .env.local
config.yaml config.yaml
/secret/*
/files/* /files/*
!/files/.gitkeep !/files/.gitkeep
/build/*
!/build/.gitkeep
/target /target
fly.toml

0
build/.gitkeep Normal file
View file

View file

@ -25,7 +25,7 @@ async fn main() {
} }
let db_config = config.database_url.parse().unwrap(); let db_config = config.database_url.parse().unwrap();
let db_client = &mut create_database_client(&db_config).await; let db_client = &mut create_database_client(&db_config, config.tls_ca_file.as_ref().map(|p| p.as_path())).await;
apply_migrations(db_client).await; apply_migrations(db_client).await;
match subcmd { match subcmd {

View file

@ -33,6 +33,8 @@ pub struct Config {
// Core settings // Core settings
pub database_url: String, pub database_url: String,
#[serde(default)]
pub tls_ca_file: Option<PathBuf>,
pub storage_dir: PathBuf, pub storage_dir: PathBuf,
pub web_client_dir: Option<PathBuf>, pub web_client_dir: Option<PathBuf>,

View file

@ -27,6 +27,8 @@ thiserror = "1.0.37"
# Async runtime # Async runtime
tokio = { version = "1.20.4", features = [] } tokio = { version = "1.20.4", features = [] }
# Used for working with Postgresql database # Used for working with Postgresql database
openssl = { version = "0.10", features = ["vendored"] }
postgres-openssl = "0.5.0"
tokio-postgres = { version = "0.7.6", features = ["with-chrono-0_4", "with-uuid-1", "with-serde_json-1"] } tokio-postgres = { version = "0.7.6", features = ["with-chrono-0_4", "with-uuid-1", "with-serde_json-1"] }
postgres-types = { version = "0.2.3", features = ["derive", "with-chrono-0_4", "with-uuid-1", "with-serde_json-1"] } postgres-types = { version = "0.2.3", features = ["derive", "with-chrono-0_4", "with-uuid-1", "with-serde_json-1"] }
postgres-protocol = "0.6.4" postgres-protocol = "0.6.4"
@ -38,7 +40,6 @@ uuid = { version = "1.1.2", features = ["serde", "v4"] }
[dev-dependencies] [dev-dependencies]
fedimovies-utils = { path = "../fedimovies-utils", features = ["test-utils"] } fedimovies-utils = { path = "../fedimovies-utils", features = ["test-utils"] }
serial_test = "0.7.0" serial_test = "0.7.0"
[features] [features]

View file

@ -1,3 +1,7 @@
use deadpool_postgres::SslMode;
use openssl::ssl::{SslConnector, SslMethod};
use postgres_openssl::MakeTlsConnector;
use std::path::Path;
use tokio_postgres::config::Config as DatabaseConfig; use tokio_postgres::config::Config as DatabaseConfig;
use tokio_postgres::error::{Error as PgError, SqlState}; use tokio_postgres::error::{Error as PgError, SqlState};
@ -11,6 +15,7 @@ pub mod test_utils;
pub type DbPool = deadpool_postgres::Pool; pub type DbPool = deadpool_postgres::Pool;
pub use tokio_postgres::GenericClient as DatabaseClient; pub use tokio_postgres::GenericClient as DatabaseClient;
use tokio_postgres::NoTls;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
#[error("database type error")] #[error("database type error")]
@ -37,21 +42,49 @@ pub enum DatabaseError {
AlreadyExists(&'static str), // object type AlreadyExists(&'static str), // object type
} }
pub async fn create_database_client(db_config: &DatabaseConfig) -> tokio_postgres::Client { pub async fn create_database_client(
db_config: &DatabaseConfig,
ca_file_path: Option<&Path>,
) -> tokio_postgres::Client {
let client = if let Some(ca_file_path) = ca_file_path {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
log::debug!("Using TLS CA file: {}", ca_file_path.display());
builder.set_ca_file(ca_file_path).unwrap();
let connector = MakeTlsConnector::new(builder.build());
let (client, connection) = db_config.connect(connector).await.unwrap();
tokio::spawn(async move {
if let Err(err) = connection.await {
log::error!("connection with tls error: {}", err);
};
});
client
} else {
let (client, connection) = db_config.connect(tokio_postgres::NoTls).await.unwrap(); let (client, connection) = db_config.connect(tokio_postgres::NoTls).await.unwrap();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = connection.await { if let Err(err) = connection.await {
log::error!("connection error: {}", err); log::error!("connection error: {}", err);
}; };
}); });
client
};
client client
} }
pub fn create_pool(database_url: &str, pool_size: usize) -> DbPool { pub fn create_pool(database_url: &str, ca_file_path: Option<&Path>, pool_size: usize) -> DbPool {
let manager = deadpool_postgres::Manager::new( let manager = if let Some(ca_file_path) = ca_file_path {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
log::info!("Using TLS CA file: {}", ca_file_path.display());
builder.set_ca_file(ca_file_path).unwrap();
let connector = MakeTlsConnector::new(builder.build());
deadpool_postgres::Manager::new(
database_url.parse().expect("invalid database URL"), database_url.parse().expect("invalid database URL"),
tokio_postgres::NoTls, connector,
); )
} else {
deadpool_postgres::Manager::new(database_url.parse().expect("invalid database URL"), NoTls)
};
DbPool::builder(manager) DbPool::builder(manager)
.max_size(pool_size) .max_size(pool_size)
.build() .build()

View file

@ -44,7 +44,11 @@ async fn main() -> std::io::Result<()> {
// https://wiki.postgresql.org/wiki/Number_Of_Database_Connections // https://wiki.postgresql.org/wiki/Number_Of_Database_Connections
let db_pool_size = num_cpus::get() * 2; let db_pool_size = num_cpus::get() * 2;
let db_pool = create_pool(&config.database_url, db_pool_size); let db_pool = create_pool(
&config.database_url,
config.tls_ca_file.as_ref().map(|s| s.as_path()),
db_pool_size,
);
let mut db_client = get_database_client(&db_pool).await.unwrap(); let mut db_client = get_database_client(&db_pool).await.unwrap();
apply_migrations(&mut db_client).await; apply_migrations(&mut db_client).await;