Skip to content

Commit bd92eee

Browse files
authored
Merge pull request #315 from gkreitz/refactor_formatversion
Replace formatversion.FormatData with a StrEnum
2 parents 98290cf + 450d453 commit bd92eee

File tree

5 files changed

+89
-143
lines changed

5 files changed

+89
-143
lines changed

problemtools/formatversion.py

Lines changed: 46 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,47 @@
1-
import os
21
import yaml
3-
from dataclasses import dataclass
4-
5-
6-
VERSION_LEGACY = 'legacy'
7-
VERSION_2023_07 = '2023-07-draft'
8-
9-
10-
@dataclass(frozen=True)
11-
class FormatData:
12-
"""
13-
A class containing data specific to the format version.
14-
name: the version name.
15-
statement_directory: the directory where the statements should be found.
16-
statement_extensions: the allowed extensions for the statements.
17-
"""
18-
19-
name: str
20-
statement_directory: str
21-
statement_extensions: list[str]
22-
output_validator_directory: str
23-
24-
25-
FORMAT_DATACLASSES = {
26-
VERSION_LEGACY: FormatData(
27-
name=VERSION_LEGACY,
28-
statement_directory='problem_statement',
29-
statement_extensions=['tex'],
30-
output_validator_directory='output_validators',
31-
),
32-
VERSION_2023_07: FormatData(
33-
name=VERSION_2023_07,
34-
statement_directory='statement',
35-
statement_extensions=['md', 'tex'],
36-
output_validator_directory='output_validator',
37-
),
38-
}
39-
FORMAT_DATACLASSES['2023-07'] = FORMAT_DATACLASSES[VERSION_2023_07] # Accept non-draft version string too
40-
41-
42-
def detect_problem_version(path: str) -> str:
43-
"""
44-
Returns the problem version value of problem.yaml or throws an error if it is unable to read the file.
45-
Args:
46-
path: the problem path
47-
48-
Returns:
49-
the version name as a String
50-
51-
"""
52-
config_path = os.path.join(path, 'problem.yaml')
53-
try:
54-
with open(config_path) as f:
55-
config: dict = yaml.safe_load(f) or {}
56-
except Exception as e:
57-
raise VersionError(f'Error reading problem.yaml: {e}')
58-
return config.get('problem_format_version', VERSION_LEGACY)
59-
60-
61-
def get_format_data(path: str) -> FormatData:
62-
"""
63-
Gets the dataclass object containing the necessary data for a problem format.
64-
Args:
65-
path: the problem path
66-
67-
Returns:
68-
the dataclass object containing the necessary data for a problem format
69-
70-
"""
71-
return get_format_data_by_name(detect_problem_version(path))
72-
73-
74-
def get_format_data_by_name(name: str) -> FormatData:
75-
"""
76-
Gets the dataclass object containing the necessary data for a problem format given the format name.
77-
Args:
78-
name: the format name
79-
80-
Returns:
81-
the dataclass object containing the necessary data for a problem format
82-
83-
"""
84-
data = FORMAT_DATACLASSES.get(name)
85-
if not data:
86-
raise VersionError(f'No version found with name {name}')
87-
else:
88-
return data
89-
90-
91-
class VersionError(Exception):
92-
pass
2+
from enum import StrEnum
3+
from pathlib import Path
4+
5+
6+
class FormatVersion(StrEnum):
7+
LEGACY = 'legacy'
8+
V_2023_07 = '2023-07-draft' # When 2023-07 is finalized, replace this and update _missing_
9+
10+
@property
11+
def statement_directory(self) -> str:
12+
match self:
13+
case FormatVersion.LEGACY:
14+
return 'problem_statement'
15+
case FormatVersion.V_2023_07:
16+
return 'statement'
17+
18+
@property
19+
def statement_extensions(self) -> list[str]:
20+
match self:
21+
case FormatVersion.LEGACY:
22+
return ['tex']
23+
case FormatVersion.V_2023_07:
24+
return ['md', 'tex']
25+
26+
@property
27+
def output_validator_directory(self) -> str:
28+
match self:
29+
case FormatVersion.LEGACY:
30+
return 'output_validators'
31+
case FormatVersion.V_2023_07:
32+
return 'output_validator'
33+
34+
# Support 2023-07 and 2023-07-draft strings.
35+
# This method should be replaced with an alias once we require python 3.13
36+
@classmethod
37+
def _missing_(cls, value):
38+
if value == '2023-07':
39+
return cls.V_2023_07
40+
return None
41+
42+
43+
def get_format_version(problem_root: Path) -> FormatVersion:
44+
"""Loads the version from the problem in problem_root"""
45+
with open(problem_root / 'problem.yaml') as f:
46+
config: dict = yaml.safe_load(f) or {}
47+
return FormatVersion(config.get('problem_format_version', FormatVersion.LEGACY))

problemtools/metadata.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import yaml
1212

1313
from . import config
14-
from . import formatversion
1514
from . import statement_util
15+
from .formatversion import FormatVersion
1616

1717

1818
class ProblemType(StrEnum):
@@ -164,7 +164,7 @@ class MetadataLegacy(BaseModel):
164164
which pre-date the version called legacy).
165165
"""
166166

167-
problem_format_version: str = formatversion.VERSION_LEGACY
167+
problem_format_version: FormatVersion = FormatVersion.LEGACY
168168
type: Literal['pass-fail'] | Literal['scoring'] = 'pass-fail'
169169
name: str | None = None
170170
uuid: UUID | None = None
@@ -191,7 +191,7 @@ class Metadata(BaseModel):
191191
Metadata serializes to a valid 2023-07-draft configuration.
192192
"""
193193

194-
problem_format_version: str
194+
problem_format_version: FormatVersion
195195
type: list[ProblemType]
196196
name: dict[str, str]
197197
uuid: UUID | None
@@ -309,7 +309,7 @@ def parse_person(person: str | Person) -> Person:
309309

310310

311311
def parse_metadata(
312-
version: formatversion.FormatData,
312+
version: FormatVersion,
313313
problem_yaml_data: dict[str, Any],
314314
names_from_statements: dict[str, str] | None = None,
315315
) -> Metadata:
@@ -326,11 +326,11 @@ def parse_metadata(
326326
system_defaults = config.load_config('problem.yaml')
327327
data['limits'] = system_defaults['limits'] | data.get('limits', {})
328328

329-
if version.name == formatversion.VERSION_LEGACY:
329+
if version is FormatVersion.LEGACY:
330330
legacy_model = MetadataLegacy.model_validate(data)
331331
return Metadata.from_legacy(legacy_model, names_from_statements or {})
332332
else:
333-
assert version.name == formatversion.VERSION_2023_07
333+
assert version is FormatVersion.V_2023_07
334334
model_2023_07 = Metadata2023_07.model_validate(data)
335335
return Metadata.from_2023_07(model_2023_07)
336336

@@ -347,8 +347,8 @@ def load_metadata(problem_root: Path) -> tuple[Metadata, dict]:
347347
if data is None: # Loading empty yaml returns None
348348
data = {}
349349

350-
version = formatversion.get_format_data_by_name(data.get('problem_format_version', formatversion.VERSION_LEGACY))
351-
if version.name == formatversion.VERSION_LEGACY:
350+
version = FormatVersion(data.get('problem_format_version', FormatVersion.LEGACY))
351+
if version is FormatVersion.LEGACY:
352352
names_from_statements = statement_util.load_names_from_statements(problem_root, version)
353353
else:
354354
names_from_statements = None

problemtools/statement_util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from pathlib import Path
99
from typing import Optional, List, Tuple
1010

11-
from . import formatversion
1211
from . import metadata
12+
from .formatversion import FormatVersion, get_format_version
1313

1414
ALLOWED_IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg') # ".svg"
1515
FOOTNOTES_STRINGS = ['<section class="footnotes">', '<aside class="footnotes">']
1616

1717

18-
def find_statements(problem_root: Path, version: formatversion.FormatData) -> dict[str, list[Path]]:
18+
def find_statements(problem_root: Path, version: FormatVersion) -> dict[str, list[Path]]:
1919
"""Returns a dict mapping language code to a list of paths to statements (relative to problem_root)
2020
2121
Note that in well-formed problem packages, there should only be a single
@@ -30,17 +30,17 @@ def find_statements(problem_root: Path, version: formatversion.FormatData) -> di
3030
for file in directory.iterdir():
3131
if m := filename_re.search(file.name):
3232
if m.group(2) is None: # problem.tex is allowed and assumed to be 'en' in legacy. We ignore it in newer formats.
33-
if version.name == formatversion.VERSION_LEGACY:
33+
if version is FormatVersion.LEGACY:
3434
ret['en'].append(file)
3535
else:
3636
ret[m.group(2)].append(file)
3737
return dict(ret)
3838

3939

40-
def load_names_from_statements(problem_root: Path, version: formatversion.FormatData) -> dict[str, str]:
40+
def load_names_from_statements(problem_root: Path, version: FormatVersion) -> dict[str, str]:
4141
"""Returns a dict mapping language code => problem name"""
4242

43-
assert version.name == formatversion.VERSION_LEGACY, 'load_names_from_statements only makes sense for legacy format'
43+
assert version is FormatVersion.LEGACY, 'load_names_from_statements only makes sense for legacy format'
4444
ret: dict[str, str] = {}
4545
for lang, files in find_statements(problem_root, version).items():
4646
hit = re.search(r'\\problemname{(.*)}', files[0].read_text(), re.MULTILINE)
@@ -56,7 +56,7 @@ def find_statement(problem_root: Path, language: str) -> Path:
5656
ValueError: if there are multiple statements in language.
5757
FileNotFoundError: if there are no statements in language.
5858
"""
59-
candidates = find_statements(problem_root, formatversion.get_format_data(str(problem_root)))
59+
candidates = find_statements(problem_root, get_format_version(problem_root))
6060
if language not in candidates:
6161
raise FileNotFoundError(f'No statement found in language {language}. Found languages: {", ".join(candidates)}')
6262
elif len(candidates[language]) > 1:

problemtools/verifyproblem.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828

2929
from . import config
3030
from . import languages
31-
from . import formatversion
3231
from . import metadata
3332
from . import problem2html
3433
from . import problem2pdf
3534
from . import run
3635
from . import statement_util
36+
from .formatversion import FormatVersion, get_format_version
3737

3838
from abc import ABC
3939
from typing import Any, Callable, ClassVar, Literal, Pattern, Match, ParamSpec, Type, TypeVar
@@ -819,16 +819,12 @@ def setup(self):
819819
error_str = '\n'.join([f' {"->".join((str(loc) for loc in err["loc"]))}: {err["msg"]}' for err in e.errors()])
820820
self.error(f'Failed parsing problem.yaml. Found {len(e.errors())} errors:\n{error_str}')
821821
# For now, set metadata to an empty legacy config to avoid crashing.
822-
self.problem.setMetadata(
823-
metadata.parse_metadata(formatversion.get_format_data_by_name(formatversion.VERSION_LEGACY), {})
824-
)
822+
self.problem.setMetadata(metadata.parse_metadata(FormatVersion.LEGACY, {}))
825823
except Exception as e:
826824
# This should likely be a fatal error, but I'm not sure there's a clean way to fail from setup
827825
self.error(f'Failed loading problem configuration: {e}')
828826
# For now, set metadata to an empty legacy config to avoid crashing.
829-
self.problem.setMetadata(
830-
metadata.parse_metadata(formatversion.get_format_data_by_name(formatversion.VERSION_LEGACY), {})
831-
)
827+
self.problem.setMetadata(metadata.parse_metadata(FormatVersion.LEGACY, {}))
832828
return {}
833829

834830
def __str__(self) -> str:
@@ -853,7 +849,7 @@ def check(self, context: Context) -> bool:
853849

854850
if self._metadata.uuid is None:
855851
uuid_msg = f'Missing uuid from problem.yaml. Add "uuid: {uuid.uuid4()}" to problem.yaml.'
856-
if self.problem.format.name == formatversion.VERSION_LEGACY:
852+
if self.problem.format is FormatVersion.LEGACY:
857853
self.warning(uuid_msg)
858854
else:
859855
self.error(uuid_msg)
@@ -864,7 +860,7 @@ def check(self, context: Context) -> bool:
864860
not self._metadata.is_pass_fail()
865861
and self.problem.get(ProblemTestCases)['root_group'].has_custom_groups()
866862
and 'show_test_data_groups' not in self._origdata.get('grading', {})
867-
and self.problem.format.name == formatversion.VERSION_LEGACY
863+
and self.problem.format is FormatVersion.LEGACY
868864
):
869865
self.warning(
870866
'Problem has custom testcase groups, but does not specify a value for grading.show_test_data_groups; defaulting to false'
@@ -1217,10 +1213,7 @@ class OutputValidators(ProblemPart):
12171213
PART_NAME = 'output_validator'
12181214

12191215
def setup(self):
1220-
if (
1221-
self.problem.format.name != formatversion.VERSION_LEGACY
1222-
and (Path(self.problem.probdir) / 'output_validators').exists()
1223-
):
1216+
if self.problem.format is FormatVersion.LEGACY and (Path(self.problem.probdir) / 'output_validators').exists():
12241217
self.error('output_validators is not supported after Legacy; please use output_validator instead')
12251218

12261219
self._validators = run.find_programs(
@@ -1351,7 +1344,7 @@ def _parse_validator_results(self, val, status: int, feedbackdir, testcase: Test
13511344
def _actual_validators(self) -> list:
13521345
vals = self._validators
13531346
if self.problem.getMetadata().legacy_validation == 'default' or (
1354-
self.problem.format.name == formatversion.VERSION_2023_07 and not vals
1347+
self.problem.format is FormatVersion.V_2023_07 and not vals
13551348
):
13561349
vals = [self._default_validator]
13571350
return [val for val in vals if val is not None]
@@ -1739,16 +1732,16 @@ def check(self, context: Context) -> bool:
17391732
return self._check_res
17401733

17411734

1742-
PROBLEM_FORMATS: dict[str, dict[str, list[Type[ProblemPart]]]] = {
1743-
formatversion.VERSION_LEGACY: {
1735+
PROBLEM_FORMATS: dict[FormatVersion, dict[str, list[Type[ProblemPart]]]] = {
1736+
FormatVersion.LEGACY: {
17441737
'config': [ProblemConfig],
17451738
'statement': [ProblemStatement, Attachments],
17461739
'validators': [InputValidators, OutputValidators],
17471740
'graders': [Graders],
17481741
'data': [ProblemTestCases],
17491742
'submissions': [Submissions],
17501743
},
1751-
formatversion.VERSION_2023_07: { # TODO: Add all the parts
1744+
FormatVersion.V_2023_07: { # TODO: Add all the parts
17521745
'config': [ProblemConfig],
17531746
'statement': [ProblemStatement, Attachments],
17541747
'validators': [InputValidators, OutputValidators],
@@ -1773,14 +1766,14 @@ class Problem(ProblemAspect):
17731766
of category -> part-types. You could for example have 'validators' -> [InputValidators, OutputValidators].
17741767
"""
17751768

1776-
def __init__(self, probdir: str, parts: dict[str, list[type]] = PROBLEM_FORMATS[formatversion.VERSION_LEGACY]):
1769+
def __init__(self, probdir: str, parts: dict[str, list[type]] = PROBLEM_FORMATS[FormatVersion.LEGACY]):
17771770
self.part_mapping: dict[str, list[Type[ProblemPart]]] = parts
17781771
self.aspects: set[type] = {v for s in parts.values() for v in s}
17791772
self.probdir = os.path.realpath(probdir)
17801773
self.shortname: str | None = os.path.basename(self.probdir)
17811774
super().__init__(self.shortname)
17821775
self.language_config = languages.load_language_config()
1783-
self.format = formatversion.get_format_data(self.probdir)
1776+
self.format = get_format_version(Path(self.probdir))
17841777
self._data: dict[str, dict] = {}
17851778
self._metadata: metadata.Metadata | None = None
17861779
self.debug(f'Problem-format: {parts}')
@@ -1860,8 +1853,8 @@ def check(self, args: argparse.Namespace) -> tuple[int, int]:
18601853
try:
18611854
if not re.match('^[a-z0-9]+$', self.shortname):
18621855
self.error(f"Invalid shortname '{self.shortname}' (must be [a-z0-9]+)")
1863-
if self.format.name == formatversion.VERSION_2023_07:
1864-
self.warning(f'Support for version {self.format.name} is very incomplete. Verification may not work as expected.')
1856+
if self.format is FormatVersion.V_2023_07:
1857+
self.warning(f'Support for version {self.format} is very incomplete. Verification may not work as expected.')
18651858

18661859
self._check_symlinks()
18671860

@@ -2007,17 +2000,16 @@ def main() -> None:
20072000
for problemdir in args.problemdir:
20082001
try:
20092002
if args.problem_format == 'automatic':
2010-
version_data = formatversion.get_format_data(problemdir)
2003+
formatversion = get_format_version(Path(problemdir))
20112004
else:
2012-
version_data = formatversion.get_format_data_by_name(args.problem_format)
2013-
except formatversion.VersionError as e:
2005+
formatversion = FormatVersion(args.problem_format)
2006+
except Exception as e:
20142007
total_errors += 1
20152008
print(f'ERROR: problem version could not be decided for {os.path.basename(os.path.realpath(problemdir))}: {e}')
20162009
continue
20172010

2018-
print(f'Loading problem {os.path.basename(os.path.realpath(problemdir))} with format version {version_data.name}')
2019-
format = PROBLEM_FORMATS[version_data.name]
2020-
with Problem(problemdir, format) as prob:
2011+
print(f'Loading problem {os.path.basename(os.path.realpath(problemdir))} with format version {formatversion}')
2012+
with Problem(problemdir, PROBLEM_FORMATS[formatversion]) as prob:
20212013
errors, warnings = prob.check(args)
20222014

20232015
def p(x: int) -> str:

0 commit comments

Comments
 (0)