Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# YAML formatting specifications: https://yaml.org/spec/1.2.2/

repos:
# Generate up-to-date Murfey route manifest
# NOTE: This will only work if Murfey is installed in the current Python environment
- repo: local
hooks:
- id: generate-route-manifest
name: Generating Murfey route manifest
entry: murfey.generate_route_manifest
language: system
# Only run if FastAPI router-related modules are changed
files: ^src/murfey/(instrument_server/.+\.py|server/(main|api/.+)\.py)$
pass_filenames: false

# Syntax validation and some basic sanity checks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 # Released 2023-10-09
Expand Down
117 changes: 88 additions & 29 deletions src/murfey/cli/generate_route_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
server and backend server to enable lookup of the URLs based on function name.
"""

import contextlib
import importlib
import inspect
import io
import pkgutil
from argparse import ArgumentParser
from pathlib import Path
Expand All @@ -17,6 +19,54 @@
import murfey


class PrettierDumper(yaml.Dumper):
"""
Custom YAML Dumper class that sets `indentless` to False. This generates a YAML
file that is then compliant with Prettier's formatting style
"""

def increase_indent(self, flow=False, indentless=False):
# Force 'indentless=False' so list items align with Prettier
return super(PrettierDumper, self).increase_indent(flow, indentless=False)


def prettier_str_representer(dumper, data):
"""
Helper function to format strings according to Prettier's standards:
- No quoting unless it can be misinterpreted as another data type
- When quoting, use double quotes unless string already contains double quotes
"""

def is_implicitly_resolved(value: str) -> bool:
for (
first_char,
resolvers,
) in yaml.resolver.Resolver.yaml_implicit_resolvers.items():
if first_char is None or (value and value[0] in first_char):
for resolver in resolvers:
if len(resolver) == 3:
_, regexp, _ = resolver
else:
_, regexp = resolver
if regexp.match(value):
return True
return False

# If no quoting is needed, use default plain style
if not is_implicitly_resolved(data):
return dumper.represent_scalar("tag:yaml.org,2002:str", data)

# If the string already contains double quotes, fall back to single quotes
if '"' in data and "'" not in data:
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="'")

# Otherwise, prefer double quotes
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"')


PrettierDumper.add_representer(str, prettier_str_representer)


def find_routers(name: str) -> dict[str, APIRouter]:

def _extract_routers_from_module(module: ModuleType):
Expand All @@ -30,34 +80,36 @@ def _extract_routers_from_module(module: ModuleType):

routers = {}

# Import the module or package
try:
root = importlib.import_module(name)
except ImportError:
raise ImportError(
f"Cannot import '{name}'. Please ensure that you've installed all the "
"dependencies for the client, instrument server, and backend server "
"before running this command."
)

# If it's a package, walk through submodules and extract routers from each
if hasattr(root, "__path__"):
module_list = pkgutil.walk_packages(root.__path__, prefix=name + ".")
for _, module_name, _ in module_list:
try:
module = importlib.import_module(module_name)
except ImportError:
raise ImportError(
f"Cannot import '{module_name}'. Please ensure that you've "
"installed all the dependencies for the client, instrument "
"server, and backend server before running this command."
)

routers.update(_extract_routers_from_module(module))

# Extract directly from single module
else:
routers.update(_extract_routers_from_module(root))
# Silence output during import and only return messages if imports fail
buffer = io.StringIO()
with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer):
# Import the module or package
try:
root = importlib.import_module(name)
except Exception as e:
captured_logs = buffer.getvalue().strip()
message = f"Cannot import '{name}': {e}"
if captured_logs:
message += f"\n--- Captured output ---\n{captured_logs}"
raise ImportError(message) from e

# If it's a package, walk through submodules and extract routers from each
if hasattr(root, "__path__"):
module_list = pkgutil.walk_packages(root.__path__, prefix=name + ".")
for _, module_name, _ in module_list:
try:
module = importlib.import_module(module_name)
except Exception as e:
captured_logs = buffer.getvalue().strip()
message = f"Cannot import '{name}': {e}"
if captured_logs:
message += f"\n--- Captured output ---\n{captured_logs}"
raise ImportError(message) from e
routers.update(_extract_routers_from_module(module))

# Extract directly from single module
else:
routers.update(_extract_routers_from_module(root))

return routers

Expand Down Expand Up @@ -138,7 +190,14 @@ def run():
murfey_dir = Path(murfey.__path__[0])
manifest_file = murfey_dir / "util" / "route_manifest.yaml"
with open(manifest_file, "w") as file:
yaml.dump(manifest, file, default_flow_style=False, sort_keys=False)
yaml.dump(
manifest,
file,
Dumper=PrettierDumper,
default_flow_style=False,
sort_keys=False,
indent=2,
)
print(
"Route manifest for instrument and backend servers saved to "
f"{str(manifest_file)!r}"
Expand Down
10 changes: 5 additions & 5 deletions src/murfey/server/api/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
from datetime import datetime
from typing import Any, TypeVar, Union
from typing import Any, TypeVar

from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlmodel import Session, select
Expand All @@ -27,7 +27,7 @@ def __init__(self):
async def connect(
self,
websocket: WebSocket,
client_id: Union[int, str],
client_id: int | str,
register_client: bool = True,
):
await websocket.accept()
Expand All @@ -48,7 +48,7 @@ def _register_new_client(client_id: int):
murfey_db.commit()
murfey_db.close()

def disconnect(self, client_id: Union[int, str], unregister_client: bool = True):
def disconnect(self, client_id: int | str, unregister_client: bool = True):
self.active_connections.pop(client_id)
if unregister_client:
murfey_db: Session = next(get_murfey_db_session())
Expand Down Expand Up @@ -97,7 +97,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
@ws.websocket("/connect/{client_id}")
async def websocket_connection_endpoint(
websocket: WebSocket,
client_id: Union[int, str],
client_id: int | str,
):
await manager.connect(websocket, client_id, register_client=False)
await manager.broadcast(f"Client {client_id} joined")
Expand Down Expand Up @@ -161,7 +161,7 @@ async def close_ws_connection(client_id: int):


@ws.delete("/connect/{client_id}")
async def close_unrecorded_ws_connection(client_id: Union[int, str]):
async def close_unrecorded_ws_connection(client_id: int | str):
client_id_str = str(client_id).replace("\r\n", "").replace("\n", "")
log.info(f"Disconnecting {client_id_str}")
manager.disconnect(client_id)
Expand Down
16 changes: 3 additions & 13 deletions src/murfey/util/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,11 @@ def url_path_for(
)
logger.error(message)
raise KeyError(message)
# Skip complicated type resolution for now
# Skip complicated type resolution
if param_type.startswith("typing."):
continue
elif type(kwargs[param_name]).__name__ not in param_type:
# Validate incoming type against allowed ones
if type(kwargs[param_name]).__name__ not in param_type:
message = (
f"Error validating parameters for {function_name!r}; "
f"{param_name!r} must be {param_type!r}, "
Expand All @@ -135,14 +136,3 @@ def url_path_for(

# Render and return the path
return render_path(route_path, kwargs)


if __name__ == "__main__":
# Run test on some existing routes
url_path = url_path_for(
"workflow.tomo_router",
"register_tilt",
visit_name="nt15587-15",
session_id=2,
)
print(url_path)
4 changes: 2 additions & 2 deletions src/murfey/util/route_manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ murfey.server.api.websocket.ws:
function: websocket_connection_endpoint
path_params:
- name: client_id
type: typing.Union[int, str]
type: int | str
methods: []
- path: /ws/test/{client_id}
function: close_ws_connection
Expand All @@ -1207,7 +1207,7 @@ murfey.server.api.websocket.ws:
function: close_unrecorded_ws_connection
path_params:
- name: client_id
type: typing.Union[int, str]
type: int | str
methods:
- DELETE
murfey.server.api.workflow.correlative_router:
Expand Down
26 changes: 26 additions & 0 deletions tests/cli/test_generate_route_manifest_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys

from pytest_mock import MockerFixture

from murfey.cli.generate_route_manifest import run


def test_run(
mocker: MockerFixture,
):
# Mock out print() and exit()
mock_print = mocker.patch("builtins.print")
mock_exit = mocker.patch("builtins.exit")

# Run the function with its args
sys.argv = ["", "--debug"]
run()

# Check that the final print message and exit() are called
print_calls = mock_print.call_args_list
last_print_call = print_calls[-1]
last_printed = last_print_call.args[0]
assert last_printed.startswith(
"Route manifest for instrument and backend servers saved to"
)
mock_exit.assert_called_once()
44 changes: 44 additions & 0 deletions tests/util/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from murfey.util.api import url_path_for

url_path_test_matrix: tuple[tuple[str, str, dict[str, str | int], str], ...] = (
# Router name | Function name | kwargs | Expected URL
(
"instrument_server.api.router",
"health",
{},
"/health",
),
(
"instrument_server.api.router",
"stop_multigrid_watcher",
{"session_id": 0, "label": "some_label"},
"/sessions/0/multigrid_watcher/some_label",
),
(
"api.hub.router",
"get_instrument_image",
{"instrument_name": "test"},
"/instrument/test/image",
),
(
"api.instrument.router",
"check_if_session_is_active",
{
"instrument_name": "test",
"session_id": 0,
},
"/instrument_server/instruments/test/sessions/0/active",
),
)


@pytest.mark.parametrize("test_params", url_path_test_matrix)
def test_url_path_for(test_params: tuple[str, str, dict[str, str | int], str]):
# Unpack test params
router_name, function_name, kwargs, expected_url_path = test_params
assert (
url_path_for(router_name=router_name, function_name=function_name, **kwargs)
== expected_url_path
)