Skip to content

Commit 41572cd

Browse files
authored
Adds save and load method to project and controls (#123)
* added save and load method to project and controls * moved json save/load to the project method * backwards compatibility * fixed docstring and changed to model dump
1 parent 9f73677 commit 41572cd

File tree

5 files changed

+130
-100
lines changed

5 files changed

+130
-100
lines changed

RATapi/controls.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import os
33
import tempfile
44
import warnings
5+
from pathlib import Path
6+
from typing import Union
57

68
import prettytable
79
from pydantic import (
@@ -220,3 +222,30 @@ def delete_IPC(self):
220222
with contextlib.suppress(FileNotFoundError):
221223
os.remove(self._IPCFilePath)
222224
return None
225+
226+
def save(self, path: Union[str, Path], filename: str = "controls"):
227+
"""Save a controls object to a JSON file.
228+
229+
Parameters
230+
----------
231+
path : str or Path
232+
The directory in which the controls object will be written.
233+
filename : str
234+
The name for the JSON file containing the controls object.
235+
236+
"""
237+
file = Path(path, f"{filename.removesuffix('.json')}.json")
238+
file.write_text(self.model_dump_json())
239+
240+
@classmethod
241+
def load(cls, path: Union[str, Path]) -> "Controls":
242+
"""Load a controls object from file.
243+
244+
Parameters
245+
----------
246+
path : str or Path
247+
The path to the controls object file.
248+
249+
"""
250+
file = Path(path)
251+
return cls.model_validate_json(file.read_text())

RATapi/project.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections
44
import copy
55
import functools
6+
import json
67
from enum import Enum
78
from pathlib import Path
89
from textwrap import indent
@@ -834,6 +835,75 @@ def classlist_script(name, classlist):
834835
+ "\n)"
835836
)
836837

838+
def save(self, path: Union[str, Path], filename: str = "project"):
839+
"""Save a project to a JSON file.
840+
841+
Parameters
842+
----------
843+
path : str or Path
844+
The path in which the project will be written.
845+
filename : str
846+
The name of the generated project file.
847+
848+
"""
849+
json_dict = {}
850+
for field in self.model_fields:
851+
attr = getattr(self, field)
852+
853+
if field == "data":
854+
855+
def make_data_dict(item):
856+
return {
857+
"name": item.name,
858+
"data": item.data.tolist(),
859+
"data_range": item.data_range,
860+
"simulation_range": item.simulation_range,
861+
}
862+
863+
json_dict["data"] = [make_data_dict(data) for data in attr]
864+
865+
elif field == "custom_files":
866+
867+
def make_custom_file_dict(item):
868+
return {
869+
"name": item.name,
870+
"filename": item.filename,
871+
"language": item.language,
872+
"path": str(item.path),
873+
}
874+
875+
json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr]
876+
877+
elif isinstance(attr, ClassList):
878+
json_dict[field] = [item.model_dump() for item in attr]
879+
else:
880+
json_dict[field] = attr
881+
882+
file = Path(path, f"{filename.removesuffix('.json')}.json")
883+
file.write_text(json.dumps(json_dict))
884+
885+
@classmethod
886+
def load(cls, path: Union[str, Path]) -> "Project":
887+
"""Load a project from file.
888+
889+
Parameters
890+
----------
891+
path : str or Path
892+
The path to the project file.
893+
894+
"""
895+
input = Path(path).read_text()
896+
model_dict = json.loads(input)
897+
for i in range(0, len(model_dict["data"])):
898+
if model_dict["data"][i]["name"] == "Simulation":
899+
model_dict["data"][i]["data"] = np.empty([0, 3])
900+
del model_dict["data"][i]["data_range"]
901+
else:
902+
data = model_dict["data"][i]["data"]
903+
model_dict["data"][i]["data"] = np.array(data)
904+
905+
return cls.model_validate(model_dict)
906+
837907
def _classlist_wrapper(self, class_list: ClassList, func: Callable):
838908
"""Defines the function used to wrap around ClassList routines to force revalidation.
839909

RATapi/utils/convert.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Utilities for converting input files to Python `Project`s."""
22

3-
import json
43
import warnings
54
from collections.abc import Iterable
65
from os import PathLike
@@ -553,72 +552,3 @@ def convert_parameters(
553552
eng.save(str(filename), "problem", nargout=0)
554553
eng.exit()
555554
return None
556-
557-
558-
def project_to_json(project: Project) -> str:
559-
"""Write a Project as a JSON file.
560-
561-
Parameters
562-
----------
563-
project : Project
564-
The input Project object to convert.
565-
566-
Returns
567-
-------
568-
str
569-
A string representing the class in JSON format.
570-
"""
571-
json_dict = {}
572-
for field in project.model_fields:
573-
attr = getattr(project, field)
574-
575-
if field == "data":
576-
577-
def make_data_dict(item):
578-
return {
579-
"name": item.name,
580-
"data": item.data.tolist(),
581-
"data_range": item.data_range,
582-
"simulation_range": item.simulation_range,
583-
}
584-
585-
json_dict["data"] = [make_data_dict(data) for data in attr]
586-
587-
elif field == "custom_files":
588-
589-
def make_custom_file_dict(item):
590-
return {"name": item.name, "filename": item.filename, "language": item.language, "path": str(item.path)}
591-
592-
json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr]
593-
594-
elif isinstance(attr, ClassList):
595-
json_dict[field] = [dict(item) for item in attr]
596-
else:
597-
json_dict[field] = attr
598-
599-
return json.dumps(json_dict)
600-
601-
602-
def project_from_json(input: str) -> Project:
603-
"""Read a Project from a JSON string generated by `to_json`.
604-
605-
Parameters
606-
----------
607-
input : str
608-
The JSON input as a string.
609-
610-
Returns
611-
-------
612-
Project
613-
The project corresponding to that JSON input.
614-
"""
615-
model_dict = json.loads(input)
616-
for i in range(0, len(model_dict["data"])):
617-
if model_dict["data"][i]["name"] == "Simulation":
618-
model_dict["data"][i]["data"] = empty([0, 3])
619-
del model_dict["data"][i]["data_range"]
620-
else:
621-
data = model_dict["data"][i]["data"]
622-
model_dict["data"][i]["data"] = array(data)
623-
624-
return Project.model_validate(model_dict)

tests/test_convert.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
import RATapi
11-
from RATapi.utils.convert import project_class_to_r1, project_from_json, project_to_json, r1_to_project_class
11+
from RATapi.utils.convert import project_class_to_r1, r1_to_project_class
1212

1313
TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data")
1414

@@ -110,35 +110,6 @@ def test_invalid_constraints():
110110
assert output_project.background_parameters[0].min == output_project.background_parameters[0].value
111111

112112

113-
@pytest.mark.parametrize(
114-
"project",
115-
[
116-
"r1_default_project",
117-
"r1_monolayer",
118-
"r1_monolayer_8_contrasts",
119-
"r1_orso_polymer",
120-
"r1_motofit_bench_mark",
121-
"dspc_bilayer",
122-
"dspc_standard_layers",
123-
"dspc_custom_layers",
124-
"dspc_custom_xy",
125-
"domains_standard_layers",
126-
"domains_custom_layers",
127-
"domains_custom_xy",
128-
"absorption",
129-
],
130-
)
131-
def test_json_involution(project, request):
132-
"""Test that converting a Project to JSON and back returns the same project."""
133-
original_project = request.getfixturevalue(project)
134-
json_data = project_to_json(original_project)
135-
136-
converted_project = project_from_json(json_data)
137-
138-
for field in RATapi.Project.model_fields:
139-
assert getattr(converted_project, field) == getattr(original_project, field)
140-
141-
142113
@pytest.mark.skipif(importlib.util.find_spec("matlab") is None, reason="Matlab not installed")
143114
@pytest.mark.parametrize("path_type", [os.path.join, pathlib.Path])
144115
def test_matlab_save(path_type, request):

tests/test_project.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test the project module."""
22

33
import copy
4+
import tempfile
45
from pathlib import Path
56
from typing import Callable
67

@@ -1531,3 +1532,32 @@ def test_wrap_extend(test_project, class_list: str, model_type: str, field: str,
15311532

15321533
# Ensure invalid model was not appended
15331534
assert test_attribute == orig_class_list
1535+
1536+
1537+
@pytest.mark.parametrize(
1538+
"project",
1539+
[
1540+
"r1_default_project",
1541+
"r1_monolayer",
1542+
"r1_monolayer_8_contrasts",
1543+
"r1_orso_polymer",
1544+
"r1_motofit_bench_mark",
1545+
"dspc_standard_layers",
1546+
"dspc_custom_layers",
1547+
"dspc_custom_xy",
1548+
"domains_standard_layers",
1549+
"domains_custom_layers",
1550+
"domains_custom_xy",
1551+
"absorption",
1552+
],
1553+
)
1554+
def test_save_load(project, request):
1555+
"""Test that saving and loading a project returns the same project."""
1556+
original_project = request.getfixturevalue(project)
1557+
1558+
with tempfile.TemporaryDirectory() as tmp:
1559+
original_project.save(tmp)
1560+
converted_project = RATapi.Project.load(Path(tmp, "project.json"))
1561+
1562+
for field in RATapi.Project.model_fields:
1563+
assert getattr(converted_project, field) == getattr(original_project, field)

0 commit comments

Comments
 (0)