Skip to content

feat(wasm): added wasm support, moved blocking to a feature #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
[package]
name = "firestore-db-and-auth"
version = "0.8.0"
version = "0.8.1"
authors = ["David Gräff <[email protected]>"]
edition = "2021"
license = "MIT"
description = "This crate allows easy access to your Google Firestore DB via service account or OAuth impersonated Google Firebase Auth credentials."
readme = "readme.md"
keywords = ["firestore", "auth"]
categories = ["api-bindings","authentication"]
categories = ["api-bindings", "authentication"]
maintenance = { status = "passively-maintained" }
repository = "https://github.com/davidgraeff/firestore-db-and-auth-rs"
rust-version = "1.64"

[dependencies]
bytes = "1.1"
cache_control = "0.2"
reqwest = { version = "0.11", default-features = false, features = ["json", "blocking", "hyper-rustls"] }
reqwest = { version = "0.11.16", default-features = false, features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
chrono = { version = "0.4", features = ["serde"] }
Expand All @@ -38,17 +38,19 @@ optional = true

# Render the readme file on doc.rs
[package.metadata.docs.rs]
features = [ "external_doc", "rocket_support" ]
features = ["external_doc", "rocket_support"]

[features]
default = ["rustls-tls", "unstable"]
default = ["rustls-tls", "unstable", "blocking"]
blocking = ["reqwest/blocking"]
rocket_support = ["rocket"]
rustls-tls = ["reqwest/rustls-tls"]
default-tls = ["reqwest/default-tls"]
native-tls = ["reqwest/native-tls"]
native-tls-vendored = ["reqwest/native-tls-vendored"]
unstable = []
external_doc = []
wasm32 = ["ring/wasm32_unknown_unknown_js", "ring/std", "tokio/rt", "tokio/sync"]

[[example]]
name = "create_read_write_document"
Expand All @@ -65,4 +67,4 @@ test = true
[[example]]
name = "rocket_http_protected_route"
test = true
required-features = ["rustls-tls","rocket_support"]
required-features = ["rustls-tls", "rocket_support"]
13 changes: 7 additions & 6 deletions src/documents/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use futures::{
task::{Context, Poll},
Future,
};
use std::boxed::Box;

use std::sync::Arc;

/// List all documents of a given collection.
///
Expand Down Expand Up @@ -44,15 +45,15 @@ use std::boxed::Box;
pub fn list<T, AUTH>(
auth: &AUTH,
collection_id: impl Into<String>,
) -> Pin<Box<dyn Stream<Item = Result<(T, dto::Document)>> + Send>>
) -> impl Stream<Item=Result<(T, dto::Document)>>
where
for<'b> T: Deserialize<'b> + 'static,
AUTH: FirebaseAuthBearer + Clone + Send + Sync + 'static,
for<'b> T: Deserialize<'b> + 'static,
AUTH: FirebaseAuthBearer + Clone + Send + Sync + 'static,
{
let auth = auth.clone();
let collection_id = collection_id.into();

Box::pin(stream::unfold(
stream::unfold(
ListInner {
url: firebase_url(auth.project_id(), &collection_id),
auth,
Expand Down Expand Up @@ -116,7 +117,7 @@ where
)),
}
},
))
)
}

async fn get_new_data<'a>(
Expand Down
19 changes: 14 additions & 5 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::fmt;
use reqwest;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use crate::http2;

/// A result type that uses [`FirebaseError`] as an error type
pub type Result<T> = std::result::Result<T, FirebaseError>;
Expand Down Expand Up @@ -168,17 +169,25 @@ struct GoogleRESTApiErrorWrapper {
/// Arguments:
/// - response: The http requests response. Must be mutable, because the contained value will be extracted in an error case
/// - context: A function that will be called in an error case that returns a context string
pub(crate) fn extract_google_api_error(
response: reqwest::blocking::Response,
pub(crate) async fn extract_google_api_error(
response: http2::Response,
context: impl Fn() -> String,
) -> Result<reqwest::blocking::Response> {
) -> Result<http2::Response> {
if response.status() == 200 {
return Ok(response);
}

let status = response.status().clone();

#[cfg(feature = "blocking")]
let res_text = response.text()?;

#[cfg(not(feature = "blocking"))]
let res_text = response.text().await?;

Err(extract_google_api_error_intern(
response.status().clone(),
response.text()?,
status,
res_text,
context,
))
}
Expand Down
11 changes: 11 additions & 0 deletions src/http2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#[cfg(feature = "blocking")]
pub type Response = reqwest::blocking::Response;

#[cfg(not(feature = "blocking"))]
pub type Response = reqwest::Response;

#[cfg(feature = "blocking")]
pub type Client = reqwest::blocking::Client;

#[cfg(not(feature = "blocking"))]
pub type Client = reqwest::Client;
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ pub mod users;

#[cfg(feature = "rocket_support")]
pub mod rocket;
mod http2;

use async_trait::async_trait;
// Forward declarations
pub use credentials::Credentials;
pub use jwt::JWKSet;
Expand All @@ -26,7 +28,8 @@ pub use sessions::user::Session as UserSession;
/// Firestore document methods in [`crate::documents`] expect an object that implements this `FirebaseAuthBearer` trait.
///
/// Implement this trait for your own data structure and provide the Firestore project id and a valid access token.
#[async_trait::async_trait]
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
pub trait FirebaseAuthBearer {
/// Return the project ID. This is required for the firebase REST API.
fn project_id(&self) -> &str;
Expand Down
55 changes: 37 additions & 18 deletions src/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ use std::sync::Arc;
use tokio::sync::RwLock;

pub mod user {
use async_trait::async_trait;
use super::*;
use crate::dto::{OAuthResponse, SignInWithIdpRequest};
use credentials::Credentials;
use crate::http2::Client;

#[inline]
fn token_endpoint(v: &str) -> String {
Expand Down Expand Up @@ -86,16 +88,13 @@ pub mod user {
pub client: reqwest::Client,
}

#[async_trait::async_trait]
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl super::FirebaseAuthBearer for Session {
fn project_id(&self) -> &str {
&self.project_id_
}

async fn access_token_unchecked(&self) -> String {
self.access_token_.read().await.clone()
}

/// Returns the current access token.
/// This method will automatically refresh your access token, if it has expired.
///
Expand All @@ -107,7 +106,7 @@ pub mod user {

if is_expired(&jwt, 0).unwrap() {
// Unwrap: the token is always valid at this point
if let Ok(response) = get_new_access_token(&self.api_key, &jwt).await {
if let Ok(response) = get_new_access_token(&self.client, &self.api_key, &jwt).await {
*jwt = response.id_token.clone();
return response.id_token;
} else {
Expand All @@ -119,22 +118,26 @@ pub mod user {
jwt.clone()
}

async fn access_token_unchecked(&self) -> String {
self.access_token_.read().await.clone()
}

fn client(&self) -> &reqwest::Client {
&self.client
}
}

/// Gets a new access token via an api_key and a refresh_token.
async fn get_new_access_token(
client: &Client,
api_key: &str,
refresh_token: &str,
) -> Result<RefreshTokenToAccessTokenResponse, FirebaseError> {
let request_body = vec![("grant_type", "refresh_token"), ("refresh_token", refresh_token)];

let url = refresh_to_access_endpoint(api_key);
let client = reqwest::Client::new();
let response = client.post(&url).form(&request_body).send().await?;
Ok(response.json().await?)
Ok(response.json::<RefreshTokenToAccessTokenResponse>().await?)
}

#[allow(non_snake_case)]
Expand Down Expand Up @@ -235,15 +238,16 @@ pub mod user {
credentials: &Credentials,
refresh_token: &str,
) -> Result<Session, FirebaseError> {
let client = Client::new();
let r: RefreshTokenToAccessTokenResponse =
get_new_access_token(&credentials.api_key, refresh_token).await?;
get_new_access_token(&client, &credentials.api_key, refresh_token).await?;
Ok(Session {
user_id: r.user_id,
access_token_: Arc::new(RwLock::new(r.id_token)),
refresh_token: Some(r.refresh_token),
project_id_: credentials.project_id.to_owned(),
api_key: credentials.api_key.clone(),
client: reqwest::Client::new(),
client: client,
})
}

Expand Down Expand Up @@ -361,6 +365,7 @@ pub mod user {
}

pub mod session_cookie {
use crate::http2;
use super::*;

pub static GOOGLE_OAUTH2_URL: &str = "https://accounts.google.com/o/oauth2/token";
Expand Down Expand Up @@ -430,27 +435,38 @@ pub mod session_cookie {
let assertion = crate::jwt::session_cookie::create_jwt_encoded(credentials, duration).await?;

// Request Google Oauth2 to retrieve the access token in order to create a session cookie
let client = reqwest::blocking::Client::new();
let response_oauth2: Oauth2ResponseDTO = client
let client = http2::Client::new();

let _res_oauth = client
.post(GOOGLE_OAUTH2_URL)
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", &assertion),
])
.send()?
.json()?;
.send();

#[cfg(feature = "blocking")]
let response_oauth2: Oauth2ResponseDTO = _res_oauth?.json()?;

#[cfg(not(feature = "blocking"))]
let response_oauth2: Oauth2ResponseDTO = _res_oauth.await?.json().await?;

// Create a session cookie with the access token previously retrieved
let response_session_cookie_json: CreateSessionCookieResponseDTO = client
let _res_cookie = client
.post(&identitytoolkit_url(&credentials.project_id))
.bearer_auth(&response_oauth2.access_token)
.json(&SessionLoginDTO {
id_token,
valid_duration: duration.num_seconds() as u64,
tenant_id: None,
})
.send()?
.json()?;
.send();

#[cfg(feature = "blocking")]
let response_session_cookie_json: CreateSessionCookieResponseDTO = _res_cookie?.json()?;

#[cfg(not(feature = "blocking"))]
let response_session_cookie_json: CreateSessionCookieResponseDTO = _res_cookie.await?.json().await?;

Ok(response_session_cookie_json.session_cookie_jwk)
}
Expand All @@ -466,6 +482,8 @@ pub mod service_account {
use chrono::Duration;
use std::cell::RefCell;
use std::ops::Deref;
use async_trait::async_trait;
use crate::http2;

/// Service account session
#[derive(Clone, Debug)]
Expand All @@ -478,7 +496,8 @@ pub mod service_account {
access_token_: Arc<RwLock<String>>,
}

#[async_trait::async_trait]
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl super::FirebaseAuthBearer for Session {
fn project_id(&self) -> &str {
&self.credentials.project_id
Expand Down