Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/murfey/instrument_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,15 @@ def stop_multigrid_watcher(session_id: MurfeySessionID, label: str):
return {"success": True}


@router.get("/sessions/{session_id}/multigrid_controller/status")
def check_multigrid_controller_exists(
session_id: MurfeySessionID,
):
if controllers.get(session_id, None) is not None:
return {"exists": True}
return {"exists": False}


@router.post("/sessions/{session_id}/multigrid_controller/visit_end_time")
def update_multigrid_controller_visit_end_time(
session_id: MurfeySessionID, end_time: datetime
Expand Down
87 changes: 56 additions & 31 deletions src/murfey/server/api/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datetime
import logging
from pathlib import Path
from typing import Annotated, List, Optional
from typing import Annotated, Any, List, Optional
from urllib.parse import quote

import aiohttp
Expand Down Expand Up @@ -101,6 +101,31 @@ async def check_if_session_is_active(
return {"active": response.status == 200}


@router.get("/sessions/{session_id}/multigrid_controller/status")
async def check_multigrid_controller_exists(session_id: MurfeySessionID, db=murfey_db):
session = db.exec(select(Session).where(Session.id == session_id)).one()
instrument_name = session.instrument_name
machine_config = get_machine_config(instrument_name=instrument_name)[
instrument_name
]
if machine_config.instrument_server_url:
log.debug(
f"Submitting request to inspect multigrid controller for session {session_id}"
)
async with aiohttp.ClientSession() as clientsession:
async with clientsession.get(
f"{machine_config.instrument_server_url}{url_path_for('api.router', 'check_multigrid_controller_exists', session_id=session_id)}",
headers={
"Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}"
},
) as resp:
data: dict[str, Any] = await resp.json()
else:
data = {"detail": "No instrument server URL found"}
log.debug(f"Received response: {data}")
return data


@router.post("/sessions/{session_id}/multigrid_watcher")
async def setup_multigrid_watcher(
session_id: MurfeySessionID, watcher_spec: MultigridWatcherSetup, db=murfey_db
Expand Down Expand Up @@ -165,6 +190,36 @@ async def start_multigrid_watcher(session_id: MurfeySessionID, db=murfey_db):
return data


@router.post("/sessions/{session_id}/multigrid_controller/visit_end_time")
async def update_visit_end_time(
session_id: MurfeySessionID, end_time: datetime.datetime, db=murfey_db
):
# Load data for session
session_entry = db.exec(select(Session).where(Session.id == session_id)).one()
instrument_name = session_entry.instrument_name

# Update visit end time in database
session_entry.visit_end_time = end_time
db.add(session_entry)
db.commit()

# Update the multigrid controller
data = {}
machine_config = get_machine_config(instrument_name=instrument_name)[
instrument_name
]
if machine_config.instrument_server_url:
async with aiohttp.ClientSession() as clientsession:
async with clientsession.post(
f"{machine_config.instrument_server_url}{url_path_for('api.router', 'update_multigrid_controller_visit_end_time', session_id=session_id)}?end_time={quote(end_time.isoformat())}",
headers={
"Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}"
},
) as resp:
data = await resp.json()
return data


class ProvidedProcessingParameters(BaseModel):
dose_per_frame: float
extract_downscale: bool = True
Expand Down Expand Up @@ -397,36 +452,6 @@ async def finalise_session(session_id: MurfeySessionID, db=murfey_db):
return data


@router.post("/sessions/{session_id}/multigrid_controller/visit_end_time")
async def update_visit_end_time(
session_id: MurfeySessionID, end_time: datetime.datetime, db=murfey_db
):
# Load data for session
session_entry = db.exec(select(Session).where(Session.id == session_id)).one()
instrument_name = session_entry.instrument_name

# Update visit end time in database
session_entry.visit_end_time = end_time
db.add(session_entry)
db.commit()

# Update the multigrid controller
data = {}
machine_config = get_machine_config(instrument_name=instrument_name)[
instrument_name
]
if machine_config.instrument_server_url:
async with aiohttp.ClientSession() as clientsession:
async with clientsession.post(
f"{machine_config.instrument_server_url}{url_path_for('api.router', 'update_multigrid_controller_visit_end_time', session_id=session_id)}?end_time={quote(end_time.isoformat())}",
headers={
"Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}"
},
) as resp:
data = await resp.json()
return data


@router.post("/sessions/{session_id}/abandon_session")
async def abandon_session(session_id: MurfeySessionID, db=murfey_db):
data = {}
Expand Down
10 changes: 10 additions & 0 deletions src/murfey/util/route_manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ murfey.instrument_server.api.router:
path_params: []
methods:
- POST
- path: /sessions/{session_id}/multigrid_controller/status
function: check_multigrid_controller_exists
path_params: []
methods:
- GET
- path: /sessions/{session_id}/stop_rsyncer
function: stop_rsyncer
path_params: []
Expand Down Expand Up @@ -503,6 +508,11 @@ murfey.server.api.instrument.router:
path_params: []
methods:
- POST
- path: /instrument_server/sessions/{session_id}/multigrid_controller/status
function: check_multigrid_controller_exists
path_params: []
methods:
- GET
- path: /instrument_server/sessions/{session_id}/provided_processing_parameters
function: pass_proc_params_to_instrument_server
path_params: []
Expand Down
88 changes: 62 additions & 26 deletions tests/instrument_server/test_api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
from pathlib import Path
from typing import Optional
from unittest.mock import ANY, Mock, patch
from unittest.mock import ANY, MagicMock, patch
from urllib.parse import urlparse

from pytest import mark
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture

from murfey.instrument_server.api import (
GainReference,
_get_murfey_url,
upload_gain_reference,
)
from murfey.instrument_server.api import _get_murfey_url
from murfey.instrument_server.api import router as client_router
from murfey.instrument_server.api import validate_session_token
from murfey.util import posix_path
from murfey.util.api import url_path_for


def set_up_test_client(session_id: Optional[int] = None):
"""
Helper function to set up a test client for the instrument server with validation
checks disabled.
"""
# Set up the instrument server
client_app = FastAPI()
if session_id:
client_app.dependency_overrides[validate_session_token] = lambda: session_id
client_app.include_router(client_router)
return TestClient(client_app)


test_get_murfey_url_params_matrix = (
# Server URL to use
("default",),
Expand All @@ -23,7 +38,7 @@
)


@mark.parametrize("test_params", test_get_murfey_url_params_matrix)
@pytest.mark.parametrize("test_params", test_get_murfey_url_params_matrix)
def test_get_murfey_url(
test_params: tuple[str],
mock_client_configuration, # From conftest.py
Expand Down Expand Up @@ -57,6 +72,24 @@ def test_get_murfey_url(
assert parsed_server.path == parsed_original.path


def test_check_multigrid_controller_exists(mocker: MockerFixture):
session_id = 1

# Patch out the multigrid controllers that have been stored in memory
mocker.patch("murfey.instrument_server.api.controllers", {session_id: MagicMock()})

# Set up the test client
client_server = set_up_test_client(session_id=session_id)
url_path = url_path_for(
"api.router", "check_multigrid_controller_exists", session_id=session_id
)
response = client_server.get(url_path)

# Check that the result is as expected
assert response.status_code == 200
assert response.json() == {"exists": True}


test_upload_gain_reference_params_matrix = (
# Rsync URL settings
("http://1.1.1.1",), # When rsync_url is provided
Expand All @@ -65,25 +98,23 @@ def test_get_murfey_url(
)


@mark.parametrize("test_params", test_upload_gain_reference_params_matrix)
@patch("murfey.instrument_server.api.subprocess")
@patch("murfey.instrument_server.api.tokens")
@patch("murfey.instrument_server.api._get_murfey_url")
@patch("murfey.instrument_server.api.requests")
@pytest.mark.parametrize("test_params", test_upload_gain_reference_params_matrix)
def test_upload_gain_reference(
mock_request,
mock_get_server_url,
mock_tokens,
mock_subprocess,
mocker: MockerFixture,
test_params: tuple[Optional[str]],
):

# Unpack test parameters and define other ones
(rsync_url_setting,) = test_params
server_url = "http://0.0.0.0:8000"
server_url = "https://murfey.server.test"
instrument_name = "murfey"
session_id = 1

# Mock out objects
mock_request = mocker.patch("murfey.instrument_server.api.requests")
mock_get_server_url = mocker.patch("murfey.instrument_server.api._get_murfey_url")
mock_subprocess = mocker.patch("murfey.instrument_server.api.subprocess")
mocker.patch("murfey.instrument_server.api.tokens", {session_id: ANY})

# Create a mock machine config base on the test params
rsync_module = "data"
gain_ref_dir = "C:/ProgramData/Gatan/Gain Reference"
Expand All @@ -95,12 +126,12 @@ def test_upload_gain_reference(
mock_machine_config["rsync_url"] = rsync_url_setting

# Assign expected values to the mock objects
mock_response = Mock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_machine_config
mock_request.get.return_value = mock_response
mock_get_server_url.return_value = server_url
mock_subprocess.run.return_value = Mock(returncode=0)
mock_subprocess.run.return_value = MagicMock(returncode=0)

# Construct payload and pass request to function
gain_ref_file = f"{gain_ref_dir}/gain.mrc"
Expand All @@ -111,13 +142,18 @@ def test_upload_gain_reference(
"visit_path": visit_path,
"gain_destination_dir": gain_dest_dir,
}
result = upload_gain_reference(

# Set up instrument server test client
client_server = set_up_test_client(session_id=session_id)

# Poke the endpoint with the expected data
url_path = url_path_for(
"api.router",
"upload_gain_reference",
instrument_name=instrument_name,
session_id=session_id,
gain_reference=GainReference(
**payload,
),
)
response = client_server.post(url_path, json=payload)

# Check that the machine config request was called
machine_config_url = f"{server_url}{url_path_for('session_control.router', 'machine_info_by_instrument', instrument_name=instrument_name)}"
Expand Down Expand Up @@ -145,4 +181,4 @@ def test_upload_gain_reference(
)

# Check that the function ran through to completion successfully
assert result == {"success": True}
assert response.json() == {"success": True}
Loading