Skip to content

Commit

Permalink
Merge pull request #913 from flavio/fix-certificate-rotation-detect-c…
Browse files Browse the repository at this point in the history
…hanges

fix: make cert rotation detection more reliable
  • Loading branch information
fabriziosestito authored Sep 20, 2024
2 parents 9cd3c64 + fb9cbf1 commit d5c629b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ tempfile = "3.12.0"
tower = { version = "0.5", features = ["util"] }
http-body-util = "0.1.1"
testcontainers = { version = "0.22", features = ["watchdog"] }
backon = { version = "1.1.0", features = ["tokio-sleep"] }
backon = { version = "1.2", features = ["tokio-sleep"] }

[target.'cfg(target_os = "linux")'.dev-dependencies]
rcgen = { version = "0.13", features = ["crypto"] }
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,11 @@ async fn create_tls_config_and_watch_certificate_changes(
inotify::Inotify::init().map_err(|e| anyhow!("Cannot initialize inotify: {e}"))?;
let cert_watch = inotify
.watches()
.add(cert_file.clone(), inotify::WatchMask::MODIFY)
.add(cert_file.clone(), inotify::WatchMask::CLOSE_WRITE)
.map_err(|e| anyhow!("Cannot watch certificate file: {e}"))?;
let key_watch = inotify
.watches()
.add(key_file.clone(), inotify::WatchMask::MODIFY)
.add(key_file.clone(), inotify::WatchMask::CLOSE_WRITE)
.map_err(|e| anyhow!("Cannot watch key file: {e}"))?;

let buffer = [0; 1024];
Expand Down
85 changes: 45 additions & 40 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,9 +567,12 @@ async fn test_policy_with_wrong_url() {
// helper functions for certificate rotation test, which is a feature supported only on Linux
#[cfg(target_os = "linux")]
mod certificate_reload_helpers {
use std::net::TcpStream;

use anyhow::anyhow;
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
use rcgen::{generate_simple_self_signed, CertifiedKey};
use std::net::TcpStream;
use reqwest::StatusCode;

pub struct TlsData {
pub key: String,
Expand Down Expand Up @@ -614,48 +617,34 @@ mod certificate_reload_helpers {
.unwrap()
}

pub async fn check_tls_san_name(domain_ip: &str, domain_port: &str, hostname: &str) -> bool {
let sleep_interval = std::time::Duration::from_secs(1);
let max_retries = 10;
let mut failed_retries = 0;
pub async fn check_tls_san_name(
domain_ip: &str,
domain_port: &str,
hostname: &str,
) -> anyhow::Result<()> {
let hostname = hostname.to_string();
loop {
let san_names = get_tls_san_names(domain_ip, domain_port).await;
if san_names.contains(&hostname) {
return true;
}
failed_retries += 1;
if failed_retries >= max_retries {
return false;
}
tokio::time::sleep(sleep_interval).await;
let san_names = get_tls_san_names(domain_ip, domain_port).await;
if san_names.contains(&hostname) {
Ok(())
} else {
Err(anyhow!(
"SAN names do not contain the expected hostname ({}): {:?}",
hostname,
san_names
))
}
}

pub async fn wait_for_policy_server_to_be_ready(address: &str) {
let sleep_interval = std::time::Duration::from_secs(1);
let max_retries = 5;
let mut failed_retries = 0;

pub async fn policy_server_is_ready(address: &str) -> anyhow::Result<StatusCode> {
// wait for the server to start
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.build()
.unwrap();

loop {
let url = reqwest::Url::parse(&format!("https://{address}/readiness")).unwrap();
match client.get(url).send().await {
Ok(_) => break,
Err(e) => {
failed_retries += 1;
if failed_retries >= max_retries {
panic!("failed to start the server: {:?}", e);
}
tokio::time::sleep(sleep_interval).await;
}
}
}
let url = reqwest::Url::parse(&format!("https://{address}/readiness")).unwrap();
let response = client.get(url).send().await?;
Ok(response.status())
}
}

Expand Down Expand Up @@ -699,9 +688,22 @@ async fn test_detect_certificate_rotation() {
.unwrap();
api_server.run().await.unwrap();
});
wait_for_policy_server_to_be_ready(format!("{domain_ip}:{domain_port}").as_str()).await;

assert!(check_tls_san_name(&domain_ip, &domain_port, hostname1).await);
let exponential_backoff = ExponentialBuilder::default()
.with_min_delay(Duration::from_secs(10))
.with_max_delay(Duration::from_secs(30))
.with_max_times(5);

let status_code =
(|| async { policy_server_is_ready(format!("{domain_ip}:{domain_port}").as_str()).await })
.retry(exponential_backoff)
.await
.unwrap();
assert_eq!(status_code, reqwest::StatusCode::OK);

check_tls_san_name(&domain_ip, &domain_port, hostname1)
.await
.expect("certificate served doesn't use the expected SAN name");

// Generate a new certificate and key, and switch to them

Expand All @@ -715,16 +717,19 @@ async fn test_detect_certificate_rotation() {
tokio::time::sleep(std::time::Duration::from_secs(4)).await;

// the old certificate should still be in use, since we didn't change also the key
assert!(check_tls_san_name(&domain_ip, &domain_port, hostname1).await);
check_tls_san_name(&domain_ip, &domain_port, hostname1)
.await
.expect("certificate should not have been changed");

// write only the key file
std::fs::write(&key_file, tls_data2.key).unwrap();

// give inotify some time to ensure it detected the cert change
// give inotify some time to ensure it detected the cert change,
// also give axum some time to complete the certificate reload
tokio::time::sleep(std::time::Duration::from_secs(4)).await;

// the new certificate should be in use
assert!(check_tls_san_name(&domain_ip, &domain_port, hostname2).await);
check_tls_san_name(&domain_ip, &domain_port, hostname2)
.await
.expect("certificate hasn't been reloaded");
}

#[tokio::test]
Expand Down

0 comments on commit d5c629b

Please sign in to comment.