Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,21 @@ reqwest = { version = "0.12", features = ["json"], default-features = false, op

# default async runtime
tokio = { version = "1", features = ["time"], optional = true }
bitcoin-ohttp = { version = "0.6.0", optional = true}
url = {version = "2.5.7", optional = true}
bhttp = { version = "0.6.1", optional = true}
http = { version = "1.3.1", optional = true}


[dev-dependencies]
serde_json = "1.0"
tokio = { version = "1.20.1", features = ["full"] }
electrsd = { version = "0.33.0", features = ["legacy", "esplora_a33e97e1", "corepc-node_28_0"] }
lazy_static = "1.4.0"
ohttp-relay = { git = "https://github.com/payjoin/ohttp-relay.git", branch = "main", features = ["_test-util"]}
hyper = {version = "1.8.1", features = ["full"]}
hyper-util = {version = "0.1.19"}
http-body-util = "0.1.1"

[features]
default = ["blocking", "async", "async-https", "tokio"]
Expand All @@ -43,6 +52,7 @@ blocking-https = ["blocking", "minreq/https"]
blocking-https-rustls = ["blocking", "minreq/https-rustls"]
blocking-https-native = ["blocking", "minreq/https-native"]
blocking-https-bundled = ["blocking", "minreq/https-bundled"]
async-ohttp = ["async", "bitcoin-ohttp", "bhttp", "reqwest", "tokio", "url", "http"]

tokio = ["dep:tokio"]
async = ["reqwest", "reqwest/socks", "tokio?/time"]
Expand Down
37 changes: 36 additions & 1 deletion src/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use log::{debug, error, info, trace};

use reqwest::{header, Client, Response};

#[cfg(feature = "async-ohttp")]
use crate::ohttp::OhttpClient;
use crate::{
AddressStats, BlockInfo, BlockStatus, BlockSummary, Builder, Error, MempoolRecentTx,
MempoolStats, MerkleProof, OutputStatus, ScriptHashStats, Tx, TxStatus, Utxo,
Expand All @@ -43,6 +45,9 @@ pub struct AsyncClient<S = DefaultSleeper> {

/// Marker for the type of sleeper used
marker: PhantomData<S>,
/// Ohttp config
#[cfg(feature = "async-ohttp")]
ohttp_client: Option<OhttpClient>,
}

impl<S: Sleeper> AsyncClient<S> {
Expand Down Expand Up @@ -77,6 +82,8 @@ impl<S: Sleeper> AsyncClient<S> {
client: client_builder.build()?,
max_retries: builder.max_retries,
marker: PhantomData,
#[cfg(feature = "async-ohttp")]
ohttp_client: None,
})
}

Expand All @@ -86,9 +93,17 @@ impl<S: Sleeper> AsyncClient<S> {
client,
max_retries: crate::DEFAULT_MAX_RETRIES,
marker: PhantomData,
#[cfg(feature = "async-ohttp")]
ohttp_client: None,
}
}

#[cfg(feature = "async-ohttp")]
pub(crate) fn set_ohttp_client(mut self, ohttp_client: OhttpClient) -> Self {
self.ohttp_client = Some(ohttp_client);
self
}

/// Make an HTTP GET request to given URL, deserializing to any `T` that
/// implement [`bitcoin::consensus::Decodable`].
///
Expand Down Expand Up @@ -557,12 +572,32 @@ impl<S: Sleeper> AsyncClient<S> {
let mut attempts = 0;

loop {
match self.client.get(url).send().await? {
let res = {
#[cfg(feature = "async-ohttp")]
if let Some(ohttp_client) = &self.ohttp_client {
let (body, ctx) = ohttp_client.ohttp_encapsulate("get", url, None)?;
let res = self
.client
.post(ohttp_client.relay_url().to_string())
.header("Content-Type", "message/ohttp-req")
.body(body)
.send()
.await?;
let body = res.bytes().await?.to_vec();
ohttp_client.ohttp_decapsulate(ctx, body)?.into()
} else {
self.client.get(url).send().await?
}
#[cfg(not(feature = "async-ohttp"))]
self.client.get(url).send().await?
};
match res {
resp if attempts < self.max_retries && is_status_retryable(resp.status()) => {
S::sleep(delay).await;
attempts += 1;
delay *= 2;
}

resp => return Ok(resp),
}
}
Expand Down
217 changes: 217 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ pub mod r#async;
#[cfg(feature = "blocking")]
pub mod blocking;

#[cfg(feature = "async-ohttp")]
pub(crate) mod ohttp;

pub use api::*;
#[cfg(feature = "blocking")]
pub use blocking::BlockingClient;
Expand Down Expand Up @@ -195,6 +198,20 @@ impl Builder {
pub fn build_async_with_sleeper<S: Sleeper>(self) -> Result<AsyncClient<S>, Error> {
AsyncClient::from_builder(self)
}

#[cfg(feature = "async-ohttp")]
pub async fn build_async_with_ohttp(
self,
ohttp_relay_url: &str,
ohttp_gateway_url: &str,
) -> Result<AsyncClient, Error> {
use crate::ohttp::OhttpClient;

let ohttp_client = OhttpClient::new(ohttp_relay_url, ohttp_gateway_url).await?;
Ok(self
.build_async_with_sleeper()?
.set_ohttp_client(ohttp_client))
}
}

/// Errors that can happen during a request to `Esplora` servers.
Expand Down Expand Up @@ -230,6 +247,18 @@ pub enum Error {
InvalidHttpHeaderValue(String),
/// The server sent an invalid response
InvalidResponse,
/// Error from Ohttp library
#[cfg(feature = "async-ohttp")]
Ohttp(bitcoin_ohttp::Error),
/// Error when reading and writing to bhttp payloads
#[cfg(feature = "async-ohttp")]
Bhttp(bhttp::Error),
/// Error when converting the http response to and from bhttp response
#[cfg(feature = "async-ohttp")]
Http(http::Error),
/// Error when parsing the URL
#[cfg(feature = "async-ohttp")]
UrlParsing(url::ParseError),
}

impl fmt::Display for Error {
Expand Down Expand Up @@ -344,6 +373,194 @@ mod test {
(blocking_client, async_client)
}

#[cfg(feature = "async-ohttp")]
fn find_free_port() -> u16 {
let listener = std::net::TcpListener::bind("0.0.0.0:0").unwrap();
listener.local_addr().unwrap().port()
}

#[cfg(feature = "async-ohttp")]
async fn start_ohttp_relay(
gateway_url: ohttp_relay::GatewayUri,
) -> (
u16,
tokio::task::JoinHandle<Result<(), Box<dyn std::error::Error + std::marker::Send + Sync>>>,
) {
let port = find_free_port();
let relay = ohttp_relay::listen_tcp(port, gateway_url).await.unwrap();

(port, relay)
}

#[cfg(feature = "async-ohttp")]
async fn start_ohttp_gateway() -> (u16, tokio::task::JoinHandle<()>) {
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper::Response;
use hyper::{Method, Request};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;

let port = find_free_port();
let listener = TcpListener::bind(format!("0.0.0.0:{}", port))
.await
.unwrap();

let handle = tokio::spawn(async move {
let key_config = bitcoin_ohttp::KeyConfig::new(
0,
bitcoin_ohttp::hpke::Kem::K256Sha256,
vec![bitcoin_ohttp::SymmetricSuite::new(
bitcoin_ohttp::hpke::Kdf::HkdfSha256,
bitcoin_ohttp::hpke::Aead::ChaCha20Poly1305,
)],
)
.expect("valid key config");
let server = bitcoin_ohttp::Server::new(key_config).expect("valid server");
let server = std::sync::Arc::new(server);
loop {
match listener.accept().await {
Ok((stream, _)) => {
let io = TokioIo::new(stream);
let server = server.clone();
let service = service_fn(move |req: Request<Incoming>| {
let server = server.clone();
async move {
let path = req.uri().path();
if path == "/.well-known/ohttp-gateway"
&& req.method() == Method::GET
{
let key_config = server.config().encode().unwrap();
Ok::<_, hyper::Error>(
Response::builder()
.status(200)
.header("content-type", "application/ohttp-keys")
.body(Full::new(hyper::body::Bytes::from(key_config)))
.unwrap(),
)
} else if path == "/.well-known/ohttp-gateway"
&& req.method() == Method::POST
{
use http_body_util::BodyExt;

// Assert that the content-type header is set to
// "message/ohttp-req".
let content_type_header = req
.headers()
.get("content-type")
.expect("content-type header should be set by the client");
assert_eq!(content_type_header, "message/ohttp-req");

let bytes = req.collect().await?.to_bytes();
let (bhttp_body, response_ctx) =
server.decapsulate(bytes.iter().as_slice()).unwrap();
// Reconstruct the inner HTTP message from the bhttp message.
let mut r = std::io::Cursor::new(bhttp_body);
let m: bhttp::Message = bhttp::Message::read_bhttp(&mut r)
.expect("Should be valid bhttp message");
let base_url = format!(
"http://{}",
ELECTRSD.esplora_url.as_ref().unwrap()
);
let path =
String::from_utf8(m.control().path().unwrap().to_vec())
.unwrap();
let _ =
Method::from_bytes(m.control().method().unwrap()).unwrap();
// TODO: Use the actual method from the bhttp message
// This will be refactored out to use bitreq
let req = reqwest::Request::new(
Method::GET,
url::Url::parse(&(base_url + &path)).unwrap(),
);
let mut req_builder = reqwest::RequestBuilder::from_parts(
reqwest::Client::new(),
req,
);
for field in m.header().iter() {
req_builder =
req_builder.header(field.name(), field.value());
}

let res = req_builder.send().await.unwrap();
// Convert HTTP response to bhttp response
let mut m: bhttp::Message = bhttp::Message::response(
res.status().as_u16().try_into().unwrap(),
);
m.write_content(res.bytes().await.unwrap());
let mut bhttp_res = vec![];
m.write_bhttp(bhttp::Mode::IndeterminateLength, &mut bhttp_res)
.unwrap();
// Now we need to encapsulate the response
let encapsulated_response =
response_ctx.encapsulate(&bhttp_res).unwrap();

Ok::<_, hyper::Error>(
Response::builder()
.status(200)
.header("content-type", "message/ohttp-res")
.body(Full::new(hyper::body::Bytes::copy_from_slice(
&encapsulated_response,
)))
.unwrap(),
)
} else {
Ok::<_, hyper::Error>(
Response::builder()
.status(404)
.body(Full::new(hyper::body::Bytes::from("Not Found")))
.unwrap(),
)
}
}
});

tokio::spawn(async move {
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
{
eprintln!("Error serving connection: {:?}", err);
}
});
}
Err(e) => {
eprintln!("Error accepting connection: {:?}", e);
break;
}
}
}
});
println!("OHTTP gateway started on port {}", port);

(port, handle)
}
#[cfg(feature = "async-ohttp")]
#[tokio::test]
async fn test_ohttp_e2e() {
let (_, async_client) = setup_clients().await;
let block_hash = async_client.get_block_hash(1).await.unwrap();
let esplora_url = ELECTRSD.esplora_url.as_ref().unwrap();
let (gateway_port, _) = start_ohttp_gateway().await;
let gateway_origin = format!("http://localhost:{gateway_port}");
let (relay_port, _) =
start_ohttp_relay(gateway_origin.parse::<ohttp_relay::GatewayUri>().unwrap()).await;
let gateway_url = format!(
"http://localhost:{}/.well-known/ohttp-gateway",
gateway_port
);
let relay_url = format!("http://localhost:{}", relay_port);

let ohttp_client = Builder::new(&format!("http://{}", esplora_url))
.build_async_with_ohttp(&relay_url, &gateway_url)
.await
.unwrap();

let res = ohttp_client.get_block_hash(1).await.unwrap();
assert_eq!(res, block_hash);
}

#[cfg(all(feature = "blocking", feature = "async"))]
fn generate_blocks_and_wait(num: usize) {
let cur_height = BITCOIND.client.get_block_count().unwrap().0;
Expand Down
Loading