Skip to content

Commit fcd1f34

Browse files
jbourassasurma
andauthored
Generate an enum for @oneOf input (#450)
* Add `is_one_of` to input * Extract struct building from input codegen * Generate enums for `oneOf` input * Clippy fixes * Empty commit for CI --------- Co-authored-by: Surma <[email protected]>
1 parent c5847ce commit fcd1f34

File tree

10 files changed

+180
-53
lines changed

10 files changed

+180
-53
lines changed

graphql_client/tests/input_object_variables.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ fn input_object_variables_default() {
3939
msg: default_input_object_variables_query::Variables::default_msg(),
4040
};
4141

42-
let out = serde_json::to_value(&variables).unwrap();
42+
let out = serde_json::to_value(variables).unwrap();
4343

4444
let expected_default = serde_json::json!({
4545
"msg":{"content":null,"to":{"category":null,"email":"[email protected]","name":null}}
@@ -130,7 +130,7 @@ pub struct RustNameQuery;
130130
#[test]
131131
fn rust_name_correctly_mapped() {
132132
use rust_name_query::*;
133-
let value = serde_json::to_value(&Variables {
133+
let value = serde_json::to_value(Variables {
134134
extern_: Some("hello".to_owned()),
135135
msg: <_>::default(),
136136
})

graphql_client/tests/one_of_input.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use graphql_client::*;
2+
use serde_json::*;
3+
4+
#[derive(GraphQLQuery)]
5+
#[graphql(
6+
schema_path = "tests/one_of_input/schema.graphql",
7+
query_path = "tests/one_of_input/query.graphql",
8+
variables_derives = "Clone"
9+
)]
10+
pub struct OneOfMutation;
11+
12+
#[test]
13+
fn one_of_input() {
14+
use one_of_mutation::*;
15+
16+
let author = Param::Author(Author { id: 1 });
17+
let _ = Param::Name("Mark Twain".to_string());
18+
let _ = Param::RecursiveDirect(Box::new(author.clone()));
19+
let _ = Param::RecursiveIndirect(Box::new(Recursive {
20+
param: Box::new(author.clone()),
21+
}));
22+
let _ = Param::RequiredInts(vec![1]);
23+
let _ = Param::OptionalInts(vec![Some(1)]);
24+
25+
let query = OneOfMutation::build_query(Variables { param: author });
26+
assert_eq!(
27+
json!({ "param": { "author":{ "id": 1 } } }),
28+
serde_json::to_value(&query.variables).expect("json"),
29+
);
30+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mutation OneOfMutation($param: Param!) {
2+
oneOfMutation(query: $param)
3+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
schema {
2+
mutation: Mutation
3+
}
4+
5+
type Mutation {
6+
oneOfMutation(mutation: Param!): Int
7+
}
8+
9+
input Param @oneOf {
10+
author: Author
11+
name: String
12+
recursiveDirect: Param
13+
recursiveIndirect: Recursive
14+
requiredInts: [Int!]
15+
optionalInts: [Int]
16+
}
17+
18+
input Author {
19+
id: Int!
20+
}
21+
22+
input Recursive {
23+
param: Param!
24+
}

graphql_client_cli/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ fn set_env_logger() {
153153
.init();
154154
}
155155

156-
fn colored_level<'a>(style: &'a mut Style, level: Level) -> StyledValue<'a, &'static str> {
156+
fn colored_level(style: &mut Style, level: Level) -> StyledValue<'_, &'static str> {
157157
match level {
158158
Level::Trace => style.set_color(Color::Magenta).value("TRACE"),
159159
Level::Debug => style.set_color(Color::Blue).value("DEBUG"),

graphql_client_codegen/src/codegen/inputs.rs

Lines changed: 108 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use super::shared::{field_rename_annotation, keyword_replace};
22
use crate::{
33
codegen_options::GraphQLClientCodegenOptions,
44
query::{BoundQuery, UsedTypes},
5-
schema::input_is_recursive_without_indirection,
5+
schema::{input_is_recursive_without_indirection, StoredInputType},
6+
type_qualifiers::GraphqlTypeQualifier,
67
};
7-
use heck::ToSnakeCase;
8+
use heck::{ToSnakeCase, ToUpperCamelCase};
89
use proc_macro2::{Ident, Span, TokenStream};
910
use quote::quote;
1011

@@ -17,48 +18,112 @@ pub(super) fn generate_input_object_definitions(
1718
all_used_types
1819
.inputs(query.schema)
1920
.map(|(_input_id, input)| {
20-
let normalized_name = options.normalization().input_name(input.name.as_str());
21-
let safe_name = keyword_replace(normalized_name);
22-
let struct_name = Ident::new(safe_name.as_ref(), Span::call_site());
23-
24-
let fields = input.fields.iter().map(|(field_name, field_type)| {
25-
let safe_field_name = keyword_replace(field_name.to_snake_case());
26-
let annotation = field_rename_annotation(field_name, safe_field_name.as_ref());
27-
let name_ident = Ident::new(safe_field_name.as_ref(), Span::call_site());
28-
let normalized_field_type_name = options
29-
.normalization()
30-
.field_type(field_type.id.name(query.schema));
31-
let optional_skip_serializing_none =
32-
if *options.skip_serializing_none() && field_type.is_optional() {
33-
Some(quote!(#[serde(skip_serializing_if = "Option::is_none")]))
34-
} else {
35-
None
36-
};
37-
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());
38-
let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers);
39-
let field_type = if field_type
40-
.id
41-
.as_input_id()
42-
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
43-
.unwrap_or(false)
44-
{
45-
quote!(Box<#field_type_tokens>)
46-
} else {
47-
field_type_tokens
48-
};
49-
50-
quote!(
51-
#optional_skip_serializing_none
52-
#annotation pub #name_ident: #field_type
53-
)
54-
});
55-
56-
quote! {
57-
#variable_derives
58-
pub struct #struct_name{
59-
#(#fields,)*
60-
}
21+
if input.is_one_of {
22+
generate_enum(input, options, variable_derives, query)
23+
} else {
24+
generate_struct(input, options, variable_derives, query)
6125
}
6226
})
6327
.collect()
6428
}
29+
30+
fn generate_struct(
31+
input: &StoredInputType,
32+
options: &GraphQLClientCodegenOptions,
33+
variable_derives: &impl quote::ToTokens,
34+
query: &BoundQuery<'_>,
35+
) -> TokenStream {
36+
let normalized_name = options.normalization().input_name(input.name.as_str());
37+
let safe_name = keyword_replace(normalized_name);
38+
let struct_name = Ident::new(safe_name.as_ref(), Span::call_site());
39+
40+
let fields = input.fields.iter().map(|(field_name, field_type)| {
41+
let safe_field_name = keyword_replace(field_name.to_snake_case());
42+
let annotation = field_rename_annotation(field_name, safe_field_name.as_ref());
43+
let name_ident = Ident::new(safe_field_name.as_ref(), Span::call_site());
44+
let normalized_field_type_name = options
45+
.normalization()
46+
.field_type(field_type.id.name(query.schema));
47+
let optional_skip_serializing_none =
48+
if *options.skip_serializing_none() && field_type.is_optional() {
49+
Some(quote!(#[serde(skip_serializing_if = "Option::is_none")]))
50+
} else {
51+
None
52+
};
53+
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());
54+
let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers);
55+
let field_type = if field_type
56+
.id
57+
.as_input_id()
58+
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
59+
.unwrap_or(false)
60+
{
61+
quote!(Box<#field_type_tokens>)
62+
} else {
63+
field_type_tokens
64+
};
65+
66+
quote!(
67+
#optional_skip_serializing_none
68+
#annotation pub #name_ident: #field_type
69+
)
70+
});
71+
72+
quote! {
73+
#variable_derives
74+
pub struct #struct_name{
75+
#(#fields,)*
76+
}
77+
}
78+
}
79+
80+
fn generate_enum(
81+
input: &StoredInputType,
82+
options: &GraphQLClientCodegenOptions,
83+
variable_derives: &impl quote::ToTokens,
84+
query: &BoundQuery<'_>,
85+
) -> TokenStream {
86+
let normalized_name = options.normalization().input_name(input.name.as_str());
87+
let safe_name = keyword_replace(normalized_name);
88+
let enum_name = Ident::new(safe_name.as_ref(), Span::call_site());
89+
90+
let variants = input.fields.iter().map(|(field_name, field_type)| {
91+
let variant_name = field_name.to_upper_camel_case();
92+
let safe_variant_name = keyword_replace(&variant_name);
93+
94+
let annotation = field_rename_annotation(field_name.as_ref(), &variant_name);
95+
let name_ident = Ident::new(safe_variant_name.as_ref(), Span::call_site());
96+
97+
let normalized_field_type_name = options
98+
.normalization()
99+
.field_type(field_type.id.name(query.schema));
100+
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());
101+
102+
// Add the required qualifier so that the variant's field isn't wrapped in Option
103+
let mut qualifiers = vec![GraphqlTypeQualifier::Required];
104+
qualifiers.extend(field_type.qualifiers.iter().cloned());
105+
106+
let field_type_tokens = super::decorate_type(&type_name, &qualifiers);
107+
let field_type = if field_type
108+
.id
109+
.as_input_id()
110+
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
111+
.unwrap_or(false)
112+
{
113+
quote!(Box<#field_type_tokens>)
114+
} else {
115+
field_type_tokens
116+
};
117+
118+
quote!(
119+
#annotation #name_ident(#field_type)
120+
)
121+
});
122+
123+
quote! {
124+
#variable_derives
125+
pub enum #enum_name{
126+
#(#variants,)*
127+
}
128+
}
129+
}

graphql_client_codegen/src/deprecation.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,17 @@ pub enum DeprecationStatus {
88
}
99

1010
/// The available deprecation strategies.
11-
#[derive(Debug, PartialEq, Eq, Clone)]
11+
#[derive(Debug, PartialEq, Eq, Clone, Default)]
1212
pub enum DeprecationStrategy {
1313
/// Allow use of deprecated items in queries, and say nothing.
1414
Allow,
1515
/// Fail compilation if a deprecated item is used.
1616
Deny,
1717
/// Allow use of deprecated items in queries, but warn about them (default).
18+
#[default]
1819
Warn,
1920
}
2021

21-
impl Default for DeprecationStrategy {
22-
fn default() -> Self {
23-
DeprecationStrategy::Warn
24-
}
25-
}
26-
2722
impl std::str::FromStr for DeprecationStrategy {
2823
type Err = ();
2924

graphql_client_codegen/src/schema.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ impl StoredInputFieldType {
210210
pub(crate) struct StoredInputType {
211211
pub(crate) name: String,
212212
pub(crate) fields: Vec<(String, StoredInputFieldType)>,
213+
pub(crate) is_one_of: bool,
213214
}
214215

215216
/// Intermediate representation for a parsed GraphQL schema used during code generation.

graphql_client_codegen/src/schema/graphql_parser_conversion.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,11 @@ fn ingest_input<'doc, T>(schema: &mut Schema, input: &mut parser::InputObjectTyp
289289
where
290290
T: graphql_parser::query::Text<'doc>,
291291
{
292+
let is_one_of = input
293+
.directives
294+
.iter()
295+
.any(|directive| directive.name.as_ref() == "oneOf");
296+
292297
let input = super::StoredInputType {
293298
name: input.name.as_ref().into(),
294299
fields: input
@@ -305,6 +310,7 @@ where
305310
)
306311
})
307312
.collect(),
313+
is_one_of,
308314
};
309315

310316
schema.stored_inputs.push(input);

graphql_client_codegen/src/schema/json_conversion.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@ fn ingest_input(schema: &mut Schema, input: &mut FullType) {
296296
let input = super::StoredInputType {
297297
fields,
298298
name: input.name.take().expect("Input without a name"),
299+
// The one-of input spec is not stable yet, thus the introspection query does not have
300+
// `isOneOf`, so this is always false.
301+
is_one_of: false,
299302
};
300303

301304
schema.stored_inputs.push(input);

0 commit comments

Comments
 (0)