Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple media types per endpoint #184

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions derive/src/resource_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,38 +320,32 @@ pub fn expand_resource_error(input: DeriveInput) -> Result<TokenStream> {

from_impls.push(quote! {
impl #generics ::std::convert::From<#from_ty> for #ident #generics
where #( #fields_where ),*
where #(#fields_where),*
{
fn from(#from_ident: #from_ty) -> Self {
#( #fields_let )*
#(#fields_let)*
Self::#var_ident #fields_pat
}
}
});
}

let status_codes = if cfg!(feature = "openapi") {
let were = variants
.iter()
.filter_map(|variant| variant.were())
.collect::<Vec<_>>();

let response_schema = if cfg!(feature = "openapi") {
let codes = variants.iter().map(|v| match v.status() {
Some(code) => quote!(status_codes.push(#code);),
None => {
// we would've errored before if from_ty was not set
let from_ty = &v.from_ty.as_ref().unwrap().1;
quote!(status_codes.extend(<#from_ty as ::gotham_restful::IntoResponseError>::status_codes());)
quote!(status_codes.extend(<#from_ty as ::gotham_restful::ResponseSchema>::status_codes());)
}
});
Some(quote! {
fn status_codes() -> ::std::vec::Vec<::gotham_restful::gotham::hyper::StatusCode> {
let mut status_codes = <::std::vec::Vec<::gotham_restful::gotham::hyper::StatusCode>>::new();
#(#codes)*
status_codes
}
})
} else {
None
};

let schema = if cfg!(feature = "openapi") {
let codes = variants.iter().map(|v| match v.status() {
let codes_schema = variants.iter().map(|v| match v.status() {
Some(code) => quote! {
#code => <::gotham_restful::NoContent as ::gotham_restful::ResponseSchema>::schema(
::gotham_restful::gotham::hyper::StatusCode::NO_CONTENT
Expand All @@ -361,28 +355,35 @@ pub fn expand_resource_error(input: DeriveInput) -> Result<TokenStream> {
// we would've errored before if from_ty was not set
let from_ty = &v.from_ty.as_ref().unwrap().1;
quote! {
code if <#from_ty as ::gotham_restful::IntoResponseError>::status_codes().contains(&code) => {
<#from_ty as ::gotham_restful::IntoResponseError>::schema(code)
code if <#from_ty as ::gotham_restful::ResponseSchema>::status_codes().contains(&code) => {
<#from_ty as ::gotham_restful::ResponseSchema>::schema(code)
}
}
}
});

Some(quote! {
fn schema(code: ::gotham_restful::gotham::hyper::StatusCode) -> ::gotham_restful::private::OpenapiSchema {
match code {
#(#codes,)*
code => panic!("Invalid status code {}", code)
impl #generics ::gotham_restful::ResponseSchema for #ident #generics
where #(#were),*
{
fn status_codes() -> ::std::vec::Vec<::gotham_restful::gotham::hyper::StatusCode> {
let mut status_codes = <::std::vec::Vec<::gotham_restful::gotham::hyper::StatusCode>>::new();
#(#codes)*
status_codes
}

fn schema(code: ::gotham_restful::gotham::hyper::StatusCode) -> ::std::vec::Vec<::gotham_restful::MimeAndSchema> {
match code {
#(#codes_schema,)*
code => panic!("Invalid status code {}", code)
}
}
}
})
} else {
None
};

let were = variants
.iter()
.filter_map(|variant| variant.were())
.collect::<Vec<_>>();
let variants = variants
.into_iter()
.map(|variant| variant.into_match_arm(&ident))
Expand All @@ -392,21 +393,20 @@ pub fn expand_resource_error(input: DeriveInput) -> Result<TokenStream> {
#display_impl

impl #generics ::gotham_restful::IntoResponseError for #ident #generics
where #( #were ),*
where #(#were),*
{
type Err = ::gotham_restful::private::serde_json::Error;

fn into_response_error(self) -> ::std::result::Result<::gotham_restful::Response, Self::Err>
{
match self {
#( #variants ),*
#(#variants),*
}
}

#status_codes
#schema
}

#( #from_impls )*
#response_schema

#(#from_impls)*
})
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ pub use response::{
NoContent, Raw, Redirect, Response, Success
};
#[cfg(feature = "openapi")]
pub use response::{IntoResponseWithSchema, ResponseSchema};
pub use response::{IntoResponseWithSchema, MimeAndSchema, ResponseSchema};

mod routing;
pub use routing::{DrawResourceRoutes, DrawResources};
Expand Down
54 changes: 32 additions & 22 deletions src/openapi/operation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::SECURITY_NAME;
use crate::{response::OrAllTypes, EndpointWithSchema, IntoResponse, RequestBody};
use crate::{response::OrAllTypes, EndpointWithSchema, RequestBody};
use gotham::{hyper::StatusCode, mime::Mime};
use openapi_type::{
indexmap::IndexMap,
Expand Down Expand Up @@ -97,8 +97,7 @@ pub(crate) struct OperationDescription {
operation_id: Option<String>,
description: Option<String>,

accepted_types: Option<Vec<Mime>>,
responses: HashMap<StatusCode, ReferenceOr<Schema>>,
responses: HashMap<StatusCode, Vec<(Mime, ReferenceOr<Schema>)>>,
params: OperationParams,
body_schema: Option<ReferenceOr<Schema>>,
supported_types: Option<Vec<Mime>>,
Expand All @@ -109,7 +108,7 @@ impl OperationDescription {
/// Create a new operation description for the given endpoint type and schema. If the endpoint
/// does not specify an operation id, the path is used to generate one.
pub(crate) fn new<E: EndpointWithSchema>(
responses: HashMap<StatusCode, ReferenceOr<Schema>>,
responses: HashMap<StatusCode, Vec<(Mime, ReferenceOr<Schema>)>>,
path: &str
) -> Self {
let operation_id = E::operation_id().or_else(|| {
Expand All @@ -120,7 +119,6 @@ impl OperationDescription {
operation_id,
description: E::description(),

accepted_types: E::Output::accepted_types(),
responses,
params: Default::default(),
body_schema: None,
Expand All @@ -142,13 +140,10 @@ impl OperationDescription {
self.supported_types = Body::supported_types();
}

fn schema_to_content(
types: Vec<Mime>,
schema: ReferenceOr<Schema>
) -> IndexMap<String, MediaType> {
fn schema_to_content(schemas: Vec<(Mime, ReferenceOr<Schema>)>) -> IndexMap<String, MediaType> {
let mut content: IndexMap<String, MediaType> = IndexMap::new();
for ty in types {
content.insert(ty.to_string(), MediaType {
for (mime, schema) in schemas {
content.insert(mime.to_string(), MediaType {
schema: Some(schema.clone()),
..Default::default()
});
Expand All @@ -161,7 +156,6 @@ impl OperationDescription {
let (
operation_id,
description,
accepted_types,
responses,
params,
body_schema,
Expand All @@ -170,7 +164,6 @@ impl OperationDescription {
) = (
self.operation_id,
self.description,
self.accepted_types,
self.responses,
self.params,
self.body_schema,
Expand All @@ -180,9 +173,8 @@ impl OperationDescription {

let responses: IndexMap<OAStatusCode, ReferenceOr<Response>> = responses
.into_iter()
.map(|(code, schema)| {
let content =
Self::schema_to_content(accepted_types.clone().or_all_types(), schema);
.map(|(code, schemas)| {
let content = Self::schema_to_content(schemas);
(
OAStatusCode::Code(code.as_u16()),
Item(Response {
Expand All @@ -199,7 +191,13 @@ impl OperationDescription {

let request_body = body_schema.map(|schema| {
Item(OARequestBody {
content: Self::schema_to_content(supported_types.or_all_types(), schema),
content: Self::schema_to_content(
supported_types
.or_all_types()
.into_iter()
.map(|mime| (mime, schema.clone()))
.collect()
),
required: true,
..Default::default()
})
Expand Down Expand Up @@ -232,23 +230,35 @@ impl OperationDescription {
#[cfg(test)]
mod test {
use super::*;
use crate::{NoContent, Raw, ResponseSchema};
use crate::{IntoResponse, MimeAndSchema, NoContent, Raw, ResponseSchema};

fn schema_to_content(schema: Vec<MimeAndSchema>) -> IndexMap<String, MediaType> {
OperationDescription::schema_to_content(
schema
.into_iter()
.map(|mime_schema| {
(
mime_schema.mime,
ReferenceOr::Item(mime_schema.schema.schema)
)
})
.collect()
)
}

#[test]
fn no_content_schema_to_content() {
let types = NoContent::accepted_types();
let schema = <NoContent as ResponseSchema>::schema(StatusCode::NO_CONTENT);
let content =
OperationDescription::schema_to_content(types.or_all_types(), Item(schema.schema));
let content = schema_to_content(schema);
assert!(content.is_empty());
}

#[test]
fn raw_schema_to_content() {
let types = Raw::<&str>::accepted_types();
let schema = <Raw<&str> as ResponseSchema>::schema(StatusCode::OK);
let content =
OperationDescription::schema_to_content(types.or_all_types(), Item(schema.schema));
let content = schema_to_content(schema);
assert_eq!(content.len(), 1);
let json = serde_json::to_string(&content.values().nth(0).unwrap()).unwrap();
assert_eq!(json, r#"{"schema":{"type":"string","format":"binary"}}"#);
Expand Down
10 changes: 9 additions & 1 deletion src/openapi/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,15 @@ macro_rules! implOpenapiRouter {
for code in E::Output::status_codes() {
responses.insert(
code,
(self.0).openapi_builder.add_schema(E::Output::schema(code))
E::Output::schema(code)
.into_iter()
.map(|mime_schema| {
(
mime_schema.mime,
(self.0).openapi_builder.add_schema(mime_schema.schema)
)
})
.collect()
);
}
let mut path = format!("{}/{}", self.0.scope.unwrap_or_default(), self.1);
Expand Down
13 changes: 7 additions & 6 deletions src/response/auth_result.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{IntoResponseError, Response};
#[cfg(feature = "openapi")]
use crate::{MimeAndSchema, Raw, ResponseSchema};
use gotham::{hyper::StatusCode, mime::TEXT_PLAIN_UTF_8};
use gotham_restful_derive::ResourceError;
#[cfg(feature = "openapi")]
use openapi_type::{OpenapiSchema, OpenapiType};

/// This is an error type that always yields a _403 Forbidden_ response. This type
/// is best used in combination with [`AuthSuccess`] or [`AuthResult`].
Expand All @@ -26,16 +26,17 @@ impl IntoResponseError for AuthError {
Some(TEXT_PLAIN_UTF_8)
))
}
}

#[cfg(feature = "openapi")]
#[cfg(feature = "openapi")]
impl ResponseSchema for AuthError {
fn status_codes() -> Vec<StatusCode> {
vec![StatusCode::FORBIDDEN]
}

#[cfg(feature = "openapi")]
fn schema(code: StatusCode) -> OpenapiSchema {
fn schema(code: StatusCode) -> Vec<MimeAndSchema> {
assert_eq!(code, StatusCode::FORBIDDEN);
<super::Raw<String> as OpenapiType>::schema()
<Raw<String> as ResponseSchema>::schema(StatusCode::OK)
}
}

Expand Down
15 changes: 13 additions & 2 deletions src/response/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,22 @@ pub trait IntoResponse {
fn into_response(self) -> BoxFuture<'static, Result<Response, Self::Err>>;

/// Return a list of supported mime types.
#[cfg_attr(
feature = "openapi",
doc = "\n Note that this does not influence the auto-generated OpenAPI specification."
)]
fn accepted_types() -> Option<Vec<Mime>> {
None
}
}

#[cfg(feature = "openapi")]
#[derive(Debug)]
pub struct MimeAndSchema {
pub mime: Mime,
pub schema: OpenapiSchema
}

/// Additional details for [IntoResponse] to be used with an OpenAPI-aware router.
#[cfg(feature = "openapi")]
pub trait ResponseSchema {
Expand All @@ -195,7 +206,7 @@ pub trait ResponseSchema {
/// Return the schema of the response for the given status code. The code may
/// only be one that was previously returned by [Self::status_codes]. The
/// implementation should panic if that is not the case.
fn schema(code: StatusCode) -> OpenapiSchema;
fn schema(code: StatusCode) -> Vec<MimeAndSchema>;
}

#[cfg(feature = "openapi")]
Expand Down Expand Up @@ -284,7 +295,7 @@ where
Res::status_codes()
}

fn schema(code: StatusCode) -> OpenapiSchema {
fn schema(code: StatusCode) -> Vec<MimeAndSchema> {
Res::schema(code)
}
}
Expand Down
Loading