Skip to content

Commit 5ad4a69

Browse files
author
Ben Elam
committed
Implement auth for inputstep
1 parent b969bc9 commit 5ad4a69

File tree

3 files changed

+88
-8
lines changed

3 files changed

+88
-8
lines changed

orchestrator/services/processes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,11 @@ def thread_resume_process(
523523

524524
form = pstat.log[0].form
525525

526-
user_input = post_form(form, pstat.state.unwrap(), user_inputs)
526+
# Add OIDC user to state to be processed by form for authorization
527+
state = pstat.state.unwrap()
528+
state["__process_user"] = user_model
529+
530+
user_input = post_form(form, state, user_inputs)
527531

528532
if user:
529533
pstat.update(current_user=user)

orchestrator/workflow.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313

1414

1515
from __future__ import annotations
16-
1716
import contextvars
1817
import functools
1918
import inspect
2019
import secrets
2120
from collections.abc import Callable
2221
from dataclasses import asdict, dataclass
22+
from functools import update_wrapper
23+
from http import HTTPStatus
2324
from itertools import dropwhile
2425
from typing import (
2526
Any,
@@ -40,6 +41,7 @@
4041

4142
from nwastdlib import const, identity
4243
from oauth2_lib.fastapi import OIDCUserModel
44+
from orchestrator.api.error_handling import raise_status
4345
from orchestrator.config.assignee import Assignee
4446
from orchestrator.db import db, transactional
4547
from orchestrator.services.settings import get_engine_settings
@@ -99,7 +101,7 @@ def __call__(self) -> NoReturn: ...
99101

100102

101103
def make_step_function(
102-
f: Callable, name: str, form: InputFormGenerator | None = None, assignee: Assignee | None = Assignee.SYSTEM
104+
f: Callable, name: str, form: InputFormGenerator | None = None, assignee: Assignee | None = Assignee.SYSTEM, authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None
103105
) -> Step:
104106
step_func = cast(Step, f)
105107

@@ -166,14 +168,39 @@ def __repr__(self) -> str:
166168
return f"StepList [{', '.join(repr(x) for x in self)}]"
167169

168170

169-
def _handle_simple_input_form_generator(f: StateInputStepFunc) -> StateInputFormGenerator:
171+
def _handle_simple_input_form_generator(f: StateInputStepFunc, authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None) -> StateInputFormGenerator:
172+
"""Processes f into a form generator and injects a pre-hook for user authorization"""
173+
def authorize_user_from_state(state: State) -> None:
174+
logger.error("authorize_user_from_state: called")
175+
user_model = state.pop("__process_user", None)
176+
if user_model is not None:
177+
user_model = cast(OIDCUserModel, user_model)
178+
else:
179+
logger.error("authorize_user_from_state: no user model")
180+
181+
if authorize_callback is not None:
182+
logger.error("authorize_user_from_state: callback found")
183+
authorize_callback(user_model)
184+
if not authorize_callback(user_model):
185+
logger.error("authorize_user_from_state: FORBIDDEN")
186+
#TODO not sure that step name is available here, but could put it on state?
187+
raise_status(HTTPStatus.FORBIDDEN, "User is not authorized to execute step")
188+
else:
189+
logger.error("authorize_user_from_state: no callback!")
190+
170191
if inspect.isgeneratorfunction(f):
171-
return cast(StateInputFormGenerator, f)
192+
def generator_wrapper(state: State):
193+
authorize_user_from_state(state)
194+
return f(state)
195+
196+
update_wrapper(generator_wrapper, f)
197+
return cast(StateInputFormGenerator, generator_wrapper)
172198
if inspect.isgenerator(f):
173199
raise ValueError("Got a generator object instead of function, this is not correct")
174200

175201
# If f is a SimpleInputFormGenerator convert to new style generator function
176202
def form_generator(state: State) -> FormGenerator:
203+
authorize_user_from_state(state)
177204
user_input: FormPage = yield cast(StateSimpleInputFormGenerator, f)(state)
178205
return user_input.model_dump()
179206

@@ -270,7 +297,7 @@ def wrapper(state: State) -> Process:
270297
return decorator
271298

272299

273-
def inputstep(name: str, assignee: Assignee) -> Callable[[InputStepFunc], Step]:
300+
def inputstep(name: str, assignee: Assignee, authorize_callback: Callable[[OIDCUserModel | None], bool] | None = None) -> Callable[[InputStepFunc], Step]:
274301
"""Add user input step to workflow.
275302
276303
IMPORTANT: In contrast to other workflow steps, the `@inputstep` wrapped function will not run in the
@@ -291,7 +318,7 @@ def decorator(func: InputStepFunc) -> Step:
291318
def wrapper(state: State) -> FormGenerator:
292319
form_generator_in_form_inject_args = form_inject_args(func)
293320

294-
form_generator = _handle_simple_input_form_generator(form_generator_in_form_inject_args)
321+
form_generator = _handle_simple_input_form_generator(form_generator_in_form_inject_args, authorize_callback=authorize_callback)
295322

296323
return form_generator(state)
297324

test/unit_tests/api/test_processes.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from uuid import uuid4
77

88
import pytest
9+
from pydantic_forms.core import FormPage
10+
from pydantic_forms.types import State
911
from sqlalchemy import select
1012

1113
from oauth2_lib.fastapi import OIDCUserModel
@@ -22,7 +24,7 @@
2224
from orchestrator.services.settings import get_engine_settings
2325
from orchestrator.settings import app_settings
2426
from orchestrator.targets import Target
25-
from orchestrator.workflow import ProcessStatus, done, init, step, workflow
27+
from orchestrator.workflow import ProcessStatus, done, init, inputstep, step, workflow
2628
from test.unit_tests.helpers import URL_STR_TYPE
2729
from test.unit_tests.workflows import WorkflowInstanceForTests
2830

@@ -593,3 +595,50 @@ def unauthorized_workflow():
593595
with WorkflowInstanceForTests(unauthorized_workflow, "unauthorized_workflow"):
594596
response = test_client.post("/api/processes/unauthorized_workflow", json=[{}])
595597
assert HTTPStatus.FORBIDDEN == response.status_code
598+
599+
600+
def test_inputstep_authorization(test_client):
601+
def disallow(_: OIDCUserModel | None = None) -> bool:
602+
return False
603+
604+
def allow(_: OIDCUserModel | None = None) -> bool:
605+
return True
606+
607+
class ConfirmForm(FormPage):
608+
confirm: bool
609+
610+
@inputstep("unauthorized_resume", assignee=Assignee.SYSTEM, authorize_callback=disallow)
611+
def unauthorized_resume(state: State) -> State:
612+
user_input = yield ConfirmForm
613+
return user_input.model_dump()
614+
615+
@inputstep("authorized_resume", assignee=Assignee.SYSTEM, authorize_callback=allow)
616+
def authorized_resume(state: State) -> State:
617+
user_input = yield ConfirmForm
618+
return user_input.model_dump()
619+
620+
@inputstep("noauth_resume", assignee=Assignee.SYSTEM)
621+
def noauth_resume(state: State) -> State:
622+
user_input = yield ConfirmForm
623+
return user_input.model_dump()
624+
625+
@workflow("test_auth_workflow", target=Target.CREATE)
626+
def test_auth_workflow():
627+
return init >> noauth_resume >> authorized_resume >> unauthorized_resume >> done
628+
629+
with WorkflowInstanceForTests(test_auth_workflow, "test_auth_workflow"):
630+
response = test_client.post("/api/processes/test_auth_workflow", json=[{}])
631+
assert HTTPStatus.CREATED == response.status_code
632+
process_id = response.json()["id"]
633+
# No auth succeeds
634+
response = test_client.put(f"/api/processes/{process_id}/resume", json=[{"confirm": True}])
635+
assert HTTPStatus.NO_CONTENT == response.status_code
636+
# Authorized succeeds
637+
response = test_client.put(f"/api/processes/{process_id}/resume", json=[{"confirm": True}])
638+
assert HTTPStatus.NO_CONTENT == response.status_code
639+
# Unauthorized fails
640+
response = test_client.put(f"/api/processes/{process_id}/resume", json=[{"confirm": True}])
641+
assert HTTPStatus.FORBIDDEN == response.status_code
642+
643+
#TODO test how this interacts with passing a different callback to @authorize_workflow
644+
# These should be as functionally independent as possible.

0 commit comments

Comments
 (0)