Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ratapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def get_handle(self, index: int):
"""
custom_file = self.files[index]
full_path = os.path.join(custom_file["path"], custom_file["filename"])

if not os.path.isfile(full_path):
raise FileNotFoundError(f"The custom file ({custom_file['name']}) does not have a valid path.")

if not custom_file["function_name"] and custom_file["language"] != Languages.Matlab:
raise ValueError(f"The custom file ({custom_file['name']}) does not have a valid function name.")

if custom_file["language"] == Languages.Python:
file_handle = get_python_handle(custom_file["filename"], custom_file["function_name"], custom_file["path"])
elif custom_file["language"] == Languages.Matlab:
Expand Down
84 changes: 65 additions & 19 deletions ratapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pathlib
import warnings
from itertools import count
from contextlib import suppress
from typing import Any

import numpy as np
Expand All @@ -18,14 +18,41 @@


# Create a counter for each model
background_number = count(1)
contrast_number = count(1)
custom_file_number = count(1)
data_number = count(1)
domain_contrast_number = count(1)
layer_number = count(1)
parameter_number = count(1)
resolution_number = count(1)
background_number = ["Background", 0]
contrast_number = ["Contrast", 0]
custom_file_number = ["Custom File", 0]
data_number = ["Data", 0]
domain_contrast_number = ["Domain Contrast", 0]
layer_number = ["Layer", 0]
parameter_number = ["Parameter", 0]
resolution_number = ["Resolution", 0]

_model_counter = {
"Background": background_number,
"Contrast": contrast_number,
"ContrastWithRatio": contrast_number,
"CustomFile": custom_file_number,
"Data": data_number,
"DomainContrast": domain_contrast_number,
"Layer": layer_number,
"AbsorptionLayer": layer_number,
"Parameter": parameter_number,
"ProtectedParameter": parameter_number,
"Resolution": resolution_number,
}


def _model_name_factory(model_name: str) -> str:
"""Generate a unique name for model using a global counter.

Parameters
----------
model_name : str
The name of the model class.
"""
title, number = _model_counter[model_name]
_model_counter[model_name][1] += 1
return f"New {title} {(number + 1)}"


class RATModel(BaseModel, validate_assignment=True, extra="forbid"):
Expand All @@ -38,6 +65,25 @@ def __repr__(self):
)
return f"{self.__repr_name__()}({fields_repr})"

@field_validator("name", mode="after", check_fields=False)
@classmethod
def update_counter(cls, name: str) -> str:
"""Update the auto name counter if a similar name is manually given.

Parameters
----------
name : str
The name of the model.
"""
title, number = _model_counter[cls.__name__]
prefix = f"New {title} "
if name.startswith(prefix):
with suppress(ValueError):
new_number = int(name[len(prefix) :])
if new_number > number:
_model_counter[cls.__name__][1] = new_number
return name

def __str__(self):
table = prettytable.PrettyTable()
table.field_names = [key.replace("_", " ") for key in self.display_fields]
Expand Down Expand Up @@ -116,7 +162,7 @@ class Background(Signal):

"""

name: str = Field(default_factory=lambda: f"New Background {next(background_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Background"), min_length=1)

@model_validator(mode="after")
def check_unsupported_parameters(self):
Expand Down Expand Up @@ -173,7 +219,7 @@ class Contrast(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Contrast {next(contrast_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Contrast"), min_length=1)
data: str = ""
background: str = ""
background_action: BackgroundActions = BackgroundActions.Add
Expand Down Expand Up @@ -255,7 +301,7 @@ class ContrastWithRatio(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Contrast {next(contrast_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("ContrastWithRatio"), min_length=1)
data: str = ""
background: str = ""
background_action: BackgroundActions = BackgroundActions.Add
Expand Down Expand Up @@ -309,7 +355,7 @@ class CustomFile(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Custom File {next(custom_file_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("CustomFile"), min_length=1)
filename: str = ""
function_name: str = ""
language: Languages = Languages.Python
Expand Down Expand Up @@ -348,7 +394,7 @@ class Data(RATModel, arbitrary_types_allowed=True):

"""

name: str = Field(default_factory=lambda: f"New Data {next(data_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Data"), min_length=1)
data: np.ndarray = np.empty([0, 3])
data_range: list[float] = Field(default=[], min_length=2, max_length=2)
simulation_range: list[float] = Field(default=[], min_length=2, max_length=2)
Expand Down Expand Up @@ -453,7 +499,7 @@ class DomainContrast(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Domain Contrast {next(domain_contrast_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("DomainContrast"), min_length=1)
model: list[str] = []

def __str__(self):
Expand Down Expand Up @@ -483,7 +529,7 @@ class Layer(RATModel, populate_by_name=True):

"""

name: str = Field(default_factory=lambda: f"New Layer {next(layer_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Layer"), min_length=1)
thickness: str
SLD: str = Field(validation_alias="SLD_real")
roughness: str
Expand Down Expand Up @@ -522,7 +568,7 @@ class AbsorptionLayer(RATModel, populate_by_name=True):

"""

name: str = Field(default_factory=lambda: f"New Layer {next(layer_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("AbsorptionLayer"), min_length=1)
thickness: str
SLD_real: str = Field(validation_alias="SLD")
SLD_imaginary: str
Expand Down Expand Up @@ -555,7 +601,7 @@ class Parameter(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Parameter {next(parameter_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Parameter"), min_length=1)
min: float = 0.0
value: float = 0.0
max: float = 0.0
Expand Down Expand Up @@ -638,7 +684,7 @@ class Resolution(Signal):

"""

name: str = Field(default_factory=lambda: f"New Resolution {next(resolution_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Resolution"), min_length=1)

@field_validator("type")
@classmethod
Expand Down
26 changes: 26 additions & 0 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pathlib
import pickle
import tempfile
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -675,6 +677,30 @@ def test_make_controls(standard_layers_controls) -> None:
check_controls_equal(controls, standard_layers_controls)


@patch("ratapi.wrappers.MatlabWrapper")
def test_file_handles(wrapper):
handle = FileHandles([ratapi.models.CustomFile(name="Test Custom File", filename="cpp_test.dll", language="cpp")])

with pytest.raises(FileNotFoundError, match="The custom file \\(Test Custom File\\) does not have a valid path."):
handle.get_handle(0)

with tempfile.NamedTemporaryFile("w", suffix=".dll") as f:
tmp_file = pathlib.Path(f.name)
handle.files[0]["path"] = tmp_file.parent
handle.files[0]["filename"] = tmp_file.name
handle.files[0]["function_name"] = ""
# No function name should throw exception
with pytest.raises(
ValueError, match="The custom file \\(Test Custom File\\) does not have a valid function name."
):
handle.get_handle(0)

# Matlab does not need function name
handle.files[0]["language"] = "matlab"
handle.get_handle(0)
wrapper.assert_called()


def check_problem_equal(actual_problem, expected_problem) -> None:
"""Compare two instances of the "problem" object for equality."""
scalar_fields = [
Expand Down
22 changes: 12 additions & 10 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,23 @@ def test_default_names(model: Callable, model_name: str, model_params: dict) ->
format: "New <model name> <integer>".
"""
model_1 = model(**model_params)
prefix = f"New {model_name} "
assert model_1.name.startswith(prefix)
index = int(model_1.name[len(prefix) :])

model_2 = model(**model_params)
model_3 = model(name="Given Name", **model_params)
model_4 = model(**model_params)

assert model_1.name == f"New {model_name} 1"
assert model_2.name == f"New {model_name} 2"
assert model_1.name == f"New {model_name} {index}"
assert model_2.name == f"New {model_name} {index + 1}"
assert model_3.name == "Given Name"
assert model_4.name == f"New {model_name} 3"
assert model_4.name == f"New {model_name} {index + 2}"

# If user adds name in similar format. The next auto number will take it into account.
model(name=f"{prefix}{index + 20}", **model_params)
model_5 = model(**model_params)
assert model_5.name == f"New {model_name} {index + 21}"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -100,13 +109,6 @@ def test_initialise_with_extra_fields(self, model: Callable, model_params: dict)
model(new_field=1, **model_params)


# def test_custom_file_path_is_absolute() -> None:
# """If we use provide a relative path to the custom file model, it should be converted to an absolute path."""
# relative_path = pathlib.Path("./relative_path")
# custom_file = ratapi.models.CustomFile(path=relative_path)
# assert custom_file.path.is_absolute()


def test_data_eq() -> None:
"""If we use the Data.__eq__ method with an object that is not a pydantic BaseModel, we should return
"NotImplemented".
Expand Down