Skip to content

Commit 5cc4c25

Browse files
authored
Merge pull request #1892 from Kobzol/db-tests
Add simple test infrastructure for testing DB queries
2 parents d458d7a + 50c6d40 commit 5cc4c25

File tree

8 files changed

+190
-26
lines changed

8 files changed

+190
-26
lines changed

.github/workflows/ci.yml

+21
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,27 @@ jobs:
1818
run: rustup update stable && rustup default stable && rustup component add rustfmt
1919
- run: cargo fmt --all --check
2020

21+
test:
22+
name: Test
23+
runs-on: ubuntu-latest
24+
env:
25+
TEST_DB_URL: postgres://postgres:postgres@localhost:5432/postgres
26+
services:
27+
postgres:
28+
image: postgres:14
29+
env:
30+
POSTGRES_USER: postgres
31+
POSTGRES_PASSWORD: postgres
32+
POSTGRES_DB: postgres
33+
ports:
34+
- 5432:5432
35+
steps:
36+
- uses: actions/checkout@v4
37+
- run: rustup toolchain install stable --profile minimal
38+
- uses: Swatinem/rust-cache@v2
39+
- name: Run tests
40+
run: cargo test --workspace --all-targets
41+
2142
ci:
2243
name: CI
2344
runs-on: ubuntu-latest

README.md

+27-14
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ The general overview of what you will need to do:
3939
3. [Configure webhook forwarding](#configure-webhook-forwarding)
4040
4. Configure the `.env` file:
4141

42-
1. Copy `.env.sample` to `.env`
43-
2. `GITHUB_TOKEN`: This is a token needed for Triagebot to send requests to GitHub. Go to GitHub Settings > Developer Settings > Personal Access Token, and create a new token. The `repo` permission should be sufficient.
44-
If this is not set, Triagebot will also look in `~/.gitconfig` in the `github.oauth-token` setting.
45-
3. `DATABASE_URL`: This is the URL to the database. See [Configuring a database](#configuring-a-database).
46-
4. `GITHUB_WEBHOOK_SECRET`: Enter the secret you entered in the webhook above.
47-
5. `RUST_LOG`: Set this to `debug`.
42+
1. Copy `.env.sample` to `.env`
43+
2. `GITHUB_TOKEN`: This is a token needed for Triagebot to send requests to GitHub. Go to GitHub Settings > Developer Settings > Personal Access Token, and create a new token. The `repo` permission should be sufficient.
44+
If this is not set, Triagebot will also look in `~/.gitconfig` in the `github.oauth-token` setting.
45+
3. `DATABASE_URL`: This is the URL to the database. See [Configuring a database](#configuring-a-database).
46+
4. `GITHUB_WEBHOOK_SECRET`: Enter the secret you entered in the webhook above.
47+
5. `RUST_LOG`: Set this to `debug`.
4848

4949
5. Run `cargo run --bin triagebot`. This starts the http server listening for webhooks on port 8000.
5050
6. Add a `triagebot.toml` file to the main branch of your GitHub repo with whichever services you want to try out.
@@ -109,15 +109,28 @@ You need to sign up for a free account, and also deal with configuring the GitHu
109109
3. Configure GitHub webhooks in the test repo you created.
110110
In short:
111111

112-
1. Go to the settings page for your GitHub repo.
113-
2. Go to the webhook section.
114-
3. Click "Add webhook"
115-
4. Include the settings:
112+
1. Go to the settings page for your GitHub repo.
113+
2. Go to the webhook section.
114+
3. Click "Add webhook"
115+
4. Include the settings:
116116

117-
* Payload URL: This is the URL to your Triagebot server, for example http://7e9ea9dc.ngrok.io/github-hook. This URL is displayed when you ran the `ngrok` command above.
118-
* Content type: application/json
119-
* Secret: Enter a shared secret (some longish random text)
120-
* Events: "Send me everything"
117+
* Payload URL: This is the URL to your Triagebot server, for example http://7e9ea9dc.ngrok.io/github-hook. This URL is displayed when you ran the `ngrok` command above.
118+
* Content type: application/json
119+
* Secret: Enter a shared secret (some longish random text)
120+
* Events: "Send me everything"
121+
122+
### Cargo tests
123+
124+
You can run Cargo tests using `cargo test`. If you also want to run tests that access a Postgres database, you can specify an environment variables `TEST_DB_URL`, which should contain a connection string pointing to a running Postgres database instance:
125+
126+
```bash
127+
$ docker run --rm -it -p5432:5432 \
128+
-e POSTGRES_USER=triagebot \
129+
-e POSTGRES_PASSWORD=triagebot \
130+
-e POSTGRES_DB=triagebot \
131+
postgres:14
132+
$ TEST_DB_URL=postgres://triagebot:triagebot@localhost:5432/triagebot cargo test
133+
```
121134

122135
## License
123136

src/db.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ static CERTIFICATE_PEMS: LazyLock<Vec<u8>> = LazyLock::new(|| {
2323
pub struct ClientPool {
2424
connections: Arc<Mutex<Vec<tokio_postgres::Client>>>,
2525
permits: Arc<Semaphore>,
26+
db_url: String,
2627
}
2728

2829
pub struct PooledClient {
@@ -54,10 +55,11 @@ impl std::ops::DerefMut for PooledClient {
5455
}
5556

5657
impl ClientPool {
57-
pub fn new() -> ClientPool {
58+
pub fn new(db_url: String) -> ClientPool {
5859
ClientPool {
5960
connections: Arc::new(Mutex::new(Vec::with_capacity(16))),
6061
permits: Arc::new(Semaphore::new(16)),
62+
db_url,
6163
}
6264
}
6365

@@ -79,15 +81,14 @@ impl ClientPool {
7981
}
8082

8183
PooledClient {
82-
client: Some(make_client().await.unwrap()),
84+
client: Some(make_client(&self.db_url).await.unwrap()),
8385
permit,
8486
pool: self.connections.clone(),
8587
}
8688
}
8789
}
8890

89-
async fn make_client() -> anyhow::Result<tokio_postgres::Client> {
90-
let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
91+
pub async fn make_client(db_url: &str) -> anyhow::Result<tokio_postgres::Client> {
9192
if db_url.contains("rds.amazonaws.com") {
9293
let mut builder = TlsConnector::builder();
9394
for cert in make_certificates() {
@@ -230,8 +231,9 @@ pub async fn schedule_job(
230231
Ok(())
231232
}
232233

233-
pub async fn run_scheduled_jobs(ctx: &Context, db: &DbClient) -> anyhow::Result<()> {
234-
let jobs = get_jobs_to_execute(&db).await.unwrap();
234+
pub async fn run_scheduled_jobs(ctx: &Context) -> anyhow::Result<()> {
235+
let db = &ctx.db.get().await;
236+
let jobs = get_jobs_to_execute(&db).await?;
235237
tracing::trace!("jobs to execute: {:#?}", jobs);
236238

237239
for job in jobs.iter() {

src/handlers/assign.rs

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use tracing as log;
3939
#[cfg(test)]
4040
mod tests {
4141
mod tests_candidates;
42+
mod tests_db;
4243
mod tests_from_diff;
4344
}
4445

src/handlers/assign/tests/tests_db.rs

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#[cfg(test)]
2+
mod tests {
3+
use crate::handlers::assign::filter_by_capacity;
4+
use crate::tests::run_test;
5+
use std::collections::HashSet;
6+
7+
#[tokio::test]
8+
async fn find_reviewers_no_review_prefs() {
9+
run_test(|ctx| async move {
10+
ctx.add_user("usr1", 1).await;
11+
ctx.add_user("usr2", 1).await;
12+
let _users =
13+
filter_by_capacity(ctx.db_client(), &candidates(&["usr1", "usr2"])).await?;
14+
// FIXME: this test fails, because the query is wrong
15+
// check_users(users, &["usr1", "usr2"]);
16+
Ok(ctx)
17+
})
18+
.await;
19+
}
20+
21+
fn candidates(users: &[&'static str]) -> HashSet<&'static str> {
22+
users.into_iter().copied().collect()
23+
}
24+
25+
fn check_users(users: HashSet<String>, expected: &[&'static str]) {
26+
let mut users: Vec<String> = users.into_iter().collect();
27+
users.sort();
28+
assert_eq!(users, expected);
29+
}
30+
}

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ mod team_data;
2626
pub mod triage;
2727
pub mod zulip;
2828

29+
#[cfg(test)]
30+
mod tests;
31+
2932
/// The name of a webhook event.
3033
#[derive(Debug)]
3134
pub enum EventName {

src/main.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ async fn serve_req(
242242
}
243243

244244
async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
245-
let pool = db::ClientPool::new();
245+
let db_url = std::env::var("DATABASE_URL").expect("needs DATABASE_URL");
246+
let pool = db::ClientPool::new(db_url.clone());
246247
db::run_migrations(&mut *pool.get().await)
247248
.await
248249
.context("database migrations")?;
@@ -271,7 +272,7 @@ async fn run_server(addr: SocketAddr) -> anyhow::Result<()> {
271272

272273
// Run all jobs that have a schedule (recurring jobs)
273274
if !is_scheduled_jobs_disabled() {
274-
spawn_job_scheduler();
275+
spawn_job_scheduler(db_url);
275276
spawn_job_runner(ctx.clone());
276277
}
277278

@@ -361,11 +362,12 @@ async fn spawn_job_oneoffs(ctx: Arc<Context>) {
361362
/// The scheduler wakes up every `JOB_SCHEDULING_CADENCE_IN_SECS` seconds to
362363
/// check if there are any jobs ready to run. Jobs get inserted into the the
363364
/// database which acts as a queue.
364-
fn spawn_job_scheduler() {
365+
fn spawn_job_scheduler(db_url: String) {
365366
task::spawn(async move {
366367
loop {
368+
let db_url = db_url.clone();
367369
let res = task::spawn(async move {
368-
let pool = db::ClientPool::new();
370+
let pool = db::ClientPool::new(db_url);
369371
let mut interval =
370372
time::interval(time::Duration::from_secs(JOB_SCHEDULING_CADENCE_IN_SECS));
371373

@@ -401,13 +403,12 @@ fn spawn_job_runner(ctx: Arc<Context>) {
401403
loop {
402404
let ctx = ctx.clone();
403405
let res = task::spawn(async move {
404-
let pool = db::ClientPool::new();
405406
let mut interval =
406407
time::interval(time::Duration::from_secs(JOB_PROCESSING_CADENCE_IN_SECS));
407408

408409
loop {
409410
interval.tick().await;
410-
db::run_scheduled_jobs(&ctx, &*pool.get().await)
411+
db::run_scheduled_jobs(&ctx)
411412
.await
412413
.context("run database scheduled jobs")
413414
.unwrap();

src/tests.rs

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use crate::db;
2+
use crate::db::make_client;
3+
use crate::db::notifications::record_username;
4+
use std::future::Future;
5+
use tokio_postgres::Config;
6+
7+
/// Represents a connection to a Postgres database that can be
8+
/// used in integration tests to test logic that interacts with
9+
/// a database.
10+
pub struct TestContext {
11+
client: tokio_postgres::Client,
12+
db_name: String,
13+
original_db_url: String,
14+
conn_handle: tokio::task::JoinHandle<()>,
15+
}
16+
17+
impl TestContext {
18+
async fn new(db_url: &str) -> Self {
19+
let mut config: Config = db_url.parse().expect("Cannot parse connection string");
20+
21+
// Create a new database that will be used for this specific test
22+
let client = make_client(&db_url)
23+
.await
24+
.expect("Cannot connect to database");
25+
let db_name = format!("db{}", uuid::Uuid::new_v4().to_string().replace("-", ""));
26+
client
27+
.execute(&format!("CREATE DATABASE {db_name}"), &[])
28+
.await
29+
.expect("Cannot create database");
30+
drop(client);
31+
32+
// We need to connect to the database against, because Postgres doesn't allow
33+
// changing the active database mid-connection.
34+
config.dbname(&db_name);
35+
let (mut client, connection) = config
36+
.connect(tokio_postgres::NoTls)
37+
.await
38+
.expect("Cannot connect to the newly created database");
39+
let conn_handle = tokio::spawn(async move {
40+
connection.await.unwrap();
41+
});
42+
43+
db::run_migrations(&mut client)
44+
.await
45+
.expect("Cannot run database migrations");
46+
Self {
47+
client,
48+
db_name,
49+
original_db_url: db_url.to_string(),
50+
conn_handle,
51+
}
52+
}
53+
54+
pub fn db_client(&self) -> &tokio_postgres::Client {
55+
&self.client
56+
}
57+
58+
pub async fn add_user(&self, name: &str, id: u64) {
59+
record_username(&self.client, id, name)
60+
.await
61+
.expect("Cannot create user");
62+
}
63+
64+
async fn finish(self) {
65+
// Cleanup the test database
66+
// First, we need to stop using the database
67+
drop(self.client);
68+
self.conn_handle.await.unwrap();
69+
70+
// Then we need to connect to the default database and drop our test DB
71+
let client = make_client(&self.original_db_url)
72+
.await
73+
.expect("Cannot connect to database");
74+
client
75+
.execute(&format!("DROP DATABASE {}", self.db_name), &[])
76+
.await
77+
.unwrap();
78+
}
79+
}
80+
81+
pub async fn run_test<F, Fut>(f: F)
82+
where
83+
F: FnOnce(TestContext) -> Fut,
84+
Fut: Future<Output = anyhow::Result<TestContext>>,
85+
{
86+
if let Ok(db_url) = std::env::var("TEST_DB_URL") {
87+
let ctx = TestContext::new(&db_url).await;
88+
let ctx = f(ctx).await.expect("Test failed");
89+
ctx.finish().await;
90+
} else {
91+
eprintln!("Skipping test because TEST_DB_URL was not passed");
92+
}
93+
}

0 commit comments

Comments
 (0)