Skip to content

Commit c055e17

Browse files
committed
Add retry strategy to clients. Make submissions and test cases iterable. Increase retry frequency for default implicit Sulu client.
1 parent 0d3d82b commit c055e17

File tree

7 files changed

+77
-71
lines changed

7 files changed

+77
-71
lines changed

src/judge0/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def _get_implicit_client(flavor: Flavor) -> Client:
9898
# the preview Sulu client based on the flavor.
9999
if client is None:
100100
if flavor == Flavor.CE:
101-
client = SuluJudge0CE()
101+
client = SuluJudge0CE(retry_strategy=RegularPeriodRetry(0.5))
102102
else:
103-
client = SuluJudge0ExtraCE()
103+
client = SuluJudge0ExtraCE(retry_strategy=RegularPeriodRetry(0.5))
104104

105105
if flavor == Flavor.CE:
106106
JUDGE0_IMPLICIT_CE_CLIENT = client

src/judge0/api.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Iterable, Optional, Union
1+
from typing import Optional, Union
22

3-
from .base_types import Flavor, TestCase, TestCases
3+
from .base_types import Flavor, Iterable, TestCase, TestCases
44
from .clients import Client
55
from .common import batched
66

7-
from .retry import RegularPeriodRetry, RetryMechanism
7+
from .retry import RegularPeriodRetry, RetryStrategy
88
from .submission import Submission, Submissions
99

1010

@@ -31,7 +31,7 @@ def _resolve_client(
3131
if isinstance(client, Flavor):
3232
return get_client(client)
3333

34-
if client is None and isinstance(submissions, list) and len(submissions) == 0:
34+
if client is None and isinstance(submissions, Iterable) and len(submissions) == 0:
3535
raise ValueError("Client cannot be determined from empty submissions.")
3636

3737
# client is None and we have to determine a flavor of the client from the
@@ -57,6 +57,7 @@ def _resolve_client(
5757

5858

5959
def create_submissions(
60+
*,
6061
client: Optional[Client] = None,
6162
submissions: Optional[Union[Submission, Submissions]] = None,
6263
) -> Union[Submission, Submissions]:
@@ -81,7 +82,7 @@ def get_submissions(
8182
*,
8283
client: Optional[Client] = None,
8384
submissions: Optional[Union[Submission, Submissions]] = None,
84-
fields: Union[str, Iterable[str], None] = None,
85+
fields: Optional[Union[str, Iterable[str]]] = None,
8586
) -> Union[Submission, Submissions]:
8687
client = _resolve_client(client=client, submissions=submissions)
8788

@@ -108,12 +109,15 @@ def wait(
108109
*,
109110
client: Optional[Client] = None,
110111
submissions: Optional[Union[Submission, Submissions]] = None,
111-
retry_mechanism: Optional[RetryMechanism] = None,
112+
retry_strategy: Optional[RetryStrategy] = None,
112113
) -> Union[Submission, Submissions]:
113114
client = _resolve_client(client, submissions)
114115

115-
if retry_mechanism is None:
116-
retry_mechanism = RegularPeriodRetry()
116+
if retry_strategy is None:
117+
if client.retry_strategy is None:
118+
retry_strategy = RegularPeriodRetry()
119+
else:
120+
retry_strategy = client.retry_strategy
117121

118122
if isinstance(submissions, Submission):
119123
submissions_to_check = {
@@ -124,7 +128,7 @@ def wait(
124128
submission.token: submission for submission in submissions
125129
}
126130

127-
while len(submissions_to_check) > 0 and not retry_mechanism.is_done():
131+
while len(submissions_to_check) > 0 and not retry_strategy.is_done():
128132
get_submissions(client=client, submissions=list(submissions_to_check.values()))
129133
for token in list(submissions_to_check):
130134
submission = submissions_to_check[token]
@@ -135,8 +139,8 @@ def wait(
135139
if len(submissions_to_check) == 0:
136140
break
137141

138-
retry_mechanism.wait()
139-
retry_mechanism.step()
142+
retry_strategy.wait()
143+
retry_strategy.step()
140144

141145
return submissions
142146

@@ -204,6 +208,7 @@ def _execute(
204208
if submissions is None and source_code is None:
205209
raise ValueError("Neither source_code nor submissions argument are provided.")
206210

211+
# Internally, let's rely on Submission's dataclass.
207212
if source_code is not None:
208213
submissions = Submission(source_code=source_code, **kwargs)
209214

src/judge0/base_types.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
33
from enum import IntEnum
4-
from typing import Optional, Union
5-
6-
7-
TestCases = Union[
8-
list["TestCase"],
9-
tuple["TestCase"],
10-
list[dict],
11-
tuple[dict],
12-
list[list],
13-
list[tuple],
14-
tuple[list],
15-
tuple[tuple],
16-
]
4+
from typing import Optional, Sequence, Union
5+
6+
Iterable = Sequence
7+
8+
TestCases = Iterable["TestCase"]
179

1810

1911
@dataclass(frozen=True)

src/judge0/clients.py

+43-29
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
1-
from typing import Iterable, Union
1+
from typing import Optional, Union
22

33
import requests
44

5-
from .base_types import Config, Language, LanguageAlias
5+
from .base_types import Config, Iterable, Language, LanguageAlias
66
from .data import LANGUAGE_TO_LANGUAGE_ID
7+
from .retry import RetryStrategy
78
from .submission import Submission, Submissions
89

910

1011
class Client:
11-
API_KEY_ENV = "JUDGE0_API_KEY"
12-
DEFAULT_MAX_SUBMISSION_BATCH_SIZE = 20
13-
ENABLED_BATCHED_SUBMISSIONS = True
14-
EFFECTIVE_SUBMISSION_BATCH_SIZE = (
15-
DEFAULT_MAX_SUBMISSION_BATCH_SIZE if ENABLED_BATCHED_SUBMISSIONS else 1
16-
)
12+
API_KEY_ENV = None
1713

18-
def __init__(self, endpoint, auth_headers) -> None:
14+
def __init__(
15+
self,
16+
endpoint,
17+
auth_headers,
18+
*,
19+
retry_strategy: Optional[RetryStrategy] = None,
20+
) -> None:
1921
self.endpoint = endpoint
2022
self.auth_headers = auth_headers
23+
self.retry_strategy = retry_strategy
2124

2225
try:
23-
self.languages = [Language(**lang) for lang in self.get_languages()]
26+
self.languages = tuple(Language(**lang) for lang in self.get_languages())
2427
self.config = Config(**self.get_config_info())
2528
except Exception as e:
2629
raise RuntimeError(
@@ -113,7 +116,7 @@ def get_submission(
113116
self,
114117
submission: Submission,
115118
*,
116-
fields: Union[str, Iterable[str], None] = None,
119+
fields: Optional[Union[str, Iterable[str]]] = None,
117120
) -> Submission:
118121
"""Check the submission status."""
119122

@@ -168,7 +171,7 @@ def get_submissions(
168171
self,
169172
submissions: Submissions,
170173
*,
171-
fields: Union[str, Iterable[str], None] = None,
174+
fields: Optional[Union[str, Iterable[str]]] = None,
172175
) -> Submissions:
173176
params = {
174177
"base64_encoded": "true",
@@ -201,14 +204,15 @@ def get_submissions(
201204
class ATD(Client):
202205
API_KEY_ENV = "JUDGE0_ATD_API_KEY"
203206

204-
def __init__(self, endpoint, host_header_value, api_key):
207+
def __init__(self, endpoint, host_header_value, api_key, **kwargs):
205208
self.api_key = api_key
206209
super().__init__(
207210
endpoint,
208211
{
209212
"x-apihub-host": host_header_value,
210213
"x-apihub-key": api_key,
211214
},
215+
**kwargs,
212216
)
213217

214218
def _update_endpoint_header(self, header_value):
@@ -232,11 +236,12 @@ class ATDJudge0CE(ATD):
232236
DEFAULT_CREATE_SUBMISSIONS_ENDPOINT: str = "402b857c-1126-4450-bfd8-22e1f2cbff2f"
233237
DEFAULT_GET_SUBMISSIONS_ENDPOINT: str = "e42f2a26-5b02-472a-80c9-61c4bdae32ec"
234238

235-
def __init__(self, api_key):
239+
def __init__(self, api_key, **kwargs):
236240
super().__init__(
237241
self.DEFAULT_ENDPOINT,
238242
self.DEFAULT_HOST,
239243
api_key,
244+
**kwargs,
240245
)
241246

242247
def get_about(self) -> dict:
@@ -267,7 +272,7 @@ def get_submission(
267272
self,
268273
submission: Submission,
269274
*,
270-
fields: Union[str, Iterable[str], None] = None,
275+
fields: Optional[Union[str, Iterable[str]]] = None,
271276
) -> Submission:
272277
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSION_ENDPOINT)
273278
return super().get_submission(submission, fields=fields)
@@ -280,7 +285,7 @@ def get_submissions(
280285
self,
281286
submissions: Submissions,
282287
*,
283-
fields: Union[str, Iterable[str], None] = None,
288+
fields: Optional[Union[str, Iterable[str]]] = None,
284289
) -> Submissions:
285290
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSIONS_ENDPOINT)
286291
return super().get_submissions(submissions, fields=fields)
@@ -303,11 +308,12 @@ class ATDJudge0ExtraCE(ATD):
303308
DEFAULT_CREATE_SUBMISSIONS_ENDPOINT: str = "c64df5d3-edfd-4b08-8687-561af2f80d2f"
304309
DEFAULT_GET_SUBMISSIONS_ENDPOINT: str = "5d173718-8e6a-4cf5-9d8c-db5e6386d037"
305310

306-
def __init__(self, api_key):
311+
def __init__(self, api_key, **kwargs):
307312
super().__init__(
308313
self.DEFAULT_ENDPOINT,
309314
self.DEFAULT_HOST,
310315
api_key,
316+
**kwargs,
311317
)
312318

313319
def get_about(self) -> dict:
@@ -338,7 +344,7 @@ def get_submission(
338344
self,
339345
submission: Submission,
340346
*,
341-
fields: Union[str, Iterable[str], None] = None,
347+
fields: Optional[Union[str, Iterable[str]]] = None,
342348
) -> Submission:
343349
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSION_ENDPOINT)
344350
return super().get_submission(submission, fields=fields)
@@ -351,7 +357,7 @@ def get_submissions(
351357
self,
352358
submissions: Submissions,
353359
*,
354-
fields: Union[str, Iterable[str], None] = None,
360+
fields: Optional[Union[str, Iterable[str]]] = None,
355361
) -> Submissions:
356362
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSIONS_ENDPOINT)
357363
return super().get_submissions(submissions, fields=fields)
@@ -360,14 +366,15 @@ def get_submissions(
360366
class Rapid(Client):
361367
API_KEY_ENV = "JUDGE0_RAPID_API_KEY"
362368

363-
def __init__(self, endpoint, host_header_value, api_key):
369+
def __init__(self, endpoint, host_header_value, api_key, **kwargs):
364370
self.api_key = api_key
365371
super().__init__(
366372
endpoint,
367373
{
368374
"x-rapidapi-host": host_header_value,
369375
"x-rapidapi-key": api_key,
370376
},
377+
**kwargs,
371378
)
372379

373380

@@ -376,11 +383,12 @@ class RapidJudge0CE(Rapid):
376383
DEFAULT_HOST: str = "judge0-ce.p.rapidapi.com"
377384
HOME_URL: str = "https://rapidapi.com/judge0-official/api/judge0-ce"
378385

379-
def __init__(self, api_key):
386+
def __init__(self, api_key, **kwargs):
380387
super().__init__(
381388
self.DEFAULT_ENDPOINT,
382389
self.DEFAULT_HOST,
383390
api_key,
391+
**kwargs,
384392
)
385393

386394

@@ -389,40 +397,46 @@ class RapidJudge0ExtraCE(Rapid):
389397
DEFAULT_HOST: str = "judge0-extra-ce.p.rapidapi.com"
390398
HOME_URL: str = "https://rapidapi.com/judge0-official/api/judge0-extra-ce"
391399

392-
def __init__(self, api_key):
400+
def __init__(self, api_key, **kwargs):
393401
super().__init__(
394402
self.DEFAULT_ENDPOINT,
395403
self.DEFAULT_HOST,
396404
api_key,
405+
**kwargs,
397406
)
398407

399408

400409
class Sulu(Client):
401410
API_KEY_ENV = "JUDGE0_SULU_API_KEY"
402411

403-
def __init__(self, endpoint, api_key=None):
412+
def __init__(self, endpoint, api_key=None, **kwargs):
404413
self.api_key = api_key
405414
super().__init__(
406415
endpoint,
407416
{"Authorization": f"Bearer {api_key}"} if api_key else None,
417+
**kwargs,
408418
)
409419

410420

411421
class SuluJudge0CE(Sulu):
412422
DEFAULT_ENDPOINT: str = "https://judge0-ce.p.sulu.sh"
413423
HOME_URL: str = "https://sparkhub.sulu.sh/apis/judge0/judge0-ce/readme"
414424

415-
def __init__(self, api_key=None):
416-
super().__init__(self.DEFAULT_ENDPOINT, api_key)
425+
def __init__(self, api_key=None, **kwargs):
426+
super().__init__(
427+
self.DEFAULT_ENDPOINT,
428+
api_key,
429+
**kwargs,
430+
)
417431

418432

419433
class SuluJudge0ExtraCE(Sulu):
420434
DEFAULT_ENDPOINT: str = "https://judge0-extra-ce.p.sulu.sh"
421435
HOME_URL: str = "https://sparkhub.sulu.sh/apis/judge0/judge0-extra-ce/readme"
422436

423-
def __init__(self, api_key=None):
424-
super().__init__(self.DEFAULT_ENDPOINT, api_key)
437+
def __init__(self, api_key=None, **kwargs):
438+
super().__init__(self.DEFAULT_ENDPOINT, api_key, **kwargs)
425439

426440

427-
CE = [RapidJudge0CE, SuluJudge0CE, ATDJudge0CE]
428-
EXTRA_CE = [RapidJudge0ExtraCE, SuluJudge0ExtraCE, ATDJudge0ExtraCE]
441+
CE = (RapidJudge0CE, SuluJudge0CE, ATDJudge0CE)
442+
EXTRA_CE = (RapidJudge0ExtraCE, SuluJudge0ExtraCE, ATDJudge0ExtraCE)

src/judge0/filesystem.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import zipfile
44

55
from base64 import b64decode, b64encode
6-
from collections import abc
7-
from typing import Iterable, Optional, Union
6+
from typing import Optional, Union
87

9-
from .base_types import Encodeable
8+
from .base_types import Encodeable, Iterable
109

1110

1211
class File:
@@ -42,7 +41,7 @@ def __init__(
4241
for file_name in zip_file.namelist():
4342
with zip_file.open(file_name) as fp:
4443
self.files.append(File(file_name, fp.read()))
45-
elif isinstance(content, abc.Iterable):
44+
elif isinstance(content, Iterable):
4645
self.files = list(content)
4746
elif isinstance(content, File):
4847
self.files = [content]

0 commit comments

Comments
 (0)