Skip to content

Commit 78d6b76

Browse files
committed
Fixed authorization type interceptor flow. (#8096)
1 parent 23d261d commit 78d6b76

6 files changed

+252
-86
lines changed

src/HotChocolate/Core/src/Authorization/AuthorizationTypeInterceptor.cs

+100-83
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ internal override void OnBeforeCreateSchemaInternal(
5353
schemaBuilder.SetSchema(d => _schemaContextData = d.Extend().Definition.ContextData);
5454
}
5555

56+
private ITypeCompletionContext _tc = default!;
57+
private AuthorizeDirectiveType _t = default!;
58+
5659
public override void OnBeforeCompleteName(
5760
ITypeCompletionContext completionContext,
5861
DefinitionBase definition)
@@ -69,14 +72,33 @@ public override void OnBeforeCompleteName(
6972
case UnionType when definition is UnionTypeDefinition unionTypeDef:
7073
_unionTypes.Add(new UnionTypeInfo(completionContext, unionTypeDef));
7174
break;
75+
76+
case AuthorizeDirectiveType type:
77+
_t = type;
78+
_tc = completionContext;
79+
break;
7280
}
7381

7482
// note, we do not need to collect interfaces as the object type has a
7583
// list implements that links to the interfaces that expose an object type.
7684
}
7785

78-
public override void OnBeforeCompleteTypes()
86+
public override void OnAfterResolveRootType(
87+
ITypeCompletionContext completionContext,
88+
ObjectTypeDefinition definition,
89+
OperationType operationType)
90+
{
91+
if (operationType is OperationType.Query)
92+
{
93+
_queryContext = completionContext;
94+
}
95+
}
96+
97+
public override void OnBeforeCompleteMetadata()
7998
{
99+
_t.CompleteMetadata(_tc);
100+
((RegisteredType)_tc).Status = TypeStatus.MetadataCompleted;
101+
80102
// at this stage in the type initialization we will create some state that we
81103
// will use to transform the schema authorization.
82104
var state = _state = CreateState();
@@ -99,18 +121,7 @@ public override void OnBeforeCompleteTypes()
99121
FindFieldsAndApplyAuthMiddleware(state);
100122
}
101123

102-
public override void OnAfterResolveRootType(
103-
ITypeCompletionContext completionContext,
104-
ObjectTypeDefinition definition,
105-
OperationType operationType)
106-
{
107-
if (operationType is OperationType.Query)
108-
{
109-
_queryContext = completionContext;
110-
}
111-
}
112-
113-
public override void OnBeforeCompleteType(
124+
public override void OnBeforeCompleteMetadata(
114125
ITypeCompletionContext completionContext,
115126
DefinitionBase definition)
116127
{
@@ -130,101 +141,107 @@ public override void OnAfterMakeExecutable()
130141
{
131142
var objectType = (ObjectType)type.TypeReg.Type;
132143

133-
if (objectType.ContextData.TryGetValue(NodeResolver, out var o) &&
134-
o is NodeResolverInfo nodeResolverInfo)
144+
if (!objectType.ContextData.TryGetValue(NodeResolver, out var o)
145+
|| o is not NodeResolverInfo nodeResolverInfo)
135146
{
136-
var pipeline = nodeResolverInfo.Pipeline;
137-
var directives = (DirectiveCollection)objectType.Directives;
138-
var length = directives.Count;
139-
ref var start = ref directives.GetReference();
147+
continue;
148+
}
140149

141-
for (var i = length - 1; i >= 0; i--)
142-
{
143-
var directive = Unsafe.Add(ref start, i);
150+
var pipeline = nodeResolverInfo.Pipeline;
151+
var directives = (DirectiveCollection)objectType.Directives;
152+
var length = directives.Count;
153+
ref var start = ref directives.GetReference();
144154

145-
if (directive.Type.Name.EqualsOrdinal(Authorize))
146-
{
147-
var authDir = directive.AsValue<AuthorizeDirective>();
148-
pipeline = CreateAuthMiddleware(authDir).Middleware.Invoke(pipeline);
149-
}
150-
}
155+
for (var i = length - 1; i >= 0; i--)
156+
{
157+
var directive = Unsafe.Add(ref start, i);
151158

152-
type.TypeDef.ContextData[NodeResolver] =
153-
new NodeResolverInfo(nodeResolverInfo.QueryField, pipeline);
159+
if (directive.Type.Name.EqualsOrdinal(Authorize))
160+
{
161+
var authDir = directive.AsValue<AuthorizeDirective>();
162+
pipeline = CreateAuthMiddleware(authDir).Middleware.Invoke(pipeline);
163+
}
154164
}
165+
166+
type.TypeDef.ContextData[NodeResolver] =
167+
new NodeResolverInfo(nodeResolverInfo.QueryField, pipeline);
155168
}
156169
}
157170

158171
private void InspectObjectTypesForAuthDirective(State state)
159172
{
160173
foreach (var type in _objectTypes)
161174
{
162-
if (IsAuthorizedType(type.TypeDef))
175+
if (!IsAuthorizedType(type.TypeDef))
163176
{
164-
var registration = type.TypeReg;
165-
var mainTypeRef = registration.TypeReference;
177+
continue;
178+
}
179+
180+
var registration = type.TypeReg;
181+
var mainTypeRef = registration.TypeReference;
166182

167-
// if this type is a root type we will copy type level auth down to the field.
168-
if (registration.IsQueryType == true ||
169-
registration.IsMutationType == true ||
170-
registration.IsSubscriptionType == true)
183+
// if this type is a root type we will copy type level auth down to the field.
184+
if (registration.IsQueryType == true ||
185+
registration.IsMutationType == true ||
186+
registration.IsSubscriptionType == true)
187+
{
188+
foreach (var fieldDef in type.TypeDef.Fields)
171189
{
172-
foreach (var fieldDef in type.TypeDef.Fields)
190+
// we are not interested in introspection fields or the node fields.
191+
if (fieldDef.IsIntrospectionField || fieldDef.IsNodeField())
173192
{
174-
// we are not interested in introspection fields or the node fields.
175-
if (fieldDef.IsIntrospectionField || fieldDef.IsNodeField())
176-
{
177-
continue;
178-
}
179-
180-
// if the field contains the AnonymousAllowed flag we will not
181-
// apply authorization on it.
182-
if(fieldDef.GetContextData().ContainsKey(AllowAnonymous))
183-
{
184-
continue;
185-
}
193+
continue;
194+
}
186195

187-
ApplyAuthMiddleware(fieldDef, registration, false);
196+
// if the field contains the AnonymousAllowed flag we will not
197+
// apply authorization on it.
198+
if(fieldDef.GetContextData().ContainsKey(AllowAnonymous))
199+
{
200+
continue;
188201
}
189-
}
190202

191-
foreach (var reference in registration.References)
192-
{
193-
state.AuthTypes.Add(reference);
194-
state.NeedsAuth.Add(reference);
203+
ApplyAuthMiddleware(fieldDef, registration, false);
195204
}
205+
}
206+
207+
foreach (var reference in registration.References)
208+
{
209+
state.AuthTypes.Add(reference);
210+
state.NeedsAuth.Add(reference);
211+
}
212+
213+
if (!type.TypeDef.HasInterfaces)
214+
{
215+
continue;
216+
}
196217

197-
if (type.TypeDef.HasInterfaces)
218+
CollectInterfaces(
219+
type.TypeDef.GetInterfaces(),
220+
interfaceTypeRef =>
198221
{
199-
CollectInterfaces(
200-
type.TypeDef.GetInterfaces(),
201-
interfaceTypeRef =>
222+
if (_typeRegistry.TryGetType(
223+
interfaceTypeRef,
224+
out var interfaceTypeReg))
225+
{
226+
foreach (var typeRef in interfaceTypeReg.References)
202227
{
203-
if (_typeRegistry.TryGetType(
204-
interfaceTypeRef,
205-
out var interfaceTypeReg))
228+
state.NeedsAuth.Add(typeRef);
229+
230+
if (!state.AbstractToConcrete.TryGetValue(
231+
typeRef,
232+
out var authTypeRefs))
206233
{
207-
foreach (var typeRef in interfaceTypeReg.References)
208-
{
209-
state.NeedsAuth.Add(typeRef);
210-
211-
if (!state.AbstractToConcrete.TryGetValue(
212-
typeRef,
213-
out var authTypeRefs))
214-
{
215-
authTypeRefs = [];
216-
state.AbstractToConcrete.Add(typeRef, authTypeRefs);
217-
}
218-
219-
authTypeRefs.Add(mainTypeRef);
220-
}
234+
authTypeRefs = [];
235+
state.AbstractToConcrete.Add(typeRef, authTypeRefs);
221236
}
222-
},
223-
state);
224237

225-
state.Completed.Clear();
226-
}
227-
}
238+
authTypeRefs.Add(mainTypeRef);
239+
}
240+
}
241+
},
242+
state);
243+
244+
state.Completed.Clear();
228245
}
229246
}
230247

@@ -623,7 +640,7 @@ private State CreateState()
623640
}
624641
}
625642

626-
static file class AuthorizationTypeInterceptorExtensions
643+
file static class AuthorizationTypeInterceptorExtensions
627644
{
628645
public static bool IsNodeField(this ObjectFieldDefinition fieldDef)
629646
{

src/HotChocolate/Core/src/Types/Configuration/TypeInitializer.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ private void CompleteTypes()
595595

596596
internal bool CompleteType(RegisteredType registeredType)
597597
{
598-
if (registeredType.Status is TypeStatus.Completed)
598+
if (registeredType.Type.IsCompleted)
599599
{
600600
return true;
601601
}
@@ -615,9 +615,10 @@ private void CompleteMetadata()
615615

616616
foreach (var registeredType in _typeRegistry.Types)
617617
{
618-
if (!registeredType.IsExtension)
618+
if (registeredType is { IsExtension: false, Status: TypeStatus.Completed })
619619
{
620620
registeredType.Type.CompleteMetadata(registeredType);
621+
registeredType.Status = TypeStatus.MetadataCompleted;
621622
}
622623
}
623624

@@ -635,6 +636,7 @@ private void MakeExecutable()
635636
if (!registeredType.IsExtension)
636637
{
637638
registeredType.Type.MakeExecutable(registeredType);
639+
registeredType.Status = TypeStatus.Executable;
638640
}
639641
}
640642

@@ -650,6 +652,7 @@ private void FinalizeTypes()
650652
if (!registeredType.IsExtension)
651653
{
652654
registeredType.Type.FinalizeType(registeredType);
655+
registeredType.Status = TypeStatus.Finalized;
653656
}
654657
}
655658

src/HotChocolate/Core/test/Authorization.Tests/AnnotationBasedAuthorizationTests.cs

+25-1
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,30 @@ public async Task Authorize_Type_Field()
499499
Assert.Equal(401, value);
500500
}
501501

502+
[Fact]
503+
public async Task Authorize_Node_Field_Schema()
504+
{
505+
// arrange
506+
var handler = new AuthHandler(
507+
resolver: (_, _) => AuthorizeResult.Allowed,
508+
validation: (_, d) => d.Policy.EqualsOrdinal("READ_NODE")
509+
? AuthorizeResult.NotAllowed
510+
: AuthorizeResult.Allowed);
511+
512+
// act
513+
var services = CreateServices(
514+
handler,
515+
options =>
516+
{
517+
options.ConfigureNodeFields =
518+
descriptor => descriptor.Authorize("READ_NODE", ApplyPolicy.Validation);
519+
});
520+
521+
// assert
522+
var executor = await services.GetRequestExecutorAsync();
523+
executor.Schema.MatchSnapshot();
524+
}
525+
502526
[Fact]
503527
public async Task Authorize_Node_Field()
504528
{
@@ -513,7 +537,7 @@ public async Task Authorize_Node_Field()
513537
options =>
514538
{
515539
options.ConfigureNodeFields =
516-
descriptor => { descriptor.Authorize("READ_NODE", ApplyPolicy.Validation); };
540+
descriptor => descriptor.Authorize("READ_NODE", ApplyPolicy.Validation);
517541
});
518542
var executor = await services.GetRequestExecutorAsync();
519543

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using HotChocolate.Execution;
2+
using HotChocolate.Resolvers;
3+
using Microsoft.Extensions.DependencyInjection;
4+
5+
namespace HotChocolate.Authorization;
6+
7+
public class SchemaFirstAuthorizationTests
8+
{
9+
[Fact]
10+
public async Task Authorize_Apply_Can_Be_Omitted()
11+
{
12+
var schema = await new ServiceCollection()
13+
.AddGraphQLServer()
14+
.AddDocumentFromString(
15+
"""
16+
type Query @authorize(roles: [ "policy_tester_noupdate", "policy_tester_update_noread", "authorizationHandlerTester" ]) {
17+
hello: String @authorize(roles: ["admin"])
18+
}
19+
""")
20+
.AddResolver("Query", "hello", "world")
21+
.AddAuthorizationHandler<MockAuth>()
22+
.BuildSchemaAsync();
23+
24+
schema.MatchSnapshot();
25+
}
26+
27+
private sealed class MockAuth : IAuthorizationHandler
28+
{
29+
public ValueTask<AuthorizeResult> AuthorizeAsync(
30+
IMiddlewareContext context,
31+
AuthorizeDirective directive,
32+
CancellationToken cancellationToken = default)
33+
=> new(AuthorizeResult.NotAllowed);
34+
35+
public ValueTask<AuthorizeResult> AuthorizeAsync(
36+
AuthorizationContext context,
37+
IReadOnlyList<AuthorizeDirective> directives,
38+
CancellationToken cancellationToken = default)
39+
=> new(AuthorizeResult.NotAllowed);
40+
}
41+
}

0 commit comments

Comments
 (0)