Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tmp-postgrust"
version = "0.10.1"
version = "0.11.0"
authors = ["John Children <john.children@cambridgequantum.com>"]
license = "MIT"
edition = "2018"
Expand Down
23 changes: 12 additions & 11 deletions src/asynchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use tokio::{
use tracing::{debug, instrument};

use crate::errors::{ProcessCapture, TmpPostgrustError, TmpPostgrustResult};
use crate::search::{all_dir_entries, build_copy_dst_path, find_postgresql_command};
use crate::search::{all_dir_entries, build_copy_dst_path};
use crate::POSTGRES_UID_GID;

#[instrument(skip(command, fail))]
Expand Down Expand Up @@ -49,13 +49,11 @@ async fn exec_process(

#[instrument]
pub(crate) fn start_postgres_subprocess(
postgres_bin: &Path,
data_directory: &Path,
port: u32,
) -> TmpPostgrustResult<Child> {
let postgres_path =
find_postgresql_command("bin", "postgres").expect("failed to find postgres");

let mut command = Command::new(postgres_path);
let mut command = Command::new(postgres_bin);
command
.env("PGDATA", data_directory.to_str().unwrap())
.arg("-p")
Expand All @@ -69,11 +67,12 @@ pub(crate) fn start_postgres_subprocess(
}

#[instrument]
pub(crate) async fn exec_init_db(data_directory: &Path) -> TmpPostgrustResult<()> {
let initdb_path = find_postgresql_command("bin", "initdb").expect("failed to find initdb");

pub(crate) async fn exec_init_db(
initdb_bin: &Path,
data_directory: &Path,
) -> TmpPostgrustResult<()> {
debug!("Initializing database in: {:?}", data_directory);
let mut command = Command::new(initdb_path);
let mut command = Command::new(initdb_bin);
command
.env("PGDATA", data_directory.to_str().unwrap())
.arg("--username=postgres");
Expand Down Expand Up @@ -107,12 +106,13 @@ pub(crate) async fn exec_copy_dir(src_dir: &Path, dst_dir: &Path) -> TmpPostgrus

#[instrument]
pub(crate) async fn exec_create_db(
createdb_bin: &Path,
socket: &Path,
port: u32,
owner: &str,
dbname: &str,
) -> TmpPostgrustResult<()> {
let mut command = Command::new("createdb");
let mut command = Command::new(createdb_bin);
command
.arg("-h")
.arg(socket)
Expand All @@ -130,11 +130,12 @@ pub(crate) async fn exec_create_db(

#[instrument]
pub(crate) async fn exec_create_user(
createuser_bin: &Path,
socket: &Path,
port: u32,
username: &str,
) -> TmpPostgrustResult<()> {
let mut command = Command::new("createuser");
let mut command = Command::new(createuser_bin);
command
.arg("-h")
.arg(socket)
Expand Down
11 changes: 11 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::PathBuf;

use thiserror::Error;

/// UTF-8 captures of stdout and stderr for child processes used by the library.
Expand All @@ -11,6 +13,7 @@ pub struct ProcessCapture {

/// Error type for possible postgresql errors.
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum TmpPostgrustError {
/// Catchall error for when a subprocess fails to run to completion
#[error("subprocess failed to execute")]
Expand Down Expand Up @@ -64,6 +67,14 @@ pub enum TmpPostgrustError {
/// Error when running migrations failed.
#[error("failed to run database migrations")]
MigrationsFailed(#[source] Box<dyn std::error::Error + Send + Sync>),
/// Error when a required postgres binary cannot be found.
#[error("could not find postgres command `{command}`{}", .searched_dir.as_ref().map(|d| format!(" in {}", d.display())).unwrap_or_default())]
PostgresCommandNotFound {
/// Name of the binary that was not found.
command: String,
/// The explicit directory that was searched, if one was provided.
searched_dir: Option<PathBuf>,
},
}

/// Result type for `TmpPostgrustError`, used by functions in this crate.
Expand Down
102 changes: 93 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub mod synchronous;
use std::fmt::Write as FmtWrite;
use std::fs::{metadata, set_permissions};
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::path::{Path, PathBuf};
use std::sync::atomic::AtomicU32;
use std::sync::{Arc, OnceLock, RwLock};
use std::{fs::File, io::Write};
Expand All @@ -33,6 +33,7 @@ use tempfile::{Builder, TempDir};
use tracing::{debug, info, instrument};

use crate::errors::{TmpPostgrustError, TmpPostgrustResult};
use crate::search::PostgresBinaries;

const TMP_POSTGRUST_DB_NAME: &str = "tmp-postgrust";
const TMP_POSTGRUST_USER_NAME: &str = "tmp-postgrust-user";
Expand Down Expand Up @@ -78,6 +79,7 @@ pub fn new_default_process() -> TmpPostgrustResult<synchronous::ProcessGuard> {
RwLock::new(Some(
TmpPostgrustFactory::try_new(&TmpPostgrustFactoryConfig {
disable_fsync: true,
..Default::default()
})
.expect("Failed to initialize default postgres factory."),
))
Expand Down Expand Up @@ -111,6 +113,7 @@ pub fn new_default_process_with_migrations(
let factory_mutex = DEFAULT_POSTGRES_FACTORY.get_or_init(|| {
let factory = TmpPostgrustFactory::try_new(&TmpPostgrustFactoryConfig {
disable_fsync: true,
..Default::default()
})
.expect("Failed to initialize default postgres factory.");
factory
Expand Down Expand Up @@ -153,6 +156,7 @@ pub async fn new_default_process_async() -> TmpPostgrustResult<asynchronous::Pro
.get_or_try_init(|| async {
TmpPostgrustFactory::try_new_async(&TmpPostgrustFactoryConfig {
disable_fsync: true,
..Default::default()
})
.await
.map(|factory| tokio::sync::RwLock::new(Some(factory)))
Expand Down Expand Up @@ -195,6 +199,7 @@ where
.get_or_try_init(|| async {
let factory = TmpPostgrustFactory::try_new_async(&TmpPostgrustFactoryConfig {
disable_fsync: true,
..Default::default()
})
.await?;
factory.run_migrations_async(migrate).await?;
Expand All @@ -215,15 +220,21 @@ pub struct TmpPostgrustFactory {
cache_dir: Arc<TempDir>,
config: String,
next_port: AtomicU32,
binaries: PostgresBinaries,
}

/// Configuration for the `TmpPostgrustFactory`.
#[derive(Default, Debug)]
#[non_exhaustive]
pub struct TmpPostgrustFactoryConfig {
/// Disable fsync this will speed up unit tests in exchange for
/// not guaranteeing that files will be written if postgresql
/// crashes.
pub disable_fsync: bool,
/// Directory containing the postgres binaries (`postgres`, `initdb`,
/// `createdb`, `createuser`). When `None`, binaries are resolved from
/// `$PATH` and common install locations.
pub postgresql_bin_dir: Option<PathBuf>,
}

impl TmpPostgrustFactory {
Expand Down Expand Up @@ -255,6 +266,8 @@ impl TmpPostgrustFactory {
pub fn try_new(
factory_config: &TmpPostgrustFactoryConfig,
) -> TmpPostgrustResult<TmpPostgrustFactory> {
let binaries = PostgresBinaries::resolve(factory_config.postgresql_bin_dir.as_deref())?;

let socket_dir = Builder::new()
.prefix("tmp-postgrust-socket")
.tempdir()
Expand All @@ -266,7 +279,7 @@ impl TmpPostgrustFactory {

synchronous::chown_to_non_root(cache_dir.path())?;
synchronous::chown_to_non_root(socket_dir.path())?;
synchronous::exec_init_db(cache_dir.path())?;
synchronous::exec_init_db(&binaries.initdb, cache_dir.path())?;

let config = TmpPostgrustFactory::build_config(factory_config, socket_dir.path());

Expand All @@ -275,10 +288,17 @@ impl TmpPostgrustFactory {
cache_dir: Arc::new(cache_dir),
config,
next_port: AtomicU32::new(5432),
binaries,
};
let process = factory.start_postgresql(&factory.cache_dir)?;
synchronous::exec_create_user(process.socket_dir.path(), process.port, &process.user_name)?;
synchronous::exec_create_user(
&factory.binaries.createuser,
process.socket_dir.path(),
process.port,
&process.user_name,
)?;
synchronous::exec_create_db(
&factory.binaries.createdb,
process.socket_dir.path(),
process.port,
&process.user_name,
Expand All @@ -294,6 +314,8 @@ impl TmpPostgrustFactory {
pub async fn try_new_async(
factory_config: &TmpPostgrustFactoryConfig,
) -> TmpPostgrustResult<TmpPostgrustFactory> {
let binaries = PostgresBinaries::resolve(factory_config.postgresql_bin_dir.as_deref())?;

let socket_dir = Builder::new()
.prefix("tmp-postgrust-socket")
.tempdir()
Expand All @@ -305,7 +327,7 @@ impl TmpPostgrustFactory {

asynchronous::chown_to_non_root(cache_dir.path()).await?;
asynchronous::chown_to_non_root(socket_dir.path()).await?;
asynchronous::exec_init_db(cache_dir.path()).await?;
asynchronous::exec_init_db(&binaries.initdb, cache_dir.path()).await?;

let config = TmpPostgrustFactory::build_config(factory_config, socket_dir.path());

Expand All @@ -314,11 +336,18 @@ impl TmpPostgrustFactory {
cache_dir: Arc::new(cache_dir),
config,
next_port: AtomicU32::new(5432),
binaries,
};
let process = factory.start_postgresql_async(&factory.cache_dir).await?;
asynchronous::exec_create_user(process.socket_dir.path(), process.port, &process.user_name)
.await?;
asynchronous::exec_create_user(
&factory.binaries.createuser,
process.socket_dir.path(),
process.port,
&process.user_name,
)
.await?;
asynchronous::exec_create_db(
&factory.binaries.createdb,
process.socket_dir.path(),
process.port,
&process.user_name,
Expand Down Expand Up @@ -364,7 +393,7 @@ impl TmpPostgrustFactory {
Result<(), Box<dyn std::error::Error + Send + Sync>>,
>,
{
let process = self.start_postgresql(&self.cache_dir)?;
let process = self.start_postgresql_async(&self.cache_dir).await?;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah interesting catch, I guess I missed this before

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude review stays winning on that one, it got caught while checking my work on the rest of it :)


migrate(&process.connection_string())
.await
Expand Down Expand Up @@ -412,7 +441,8 @@ impl TmpPostgrustFactory {
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);

synchronous::chown_to_non_root(dir.path())?;
let mut postgres_process_handle = synchronous::start_postgres_subprocess(dir.path(), port)?;
let mut postgres_process_handle =
synchronous::start_postgres_subprocess(&self.binaries.postgres, dir.path(), port)?;
let stdout = postgres_process_handle.stdout.take().unwrap();
let stderr = postgres_process_handle.stderr.take().unwrap();

Expand Down Expand Up @@ -487,7 +517,7 @@ impl TmpPostgrustFactory {

asynchronous::chown_to_non_root(dir.path()).await?;
let mut postgres_process_handle =
asynchronous::start_postgres_subprocess(dir.path(), port)?;
asynchronous::start_postgres_subprocess(&self.binaries.postgres, dir.path(), port)?;
let stdout = postgres_process_handle.stdout.take().unwrap();
let stderr = postgres_process_handle.stderr.take().unwrap();

Expand Down Expand Up @@ -528,6 +558,7 @@ mod tests {
async fn it_works() {
let factory = TmpPostgrustFactory::try_new(&TmpPostgrustFactoryConfig {
disable_fsync: false,
..Default::default()
})
.expect("failed to create factory");

Expand All @@ -552,6 +583,7 @@ mod tests {
async fn it_works_fsync_disabled() {
let factory = TmpPostgrustFactory::try_new(&TmpPostgrustFactoryConfig {
disable_fsync: true,
..Default::default()
})
.expect("failed to create factory");

Expand All @@ -577,6 +609,7 @@ mod tests {
async fn it_works_async() {
let factory = TmpPostgrustFactory::try_new_async(&TmpPostgrustFactoryConfig {
disable_fsync: false,
..Default::default()
})
.await
.expect("failed to create factory");
Expand All @@ -603,6 +636,7 @@ mod tests {
async fn two_simulatenous_processes() {
let factory = TmpPostgrustFactory::try_new(&TmpPostgrustFactoryConfig {
disable_fsync: false,
..Default::default()
})
.expect("failed to create factory");

Expand Down Expand Up @@ -643,6 +677,7 @@ mod tests {
async fn two_simulatenous_processes_async() {
let factory = TmpPostgrustFactory::try_new_async(&TmpPostgrustFactoryConfig {
disable_fsync: false,
..Default::default()
})
.await
.expect("failed to create factory");
Expand Down Expand Up @@ -756,4 +791,53 @@ mod tests {
// Chance to catch concurrent tests or database that have already been used.
client.execute("CREATE TABLE lock ();", &[]).await.unwrap();
}

#[test(tokio::test)]
async fn explicit_bin_dir_works() {
let postgres_path = crate::search::find_postgresql_command(None, "postgres").expect(
"cannot derive bin_dir: postgres not found via PATH or known install locations",
);
let bin_dir = postgres_path
.parent()
.expect("postgres path has no parent dir")
.to_owned();

let factory = TmpPostgrustFactory::try_new(&TmpPostgrustFactoryConfig {
disable_fsync: true,
postgresql_bin_dir: Some(bin_dir),
})
.expect("failed to create factory with explicit bin dir");

let proc = factory
.new_instance()
.expect("failed to create a new instance");

let (client, conn) = tokio_postgres::connect(&proc.connection_string(), NoTls)
.await
.expect("failed to connect to postgresql");

tokio::spawn(async move {
if let Err(e) = conn.await {
error!("connection error: {}", e);
}
});

client.query("SELECT 1;", &[]).await.unwrap();
}

#[test]
fn bogus_bin_dir_returns_error() {
let result = TmpPostgrustFactory::try_new(&TmpPostgrustFactoryConfig {
disable_fsync: true,
postgresql_bin_dir: Some(std::path::PathBuf::from("/nonexistent/bin/dir")),
});
assert!(
matches!(
result,
Err(TmpPostgrustError::PostgresCommandNotFound { .. })
),
"expected PostgresCommandNotFound, got: {:?}",
result
);
}
}
Loading
Loading