Skip to content

Commit

Permalink
Merge branch 'main' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
Myles Bartlett committed Dec 15, 2022
2 parents c91e837 + 5dfd907 commit 474c26c
Show file tree
Hide file tree
Showing 16 changed files with 267 additions and 133 deletions.
13 changes: 1 addition & 12 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,7 @@ disallow_untyped_defs = True
disallow_incomplete_defs = True
warn_incomplete_stub = True
show_error_codes = True

[mypy-hydra.*]
follow_imports = skip
follow_imports_for_stubs = True

[mypy-omegaconf.*]
follow_imports = skip
follow_imports_for_stubs = True
warn_unused_ignores = True

[mypy-pytest.*]
ignore_missing_imports = True

[mypy-torch.*]
follow_imports = skip
follow_imports_for_stubs = True
27 changes: 13 additions & 14 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ darglint = "^1.8.0"
pandas-stubs = ">=1.4.3.220718"

[tool.pyright]
exclude = ["**/node_modules", "**/__pycache__", "**/."]
exclude = ["**/node_modules", "**/__pycache__", "**/.*"]
typeCheckingMode = "basic"
pythonVersion = "3.8"
reportUnusedImport = "error"
Expand All @@ -93,6 +93,7 @@ reportMissingTypeStubs = "warning"
strictListInference = true
strictSetInference = true
strictParameterNoneValue = true
reportUnnecessaryTypeIgnoreComment = "warning"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
42 changes: 16 additions & 26 deletions ranzen/hydra/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,11 @@ def _to_yaml_value(default: Any, *, indent_level: int = 0) -> str | None:
elif isinstance(default, (tuple, list)):
str_ = ""
indent_level += 1
str_ls = []
for elem in default:
elem_str = _to_yaml_value(elem, indent_level=indent_level)
if elem_str is None:
return None
str_ += f"\n{YAML_INDENT * indent_level}- {elem_str}"
str_ = str(str_ls)
elif isinstance(default, dict):
str_ = ""
indent_level += 1
Expand Down Expand Up @@ -110,7 +108,7 @@ def name(self) -> str:
return self._name

@name.setter
def name(self, name: str | None) -> None: # type: ignore
def name(self, name: str | None) -> None:
self._name = name


Expand Down Expand Up @@ -154,7 +152,7 @@ class Relay:
_logger: ClassVar[Optional[logging.Logger]] = None

@classmethod
def _get_logger(cls: type[Self]) -> logging.Logger:
def _get_logger(cls) -> logging.Logger:
if cls._logger is None:
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(sys.stdout))
Expand All @@ -163,18 +161,16 @@ def _get_logger(cls: type[Self]) -> logging.Logger:
return cls._logger

@classmethod
def _log(cls: type[Self], msg: str) -> None:
def _log(cls, msg: str) -> None:
cls._get_logger().info(msg)

@classmethod
def _config_dir_name(cls: type[Self]) -> str:
def _config_dir_name(cls) -> str:
return _camel_to_snake(cls.__name__)

@final
@classmethod
def _init_yaml_files(
cls: type[Self], *, config_dir: Path, config_dict: dict[str, list[Any]]
) -> None:
def _init_yaml_files(cls, *, config_dir: Path, config_dict: dict[str, list[Any]]) -> None:
primary_conf_fp = (config_dir / cls._CONFIG_NAME).with_suffix(".yaml")
primary_conf_exists = primary_conf_fp.exists()
with primary_conf_fp.open("a+") as primary_conf:
Expand Down Expand Up @@ -226,15 +222,13 @@ def _init_yaml_files(
cls._log(f"Finished initialising config directory initialised at '{config_dir}'")

@classmethod
def _module_to_fp(cls: type[Self], module: ModuleType | str):
def _module_to_fp(cls, module: ModuleType | str) -> str:
if isinstance(module, ModuleType):
module = module.__name__
return module.replace(".", "/")

@classmethod
def _generate_conf(
cls: type[Self], output_dir: Path, *, module_class_dict: dict[str, List[str]]
) -> None:
def _generate_conf(cls, output_dir: Path, *, module_class_dict: dict[str, List[str]]) -> None:
from configen.config import ConfigenConf, ModuleConf # type: ignore
from configen.configen import generate_module # type: ignore

Expand All @@ -255,7 +249,7 @@ def _generate_conf(
file.write(code)

@classmethod
def _load_module_from_path(cls: type[Self], filepath: Path) -> ModuleType:
def _load_module_from_path(cls, filepath: Path) -> ModuleType:
import sys

spec = importlib.util.spec_from_file_location( # type: ignore
Expand All @@ -268,7 +262,7 @@ def _load_module_from_path(cls: type[Self], filepath: Path) -> ModuleType:

@classmethod
def _load_schemas(
cls: type[Self],
cls,
config_dir: Path,
*,
clear_cache: bool = False,
Expand Down Expand Up @@ -319,9 +313,7 @@ def _load_schemas(
if schema is None:
schema_missing = True
else:
imported_schemas[group].append(
replace(option, class_=schema) # type: ignore
)
imported_schemas[group].append(replace(option, class_=schema))
if schema_missing:
schemas_to_generate[option.class_.__module__].append(cls_name)
import_info = _SchemaImportInfo(
Expand Down Expand Up @@ -351,16 +343,14 @@ def _load_schemas(
# attribute-retrieval during unpickling when using a paralllielising hydra
# launcher and implement a more graceful solution.
schema.__module__ = "__main__"
imported_schemas[group].append(
Option(class_=schema, name=info.name)
) # type: ignore
imported_schemas[group].append(Option(class_=schema, name=info.name))

return primary_schema, imported_schemas, schemas_to_init

@final
@classmethod
def _launch(
cls: type[Self],
cls,
*,
root: Path | str,
clear_cache: bool = False,
Expand Down Expand Up @@ -388,9 +378,9 @@ def _launch(
sr = SchemaRegistration()
sr.register(path=cls._PRIMARY_SCHEMA_NAME, config_class=primary_schema)
for group, schema_ls in schemas.items():
with sr.new_group(group_name=f"schema/{group}", target_path=f"{group}") as group:
with sr.new_group(group_name=f"schema/{group}", target_path=f"{group}") as group_:
for info in schema_ls:
group.add_option(name=info.name, config_class=info.class_)
group_.add_option(name=info.name, config_class=info.class_)

# config_path only allows for relative paths; we need to resort to construct a
# searchpath plugin on-the-fly in order to set the config directory with an absolute path
Expand All @@ -403,7 +393,7 @@ def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:

@hydra.main(config_path=None, config_name=cls._CONFIG_NAME, version_base=None)
def launcher(cfg: Any, /) -> Any:
relay: cls = instantiate(cfg, _recursive_=instantiate_recursively)
relay: Self = instantiate(cfg, _recursive_=instantiate_recursively)
config_dict = cast(
Dict[str, Any],
OmegaConf.to_container(cfg, throw_on_missing=True, enum_to_str=False, resolve=True),
Expand All @@ -414,7 +404,7 @@ def launcher(cfg: Any, /) -> Any:

@classmethod
def with_hydra(
cls: type[Self],
cls,
root: Path | str,
*,
clear_cache: bool = False,
Expand Down
7 changes: 4 additions & 3 deletions ranzen/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def str_to_enum(str_: str | E, *, enum: type[E]) -> E:

if sys.version_info >= (3, 11):
# will be available in python 3.11
from enum import StrEnum # type: ignore
from enum import StrEnum
else:
#
# the following is copied straight from https://github.com/python/cpython/blob/3.11/Lib/enum.py
Expand Down Expand Up @@ -160,9 +160,10 @@ def __new__(cls: Type[_S], *values: str) -> _S:
member._value_ = value
return member

__str__ = str.__str__ # type: ignore
def __str__(self) -> str:
return str.__str__(self)

def _generate_next_value_(name: str, start: int, count: int, last_values: list[str]) -> str:
def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]) -> str:
"""
Return the lower-cased version of the member name.
"""
Expand Down
2 changes: 1 addition & 1 deletion ranzen/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def __init__(

@classmethod
def from_labels(
cls: type[Self],
cls,
labels: Sequence[int] | Tensor,
*,
batch_size: int,
Expand Down
2 changes: 1 addition & 1 deletion ranzen/torch/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def reduction(self) -> ReductionType:
return self._reduction

@reduction.setter
def reduction(self, value: ReductionType | str) -> None: # type: ignore
def reduction(self, value: ReductionType | str) -> None:
if isinstance(value, str):
value = str_to_enum(str_=value, enum=ReductionType)
self._reduction = value
Expand Down
2 changes: 1 addition & 1 deletion ranzen/torch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@dataclass(unsafe_hash=True)
class DcModule(nn.Module):
@final
def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self:
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
obj = object.__new__(cls)
nn.Module.__init__(obj)
return obj
2 changes: 1 addition & 1 deletion ranzen/torch/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def batched_randint(
:returns: A tensor of random-sampled integers upper-bounded by the values in ``high``.
"""
total_size = high.size()
total_size: torch.Size | list[int] = high.size()
if size is not None:
total_size = list(total_size)
if isinstance(size, int):
Expand Down
6 changes: 3 additions & 3 deletions ranzen/torch/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class LinearWarmup(WarmupScheduler[T]):
def __post_init__(self) -> None:
super().__post_init__()
if self.warmup_steps == 0:
self.step_size = 0
self.step_size = 0 # type: ignore
else:
self.step_size = (self.end_val - self.start_val) / self.warmup_steps

Expand All @@ -299,9 +299,9 @@ class ExponentialWarmup(WarmupScheduler[T]):
def __post_init__(self) -> None:
super().__post_init__()
if self.warmup_steps == 0:
self.step_size = 0
self.step_size = 0 # type: ignore
else:
self.step_size = (self.end_val / self.start_val) ** (1 / self.warmup_steps)
self.step_size = (self.end_val / self.start_val) ** (1 / self.warmup_steps) # type: ignore

@implements(WarmupScheduler)
def _update(self, value: T) -> T:
Expand Down
Loading

0 comments on commit 474c26c

Please sign in to comment.