Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions src/webserver/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,6 @@ fn get_app_host(config: &AppConfig) -> String {
host
}

fn build_absolute_uri(app_host: &str, relative_path: &str, scheme: &str) -> anyhow::Result<String> {
let mut base_url = Url::parse(&format!("{scheme}://{app_host}"))
.with_context(|| format!("Failed to parse app_host: {app_host}"))?;
base_url.set_path("");
let absolute_url = base_url
.join(relative_path)
.with_context(|| format!("Failed to join path {relative_path}"))?;
Ok(absolute_url.to_string())
}

pub struct ClientWithTime {
client: OidcClient,
end_session_endpoint: Option<EndSessionUrl>,
Expand Down Expand Up @@ -246,6 +236,29 @@ impl OidcState {
.map_err(|e| anyhow::anyhow!("Could not verify the ID token: {e}"))?;
Ok(claims)
}

/// Builds an absolute redirect URI by joining the relative redirect URI with the client's redirect URL
pub async fn build_absolute_redirect_uri(
&self,
relative_redirect_uri: &str,
) -> anyhow::Result<String> {
let client_guard = self.get_client().await;
let client_redirect_url = client_guard
.redirect_uri()
.ok_or_else(|| anyhow!("OIDC client has no redirect URL configured"))?;
let absolute_redirect_uri = client_redirect_url
.url()
.join(relative_redirect_uri)
.with_context(|| {
format!(
"Failed to join redirect URI {} with client redirect URL {}",
relative_redirect_uri,
client_redirect_url.url()
)
})?
.to_string();
Ok(absolute_redirect_uri)
}
}

pub async fn initialize_oidc_state(
Expand Down Expand Up @@ -494,11 +507,12 @@ async fn process_oidc_logout(
.ok()
.flatten();

let scheme = request.connection_info().scheme().to_string();
let mut response =
if let Some(end_session_endpoint) = oidc_state.get_end_session_endpoint().await {
let absolute_redirect_uri =
build_absolute_uri(&oidc_state.config.app_host, &params.redirect_uri, &scheme)?;
let absolute_redirect_uri = oidc_state
.build_absolute_redirect_uri(&params.redirect_uri)
.await?;

let post_logout_redirect_uri =
PostLogoutRedirectUrl::new(absolute_redirect_uri.clone()).with_context(|| {
format!("Invalid post_logout_redirect_uri: {absolute_redirect_uri}")
Expand Down
45 changes: 45 additions & 0 deletions tests/oidc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct DiscoveryResponse {
response_types_supported: Vec<String>,
subject_types_supported: Vec<String>,
id_token_signing_alg_values_supported: Vec<String>,
end_session_endpoint: String,
}

#[derive(Serialize)]
Expand All @@ -89,6 +90,7 @@ async fn discovery_endpoint(state: Data<SharedProviderState>) -> impl Responder
response_types_supported: vec!["code".to_string()],
subject_types_supported: vec!["public".to_string()],
id_token_signing_alg_values_supported: vec!["HS256".to_string()],
end_session_endpoint: format!("{}/logout", state.issuer_url),
};
HttpResponse::Ok()
.insert_header((header::CONTENT_TYPE, "application/json"))
Expand Down Expand Up @@ -435,3 +437,46 @@ async fn test_oidc_expired_token_is_rejected() {
})
.await;
}

#[actix_web::test]
async fn test_oidc_logout_uses_correct_scheme() {
use sqlpage::{
app_config::{test_database_url, AppConfig},
webserver::oidc::create_logout_url,
AppState,
};

crate::common::init_log();
let provider = FakeOidcProvider::new().await;

let db_url = test_database_url();
let config_json = format!(
r#"{{
"database_url": "{db_url}",
"oidc_issuer_url": "{}",
"oidc_client_id": "{}",
"oidc_client_secret": "{}",
"https_domain": "example.com"
}}"#,
provider.issuer_url, provider.client_id, provider.client_secret
);

let config: AppConfig = serde_json::from_str(&config_json).unwrap();
let app_state = AppState::init(&config).await.unwrap();
let app = test::init_service(create_app(Data::new(app_state))).await;

let logout_path = create_logout_url("/logged_out", "", &provider.client_secret);
// make sure the logout path includes the configured domain
assert!(logout_path.starts_with("/sqlpage/oidc_logout"));

let req = test::TestRequest::get().uri(&logout_path).to_request();
let resp = test::call_service(&app, req).await;

assert_eq!(resp.status(), StatusCode::SEE_OTHER);
let location = resp.headers().get("location").unwrap().to_str().unwrap();
let location_url = Url::parse(location).unwrap();
assert_eq!(location_url.path(), "/logout");
let params: HashMap<String, String> = location_url.query_pairs().into_owned().collect();
let post_logout = params.get("post_logout_redirect_uri").unwrap();
assert_eq!(post_logout, "https://example.com/logged_out");
}