Skip to content

Commit 32596d5

Browse files
authored
[hermes] Add more rest api methods (#746)
* [hermes] Add more rest api methods Add many of the price service apis. Per David suggestion, we do validation in parsing instead of doing it later. I didn't find any suitable library to deserialize our hex format so I created a macro to implement it because we use it in a couple of places. I tried making a generic HexInput but couldn't make it working (and I need other crates like generic_array for it which makes the code more complex) * Address feedbacks
1 parent 4796516 commit 32596d5

File tree

6 files changed

+188
-56
lines changed

6 files changed

+188
-56
lines changed

hermes/Cargo.lock

+5-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hermes/Cargo.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@ version = "0.1.0"
44
edition = "2021"
55

66
[dependencies]
7-
axum = { version = "0.6.9", features = ["json", "ws"] }
7+
axum = { version = "0.6.9", features = ["json", "ws", "macros"] }
88
axum-extra = { version = "0.7.2", features = ["query"] }
99
axum-macros = { version = "0.3.4" }
1010
anyhow = { version = "1.0.69" }
11+
base64 = { version = "0.21.0" }
1112
borsh = { version = "0.9.0" }
1213
bs58 = { version = "0.4.0" }
1314
dashmap = { version = "5.4.0" }
1415
der = { version = "0.7.0" }
16+
derive_more = { version = "0.99.17" }
1517
env_logger = { version = "0.10.0" }
1618
futures = { version = "0.3.26" }
1719
hex = { version = "0.4.3" }
@@ -26,7 +28,7 @@ secp256k1 = { version = "0.26.0", features = ["rand", "reco
2628
serde = { version = "1.0.152", features = ["derive"] }
2729
serde_arrays = { version = "0.1.0" }
2830
serde_cbor = { version = "0.11.2" }
29-
serde_json = { version = "1.0.93" }
31+
serde_json = { version = "1.0.93" }
3032
sha256 = { version = "1.1.2" }
3133
structopt = { version = "0.3.26" }
3234
tokio = { version = "1.26.0", features = ["full"] }
@@ -58,5 +60,3 @@ libp2p = { version = "0.51.1", features = [
5860
"websocket",
5961
"yamux",
6062
]}
61-
base64 = "0.21.0"
62-
derive_more = "0.99.17"

hermes/src/macros.rs

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#[macro_export]
2+
/// A macro that generates Deserialize from string for a struct S that wraps [u8; N] where N is a
3+
/// compile-time constant. This macro deserializes a string with or without leading 0x and supports
4+
/// both lower case and upper case hex characters.
5+
macro_rules! impl_deserialize_for_hex_string_wrapper {
6+
($struct_name:ident, $array_size:expr) => {
7+
impl<'de> serde::Deserialize<'de> for $struct_name {
8+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
9+
where
10+
D: serde::Deserializer<'de>,
11+
{
12+
struct HexVisitor;
13+
14+
impl<'de> serde::de::Visitor<'de> for HexVisitor {
15+
type Value = [u8; $array_size];
16+
17+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
18+
write!(formatter, "a hex string of length {}", $array_size * 2)
19+
}
20+
21+
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
22+
where
23+
E: serde::de::Error,
24+
{
25+
let s = s.trim_start_matches("0x");
26+
let bytes = hex::decode(s)
27+
.map_err(|_| E::invalid_value(serde::de::Unexpected::Str(s), &self))?;
28+
if bytes.len() != $array_size {
29+
return Err(E::invalid_length(bytes.len(), &self));
30+
}
31+
let mut array = [0_u8; $array_size];
32+
array.copy_from_slice(&bytes);
33+
Ok(array)
34+
}
35+
}
36+
37+
deserializer.deserialize_str(HexVisitor).map($struct_name)
38+
}
39+
}
40+
};
41+
}

hermes/src/main.rs

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use {
1616
};
1717

1818
mod config;
19+
mod macros;
1920
mod network;
2021
mod store;
2122

hermes/src/network/rpc.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl State {
2828

2929
/// This method provides a background service that responds to REST requests
3030
///
31-
/// Currently this is based on Axum due to the simplicity and strong ecosystem support for the
31+
/// Currently this is based on Axum due to the simplicity and strong ecosyjtem support for the
3232
/// packages they are based on (tokio & hyper).
3333
pub async fn spawn(rpc_addr: String, store: Store) -> Result<()> {
3434
let state = State::new(store);
@@ -39,8 +39,10 @@ pub async fn spawn(rpc_addr: String, store: Store) -> Result<()> {
3939
let app = app
4040
.route("/", get(rest::index))
4141
.route("/live", get(rest::live))
42-
.route("/latest_price_feeds", get(rest::latest_price_feeds))
43-
.route("/latest_vaas", get(rest::latest_vaas))
42+
.route("/api/latest_price_feeds", get(rest::latest_price_feeds))
43+
.route("/api/latest_vaas", get(rest::latest_vaas))
44+
.route("/api/get_vaa", get(rest::get_vaa))
45+
.route("/api/get_vaa_ccip", get(rest::get_vaa_ccip))
4446
.with_state(state.clone());
4547

4648
// Listen in the background for new VAA's from the Wormhole RPC.

hermes/src/network/rpc/rest.rs

+132-45
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,9 @@
11
use {
22
crate::store::RequestTime,
3-
base64::{
4-
engine::general_purpose::STANDARD as base64_standard_engine,
5-
Engine as _,
6-
},
7-
pyth_sdk::{
8-
PriceFeed,
9-
PriceIdentifier,
3+
crate::{
4+
impl_deserialize_for_hex_string_wrapper,
5+
store::UnixTimestamp,
106
},
11-
};
12-
// This file implements a REST service for the Price Service. This is a mostly direct copy of the
13-
// TypeScript implementation in the `pyth-crosschain` repo. It uses `axum` as the web framework and
14-
// `tokio` as the async runtime.
15-
use {
167
anyhow::Result,
178
axum::{
189
extract::State,
@@ -24,47 +15,57 @@ use {
2415
Json,
2516
},
2617
axum_extra::extract::Query, // Axum extra Query allows us to parse multi-value query parameters.
18+
base64::{
19+
engine::general_purpose::STANDARD as base64_standard_engine,
20+
Engine as _,
21+
},
22+
derive_more::{
23+
Deref,
24+
DerefMut,
25+
},
26+
pyth_sdk::{
27+
PriceFeed,
28+
PriceIdentifier,
29+
},
2730
};
2831

32+
#[derive(Debug, Clone, Deref, DerefMut)]
33+
pub struct PriceIdInput([u8; 32]);
34+
// TODO: Use const generics instead of macro.
35+
impl_deserialize_for_hex_string_wrapper!(PriceIdInput, 32);
36+
37+
impl From<PriceIdInput> for PriceIdentifier {
38+
fn from(id: PriceIdInput) -> Self {
39+
Self::new(*id)
40+
}
41+
}
42+
2943
pub enum RestError {
30-
InvalidPriceId,
3144
UpdateDataNotFound,
3245
}
3346

3447
impl IntoResponse for RestError {
3548
fn into_response(self) -> Response {
3649
match self {
37-
RestError::InvalidPriceId => {
38-
(StatusCode::BAD_REQUEST, "Invalid Price Id").into_response()
39-
}
4050
RestError::UpdateDataNotFound => {
4151
(StatusCode::NOT_FOUND, "Update data not found").into_response()
4252
}
4353
}
4454
}
4555
}
4656

47-
#[derive(Debug, serde::Serialize, serde::Deserialize)]
48-
pub struct LatestVaaQueryParams {
49-
ids: Vec<String>,
57+
58+
#[derive(Debug, serde::Deserialize)]
59+
pub struct LatestVaasQueryParams {
60+
ids: Vec<PriceIdInput>,
5061
}
5162

52-
/// REST endpoint /latest_vaas?ids[]=...&ids[]=...&ids[]=...
53-
///
54-
/// TODO: This endpoint returns update data as an array of base64 encoded strings. We want
55-
/// to support other formats such as hex in the future.
63+
5664
pub async fn latest_vaas(
5765
State(state): State<super::State>,
58-
Query(params): Query<LatestVaaQueryParams>,
66+
Query(params): Query<LatestVaasQueryParams>,
5967
) -> Result<Json<Vec<String>>, RestError> {
60-
// TODO: Find better ways to validate query parameters.
61-
// FIXME: Handle ids with leading 0x
62-
let price_ids: Vec<PriceIdentifier> = params
63-
.ids
64-
.iter()
65-
.map(PriceIdentifier::from_hex)
66-
.collect::<Result<Vec<PriceIdentifier>, _>>()
67-
.map_err(|_| RestError::InvalidPriceId)?;
68+
let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
6869
let price_feeds_with_update_data = state
6970
.store
7071
.get_price_feeds_with_update_data(price_ids, RequestTime::Latest)
@@ -74,27 +75,22 @@ pub async fn latest_vaas(
7475
.update_data
7576
.batch_vaa
7677
.iter()
77-
.map(|vaa_bytes| base64_standard_engine.encode(vaa_bytes))
78+
.map(|vaa_bytes| base64_standard_engine.encode(vaa_bytes)) // TODO: Support multiple
79+
// encoding formats
7880
.collect(),
7981
))
8082
}
8183

82-
#[derive(Debug, serde::Serialize, serde::Deserialize)]
83-
pub struct LatestPriceFeedParams {
84-
ids: Vec<String>,
84+
#[derive(Debug, serde::Deserialize)]
85+
pub struct LatestPriceFeedsQueryParams {
86+
ids: Vec<PriceIdInput>,
8587
}
8688

87-
/// REST endpoint /latest_vaas?ids[]=...&ids[]=...&ids[]=...
8889
pub async fn latest_price_feeds(
8990
State(state): State<super::State>,
90-
Query(params): Query<LatestPriceFeedParams>,
91+
Query(params): Query<LatestPriceFeedsQueryParams>,
9192
) -> Result<Json<Vec<PriceFeed>>, RestError> {
92-
let price_ids: Vec<PriceIdentifier> = params
93-
.ids
94-
.iter()
95-
.map(PriceIdentifier::from_hex)
96-
.collect::<Result<Vec<PriceIdentifier>, _>>()
97-
.map_err(|_| RestError::InvalidPriceId)?;
93+
let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(|id| id.into()).collect();
9894
let price_feeds_with_update_data = state
9995
.store
10096
.get_price_feeds_with_update_data(price_ids, RequestTime::Latest)
@@ -107,6 +103,91 @@ pub async fn latest_price_feeds(
107103
))
108104
}
109105

106+
#[derive(Debug, serde::Deserialize)]
107+
pub struct GetVaaQueryParams {
108+
id: PriceIdInput,
109+
publish_time: UnixTimestamp,
110+
}
111+
112+
#[derive(Debug, serde::Serialize)]
113+
pub struct GetVaaResponse {
114+
pub vaa: String,
115+
#[serde(rename = "publishTime")]
116+
pub publish_time: UnixTimestamp,
117+
}
118+
119+
pub async fn get_vaa(
120+
State(state): State<super::State>,
121+
Query(params): Query<GetVaaQueryParams>,
122+
) -> Result<Json<GetVaaResponse>, RestError> {
123+
let price_id: PriceIdentifier = params.id.into();
124+
125+
let price_feeds_with_update_data = state
126+
.store
127+
.get_price_feeds_with_update_data(
128+
vec![price_id],
129+
RequestTime::FirstAfter(params.publish_time),
130+
)
131+
.map_err(|_| RestError::UpdateDataNotFound)?;
132+
133+
let vaa = price_feeds_with_update_data
134+
.update_data
135+
.batch_vaa
136+
.get(0)
137+
.map(|vaa_bytes| base64_standard_engine.encode(vaa_bytes))
138+
.ok_or(RestError::UpdateDataNotFound)?;
139+
140+
let publish_time = price_feeds_with_update_data
141+
.price_feeds
142+
.get(&price_id)
143+
.map(|price_feed| price_feed.get_price_unchecked().publish_time)
144+
.ok_or(RestError::UpdateDataNotFound)?;
145+
let publish_time: UnixTimestamp = publish_time
146+
.try_into()
147+
.map_err(|_| RestError::UpdateDataNotFound)?;
148+
149+
Ok(Json(GetVaaResponse { vaa, publish_time }))
150+
}
151+
152+
#[derive(Debug, Clone, Deref, DerefMut)]
153+
pub struct GetVaaCcipInput([u8; 40]);
154+
impl_deserialize_for_hex_string_wrapper!(GetVaaCcipInput, 40);
155+
156+
#[derive(Debug, serde::Deserialize)]
157+
pub struct GetVaaCcipQueryParams {
158+
data: GetVaaCcipInput,
159+
}
160+
161+
#[derive(Debug, serde::Serialize)]
162+
pub struct GetVaaCcipResponse {
163+
data: String, // TODO: Use a typed wrapper for the hex output with leading 0x.
164+
}
165+
166+
pub async fn get_vaa_ccip(
167+
State(state): State<super::State>,
168+
Query(params): Query<GetVaaCcipQueryParams>,
169+
) -> Result<Json<GetVaaCcipResponse>, RestError> {
170+
let price_id: PriceIdentifier = PriceIdentifier::new(params.data[0..32].try_into().unwrap());
171+
let publish_time = UnixTimestamp::from_be_bytes(params.data[32..40].try_into().unwrap());
172+
173+
let price_feeds_with_update_data = state
174+
.store
175+
.get_price_feeds_with_update_data(vec![price_id], RequestTime::FirstAfter(publish_time))
176+
.map_err(|_| RestError::UpdateDataNotFound)?;
177+
178+
let vaa = price_feeds_with_update_data
179+
.update_data
180+
.batch_vaa
181+
.get(0) // One price feed has only a single VAA as proof.
182+
.ok_or(RestError::UpdateDataNotFound)?;
183+
184+
// FIXME: We should return 5xx when the vaa is not found and 4xx when the price id is not there
185+
186+
Ok(Json(GetVaaCcipResponse {
187+
data: format!("0x{}", hex::encode(vaa)),
188+
}))
189+
}
190+
110191
// This function implements the `/live` endpoint. It returns a `200` status code. This endpoint is
111192
// used by the Kubernetes liveness probe.
112193
pub async fn live() -> Result<impl IntoResponse, std::convert::Infallible> {
@@ -116,5 +197,11 @@ pub async fn live() -> Result<impl IntoResponse, std::convert::Infallible> {
116197
// This is the index page for the REST service. It will list all the available endpoints.
117198
// TODO: Dynamically generate this list if possible.
118199
pub async fn index() -> impl IntoResponse {
119-
Json(["/live", "/latest_price_feeds", "/latest_vaas"])
200+
Json([
201+
"/live",
202+
"/api/latest_price_feeds?ids[]=<price_feed_id>&ids[]=<price_feed_id_2>&..",
203+
"/api/latest_vaas?ids[]=<price_feed_id>&ids[]=<price_feed_id_2>&...",
204+
"/api/get_vaa?id=<price_feed_id>&publish_time=<publish_time_in_unix_timestamp>",
205+
"/api/get_vaa_ccip?data=<0x<price_feed_id_32_bytes>+<publish_time_unix_timestamp_be_8_bytes>>",
206+
])
120207
}

0 commit comments

Comments
 (0)