Skip to content

Commit

Permalink
Clean up typing (#64)
Browse files Browse the repository at this point in the history
* clean up typing

* fix tests

* lint
  • Loading branch information
blink1073 authored Jan 12, 2023
1 parent 88cc054 commit ac65980
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 44 deletions.
2 changes: 1 addition & 1 deletion jupyter_events/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main():
@click.command()
@click.argument("schema")
@click.pass_context
def validate(ctx: click.Context, schema: str):
def validate(ctx: click.Context, schema: str) -> int:
"""Validate a SCHEMA against Jupyter Event's meta schema.
SCHEMA can be a JSON/YAML string or filepath to a schema.
Expand Down
30 changes: 16 additions & 14 deletions jupyter_events/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import logging
import warnings
from datetime import datetime
from pathlib import PurePath
from typing import Callable, Optional, Union
from typing import Any, Callable, Coroutine, Optional, Union

from jsonschema import ValidationError
from pythonjsonlogger import jsonlogger # type:ignore
from traitlets import Dict, Instance, Set, default
from traitlets.config import Config, LoggingConfigurable

from .schema import SchemaType
from .schema_registry import SchemaRegistry
from .traits import Handlers
from .validators import JUPYTER_EVENTS_CORE_VALIDATOR
Expand Down Expand Up @@ -131,7 +131,7 @@ def get_handlers():
eventlogger_cfg = Config({"EventLogger": my_cfg})
super()._load_config(eventlogger_cfg, section_names=None, traits=None)

def register_event_schema(self, schema: Union[dict, str, PurePath]):
def register_event_schema(self, schema: SchemaType) -> None:
"""Register this schema with the schema registry.
Get this registered schema using the EventLogger.schema.get() method.
Expand All @@ -143,7 +143,7 @@ def register_event_schema(self, schema: Union[dict, str, PurePath]):
self._modified_listeners[key] = set()
self._unmodified_listeners[key] = set()

def register_handler(self, handler: logging.Handler):
def register_handler(self, handler: logging.Handler) -> None:
"""Register a new logging handler to the Event Logger.
All outgoing messages will be formatted as a JSON string.
Expand All @@ -164,7 +164,7 @@ def _skip_message(record, **kwargs):
if handler not in self.handlers:
self.handlers.append(handler)

def remove_handler(self, handler: logging.Handler):
def remove_handler(self, handler: logging.Handler) -> None:
"""Remove a logging handler from the logger and list of handlers."""
self._logger.removeHandler(handler)
if handler in self.handlers:
Expand All @@ -175,7 +175,7 @@ def add_modifier(
*,
schema_id: Union[str, None] = None,
modifier: Callable[[str, dict], dict],
):
) -> None:
"""Add a modifier (callable) to a registered event.
Parameters
Expand Down Expand Up @@ -249,8 +249,8 @@ def add_listener(
*,
modified: bool = True,
schema_id: Union[str, None] = None,
listener: Callable[["EventLogger", str, dict], None],
):
listener: Callable[["EventLogger", str, dict], Coroutine[Any, Any, None]],
) -> None:
"""Add a listener (callable) to a registered event.
Parameters
Expand Down Expand Up @@ -304,7 +304,7 @@ def remove_listener(
self,
*,
schema_id: Optional[str] = None,
listener: Callable[["EventLogger", str, dict], None],
listener: Callable[["EventLogger", str, dict], Coroutine[Any, Any, None]],
) -> None:
"""Remove a listener from an event or all events.
Expand All @@ -327,7 +327,9 @@ def remove_listener(
self._modified_listeners[schema_id].discard(listener)
self._unmodified_listeners[schema_id].discard(listener)

def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
def emit(
self, *, schema_id: str, data: dict, timestamp_override: Optional[datetime] = None
) -> Optional[dict]:
"""
Record given event with schema has occurred.
Expand All @@ -351,7 +353,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
and not self._modified_listeners[schema_id]
and not self._unmodified_listeners[schema_id]
):
return
return None

# If the schema hasn't been registered, raise a warning to make sure
# this was intended.
Expand All @@ -362,7 +364,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
"`register_event_schema` method.",
SchemaNotRegistered,
)
return
return None

schema = self.schemas.get(schema_id)

Expand Down Expand Up @@ -400,7 +402,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):

# callback for removing from finished listeners
# from active listeners set.
def _listener_task_done(task: asyncio.Task):
def _listener_task_done(task: asyncio.Task) -> None:
# If an exception happens, log it to the main
# applications logger
err = task.exception()
Expand Down Expand Up @@ -429,7 +431,7 @@ def _listener_task_done(task: asyncio.Task):
self._active_listeners.add(task)

# Remove task from active listeners once its finished.
def _listener_task_done(task: asyncio.Task):
def _listener_task_done(task: asyncio.Task) -> None:
# If an exception happens, log it to the main
# applications logger
err = task.exception()
Expand Down
21 changes: 12 additions & 9 deletions jupyter_events/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Event schema objects."""
import json
from pathlib import Path, PurePath
from typing import Type, Union
from typing import Optional, Type, Union

from jsonschema import FormatChecker, validators
from jsonschema import FormatChecker, RefResolver, validators

try:
from jsonschema.protocols import Validator
Expand Down Expand Up @@ -34,6 +34,9 @@ class EventSchemaFileAbsent(Exception): # noqa
pass


SchemaType = Union[dict, str, PurePath]


class EventSchema:
"""A validated schema that can be used.
Expand All @@ -58,10 +61,10 @@ class EventSchema:

def __init__(
self,
schema: Union[dict, str, PurePath],
validator_class: Type[Validator] = validators.Draft7Validator, # type:ignore
schema: SchemaType,
validator_class: Type[Validator] = validators.Draft7Validator, # type:ignore[assignment]
format_checker: FormatChecker = draft7_format_checker,
resolver=None,
resolver: Optional[RefResolver] = None,
):
"""Initialize an event schema."""
_schema = self._load_schema(schema)
Expand All @@ -76,29 +79,29 @@ def __repr__(self):
return json.dumps(self._schema, indent=2)

@staticmethod
def _ensure_yaml_loaded(schema, was_str=False) -> None:
def _ensure_yaml_loaded(schema: SchemaType, was_str: bool = False) -> None:
"""Ensures schema was correctly loaded into a dictionary. Raises
EventSchemaLoadingError otherwise."""
if isinstance(schema, dict):
return

error_msg = "Could not deserialize schema into a dictionary."

def intended_as_path(schema):
def intended_as_path(schema: str) -> bool:
path = Path(schema)
return path.match("*.yml") or path.match("*.yaml") or path.match("*.json")

# detect whether the user specified a string but intended a PurePath to
# generate a more helpful error message
if was_str and intended_as_path(schema):
if was_str and intended_as_path(schema): # type:ignore[arg-type]
error_msg += " Paths to schema files must be explicitly wrapped in a Pathlib object."
else:
error_msg += " Double check the schema and ensure it is in the proper form."

raise EventSchemaLoadingError(error_msg)

@staticmethod
def _load_schema(schema: Union[dict, str, PurePath]) -> dict:
def _load_schema(schema: SchemaType) -> dict:
"""Load a JSON schema from different sources/data types.
`schema` could be a dictionary or serialized string representing the
Expand Down
4 changes: 2 additions & 2 deletions jupyter_events/schema_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ def __init__(self, schemas: Optional[dict] = None):
"""Initialize the registry."""
self._schemas = schemas or {}

def __contains__(self, key: str):
def __contains__(self, key: str) -> bool:
"""Syntax sugar to check if a schema is found in the registry"""
return key in self._schemas

def __repr__(self) -> str:
"""The str repr of the registry."""
return ",\n".join([str(s) for s in self._schemas.values()])

def _add(self, schema_obj: EventSchema):
def _add(self, schema_obj: EventSchema) -> None:
if schema_obj.id in self._schemas:
msg = (
f"The schema, {schema_obj.id}, is already "
Expand Down
2 changes: 1 addition & 1 deletion jupyter_events/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)


def validate_schema(schema: dict):
def validate_schema(schema: dict) -> None:
"""Validate a schema dict."""
try:
# Validate the schema against Jupyter Events metaschema.
Expand Down
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ exclude_lines = [
"@(abc\\.)?abstractmethod",
]

[tool.mypy]
check_untyped_defs = true
disallow_incomplete_defs = true
no_implicit_optional = true
pretty = true
show_error_context = true
show_error_codes = true
strict_equality = true
warn_unused_configs = true
warn_unused_ignores = true
warn_redundant_casts = true
explicit_package_bases = true
namespace_packages = true

[tool.black]
line-length = 100
skip-string-normalization = true
Expand Down
15 changes: 6 additions & 9 deletions tests/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def jp_event_schemas(schema):

async def test_listener_function(jp_event_logger, schema):
event_logger = jp_event_logger
global listener_was_called
listener_was_called = False

async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
global listener_was_called
listener_was_called = True # type: ignore
nonlocal listener_was_called
listener_was_called = True

# Add the modifier
event_logger.add_listener(schema_id=schema.id, listener=my_listener)
Expand All @@ -41,12 +40,11 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:

async def test_remove_listener_function(jp_event_logger, schema):
event_logger = jp_event_logger
global listener_was_called
listener_was_called = False

async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
global listener_was_called
listener_was_called = True # type: ignore
nonlocal listener_was_called
listener_was_called = True

# Add the modifier
event_logger.add_listener(schema_id=schema.id, listener=my_listener)
Expand Down Expand Up @@ -114,15 +112,14 @@ async def test_bad_listener_does_not_break_good_listener(jp_event_logger, schema
h = logging.StreamHandler(log_stream)
app_log.addHandler(h)

global listener_was_called
listener_was_called = False

async def listener_raise_exception(logger: EventLogger, schema_id: str, data: dict) -> None:
raise Exception("This failed") # noqa

async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
global listener_was_called
listener_was_called = True # type: ignore
nonlocal listener_was_called
listener_was_called = True

# Add a bad listener and a good listener and ensure that
# emitting still works and the bad listener's exception is is logged.
Expand Down
10 changes: 5 additions & 5 deletions tests/test_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,20 @@ def redact(self, schema_id: str, data: dict) -> dict:
assert output["username"] == "<masked>"


def test_bad_modifier_functions(jp_event_logger, schema: EventSchema):
def test_bad_modifier_functions(jp_event_logger: EventLogger, schema: EventSchema) -> None:
event_logger = jp_event_logger

def modifier_with_extra_args(schema_id: str, data: dict, unknown_arg: dict) -> dict:
return data

with pytest.raises(ModifierError):
event_logger.add_modifier(modifier=modifier_with_extra_args)
event_logger.add_modifier(modifier=modifier_with_extra_args) # type:ignore[arg-type]

# Ensure no modifier was added.
assert len(event_logger._modifiers[schema.id]) == 0


def test_bad_modifier_method(jp_event_logger, schema: EventSchema):
def test_bad_modifier_method(jp_event_logger: EventLogger, schema: EventSchema) -> None:
event_logger = jp_event_logger

class Redactor:
Expand All @@ -77,7 +77,7 @@ def redact(self, schema_id: str, data: dict, extra_args: dict) -> dict:
redactor = Redactor()

with pytest.raises(ModifierError):
event_logger.add_modifier(modifier=redactor.redact)
event_logger.add_modifier(modifier=redactor.redact) # type:ignore[arg-type]

# Ensure no modifier was added
assert len(event_logger._modifiers[schema.id]) == 0
Expand All @@ -90,7 +90,7 @@ def modifier_with_extra_args(event):
return event

with pytest.raises(ModifierError):
logger.add_modifier(modifier=modifier_with_extra_args)
logger.add_modifier(modifier=modifier_with_extra_args) # type:ignore[arg-type]


def test_remove_modifier(schema, jp_event_logger, jp_read_emitted_events):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_string_intended_as_path():
def test_unrecognized_type():
"""Validation fails because file is not of valid type."""
with pytest.raises(EventSchemaUnrecognized):
EventSchema(9001)
EventSchema(9001) # type:ignore[arg-type]


def test_invalid_yaml():
Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def get_event_data(event, schema, schema_id, version, unredacted_policies):
handler = logging.StreamHandler(sink)

e = EventLogger(handlers=[handler], unredacted_policies=unredacted_policies)
e.register_schema(schema)
e.register_event_schema(schema)

# Record event and read output
e.emit(schema_id, version, deepcopy(event))
e.emit(schema_id=schema_id, data=deepcopy(event))

recorded_event = json.loads(sink.getvalue())
return {key: value for key, value in recorded_event.items() if not key.startswith("__")}

0 comments on commit ac65980

Please sign in to comment.