Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions src/murfey/server/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,7 @@ def get_visit_name(session_id: int) -> str:
)


async def validate_frontend_session_access(
session_id: int,
token: Annotated[str, Depends(oauth2_scheme)],
) -> int:
"""
Validates whether a frontend request can access information about this session
"""
visit_name = get_visit_name(session_id)

async def submit_to_auth_endpoint(url_subpath: str, token: str) -> None:
if auth_url:
headers = (
{}
Expand All @@ -241,7 +233,7 @@ async def validate_frontend_session_access(
)
async with aiohttp.ClientSession(cookies=cookies) as session:
async with session.get(
f"{auth_url}/validate_visit_access/{visit_name}",
f"{auth_url}/{url_subpath}",
headers=headers,
) as response:
success = response.status == 200
Expand All @@ -253,10 +245,21 @@ async def validate_frontend_session_access(
detail="You do not have access to this visit",
headers={"WWW-Authenticate": "Bearer"},
)


async def validate_frontend_session_access(
session_id: int,
token: Annotated[str, Depends(oauth2_scheme)],
) -> int:
"""
Validates whether a frontend request can access information about this session
"""
visit_name = get_visit_name(session_id)
await submit_to_auth_endpoint(f"validate_visit_access/{visit_name}", token)
return session_id


async def validate_instrument_session_access(
async def validate_instrument_server_session_access(
session_id: int,
token: Annotated[str, Depends(instrument_oauth2_scheme)],
) -> int:
Expand Down Expand Up @@ -288,9 +291,26 @@ async def validate_instrument_session_access(
return session_id


async def validate_user_instrument_access(
instrument_name: str,
token: Annotated[str, Depends(oauth2_scheme)],
) -> str:
"""
Validates whether a frontend request can access information about this instrument
"""
await submit_to_auth_endpoint(
f"validate_instrument_access/{instrument_name}", token
)
return instrument_name


# Set validation conditions for the session ID based on where the request is from
MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)]
MurfeySessionIDInstrument = Annotated[int, Depends(validate_instrument_session_access)]
MurfeySessionIDInstrument = Annotated[
int, Depends(validate_instrument_server_session_access)
]

MurfeyInstrumentNameFrontend = Annotated[str, Depends(validate_user_instrument_access)]


"""
Expand Down
13 changes: 8 additions & 5 deletions src/murfey/server/api/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlmodel import select
from werkzeug.utils import secure_filename

from murfey.server.api.auth import MurfeyInstrumentNameFrontend as MurfeyInstrumentName
from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID
from murfey.server.api.auth import (
create_access_token,
Expand Down Expand Up @@ -42,7 +43,7 @@
"/instruments/{instrument_name}/sessions/{session_id}/activate_instrument_server"
)
async def activate_instrument_server_for_session(
instrument_name: str,
instrument_name: MurfeyInstrumentName,
session_id: int,
token_in: Annotated[str, Depends(oauth2_scheme)],
db=murfey_db,
Expand Down Expand Up @@ -80,7 +81,9 @@ async def activate_instrument_server_for_session(


@router.get("/instruments/{instrument_name}/sessions/{session_id}/active")
async def check_if_session_is_active(instrument_name: str, session_id: int):
async def check_if_session_is_active(
instrument_name: MurfeyInstrumentName, session_id: int
):
if instrument_server_tokens.get(session_id) is None:
return {"active": False}
async with lock:
Expand Down Expand Up @@ -214,7 +217,7 @@ async def pass_proc_params_to_instrument_server(


@router.get("/instruments/{instrument_name}/instrument_server")
async def check_instrument_server(instrument_name: str):
async def check_instrument_server(instrument_name: MurfeyInstrumentName):
data = None
machine_config = get_machine_config(instrument_name=instrument_name)[
instrument_name
Expand All @@ -232,7 +235,7 @@ async def check_instrument_server(instrument_name: str):
"/instruments/{instrument_name}/sessions/{session_id}/possible_gain_references"
)
async def get_possible_gain_references(
instrument_name: str, session_id: MurfeySessionID
instrument_name: MurfeyInstrumentName, session_id: MurfeySessionID
) -> List[File]:
data = []
machine_config = get_machine_config(instrument_name=instrument_name)[
Expand Down Expand Up @@ -491,7 +494,7 @@ class RSyncerInfo(BaseModel):

@router.get("/instruments/{instrument_name}/sessions/{session_id}/rsyncer_info")
async def get_rsyncer_info(
instrument_name: str, session_id: MurfeySessionID, db=murfey_db
instrument_name: MurfeyInstrumentName, session_id: MurfeySessionID, db=murfey_db
) -> List[RSyncerInfo]:
rsyncer_list = []
analyser_list = []
Expand Down
17 changes: 11 additions & 6 deletions src/murfey/server/api/session_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import murfey.server.api.websocket as ws
from murfey.server import _transport_object
from murfey.server.api import templates
from murfey.server.api.auth import MurfeyInstrumentNameFrontend as MurfeyInstrumentName
from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID
from murfey.server.api.auth import validate_token
from murfey.server.api.shared import get_foil_hole as _get_foil_hole
Expand Down Expand Up @@ -74,20 +75,24 @@ def connections_check():


@router.get("/instruments/{instrument_name}/machine")
def machine_info_by_instrument(instrument_name: str) -> Optional[MachineConfig]:
def machine_info_by_instrument(
instrument_name: MurfeyInstrumentName,
) -> Optional[MachineConfig]:
return get_machine_config_for_instrument(instrument_name)


@router.get("/instruments/{instrument_name}/visits_raw", response_model=List[Visit])
def get_current_visits(instrument_name: str, db=ispyb_db):
def get_current_visits(instrument_name: MurfeyInstrumentName, db=ispyb_db):
logger.debug(
f"Received request to look up ongoing visits for {sanitise(instrument_name)}"
)
return get_all_ongoing_visits(instrument_name, db)


@router.get("/instruments/{instrument_name}/visits/")
def all_visit_info(instrument_name: str, request: Request, db=ispyb_db):
def all_visit_info(
instrument_name: MurfeyInstrumentName, request: Request, db=ispyb_db
):
visits = get_all_ongoing_visits(instrument_name, db)

if visits:
Expand Down Expand Up @@ -159,7 +164,7 @@ class VisitEndTime(BaseModel):

@router.post("/instruments/{instrument_name}/visits/{visit}/session/{name}")
def create_session(
instrument_name: str,
instrument_name: MurfeyInstrumentName,
visit: str,
name: str,
visit_end_time: VisitEndTime,
Expand Down Expand Up @@ -195,7 +200,7 @@ def remove_session(session_id: MurfeySessionID, db=murfey_db):

@router.get("/instruments/{instrument_name}/visits/{visit_name}/sessions")
def get_sessions_with_visit(
instrument_name: str, visit_name: str, db=murfey_db
instrument_name: MurfeyInstrumentName, visit_name: str, db=murfey_db
) -> List[Session]:
sessions = db.exec(
select(Session)
Expand All @@ -207,7 +212,7 @@ def get_sessions_with_visit(

@router.get("/instruments/{instrument_name}/sessions")
async def get_sessions_by_instrument_name(
instrument_name: str, db=murfey_db
instrument_name: MurfeyInstrumentName, db=murfey_db
) -> List[Session]:
sessions = db.exec(
select(Session).where(Session.instrument_name == instrument_name)
Expand Down