Skip to content

Commit 150568a

Browse files
committed
Add check in create_submission(s) for supported language. Update tests to use fixtures.
1 parent 5efa149 commit 150568a

File tree

3 files changed

+101
-57
lines changed

3 files changed

+101
-57
lines changed

src/judge0/clients.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, endpoint, auth_headers, *, wait=False) -> None:
1111
self.wait = wait
1212

1313
try:
14-
self.languages = self.get_languages()
14+
self.languages = {lang["id"]: lang for lang in self.get_languages()}
1515
except Exception:
1616
raise RuntimeError("Client authentication failed.")
1717

@@ -51,8 +51,17 @@ def get_statuses(self) -> list[dict]:
5151
r.raise_for_status()
5252
return r.json()
5353

54+
def is_language_supported(self, language_id: int) -> bool:
55+
return language_id in self.languages
56+
5457
def create_submission(self, submission: Submission) -> None:
55-
# TODO: check if client supports specified language_id
58+
# Check if submission contains supported language.
59+
if not self.is_language_supported(language_id=submission.language_id):
60+
raise RuntimeError(
61+
f"Client {type(self).__name__} does not support language with "
62+
f"id {submission.language_id}!"
63+
)
64+
5665
params = {
5766
"base64_encoded": "true",
5867
"wait": str(self.wait).lower(),
@@ -90,6 +99,14 @@ def get_submission(self, submission: Submission, *, fields=None) -> None:
9099
submission.set_attributes(resp.json())
91100

92101
def create_submissions(self, submissions: list[Submission]) -> None:
102+
# Check if all submissions contain supported language.
103+
for submission in submissions:
104+
if not self.is_language_supported(language_id=submission.language_id):
105+
raise RuntimeError(
106+
f"Client {type(self).__name__} does not support language with "
107+
f"id {submission.language_id}!"
108+
)
109+
93110
params = {
94111
"base64_encoded": "true",
95112
"wait": str(self.wait).lower(),
@@ -195,11 +212,11 @@ def get_submission(self, submission: Submission, *, fields=None) -> None:
195212
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSION_ENDPOINT)
196213
return super().get_submission(submission, fields=fields)
197214

198-
def create_submissions(self, submissions: Submission) -> None:
215+
def create_submissions(self, submissions: list[Submission]) -> None:
199216
self._update_endpoint_header(self.DEFAULT_CREATE_SUBMISSIONS_ENDPOINT)
200217
return super().create_submissions(submissions)
201218

202-
def get_submissions(self, submissions: Submission, *, fields=None) -> None:
219+
def get_submissions(self, submissions: list[Submission], *, fields=None) -> None:
203220
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSIONS_ENDPOINT)
204221
return super().get_submissions(submissions, fields=fields)
205222

@@ -253,11 +270,11 @@ def get_submission(self, submission: Submission, *, fields=None):
253270
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSION_ENDPOINT)
254271
return super().get_submission(submission, fields=fields)
255272

256-
def create_submissions(self, submission: Submission) -> None:
273+
def create_submissions(self, submission: list[Submission]) -> None:
257274
self._update_endpoint_header(self.DEFAULT_CREATE_SUBMISSIONS_ENDPOINT)
258275
return super().create_submissions(submission)
259276

260-
def get_submissions(self, submission: Submission, *, fields=None) -> None:
277+
def get_submissions(self, submission: list[Submission], *, fields=None) -> None:
261278
self._update_endpoint_header(self.DEFAULT_GET_SUBMISSIONS_ENDPOINT)
262279
return super().get_submissions(submission, fields=fields)
263280

tests/conftest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
3+
import pytest
4+
from dotenv import load_dotenv
5+
6+
from judge0 import clients
7+
8+
load_dotenv()
9+
10+
11+
@pytest.fixture(scope="session")
12+
def atd_ce_client():
13+
api_key = os.getenv("ATD_API_KEY")
14+
client = clients.ATDJudge0CE(api_key)
15+
return client
16+
17+
18+
@pytest.fixture(scope="session")
19+
def atd_extra_ce_client():
20+
api_key = os.getenv("ATD_API_KEY")
21+
client = clients.ATDJudge0ExtraCE(api_key)
22+
return client
23+
24+
25+
@pytest.fixture(scope="session")
26+
def rapid_ce_client():
27+
api_key = os.getenv("RAPID_API_KEY")
28+
client = clients.RapidJudge0CE(api_key)
29+
return client
30+
31+
32+
@pytest.fixture(scope="session")
33+
def rapid_extra_ce_client():
34+
api_key = os.getenv("RAPID_API_KEY")
35+
client = clients.RapidJudge0ExtraCE(api_key)
36+
return client
37+
38+
39+
@pytest.fixture(scope="session")
40+
def sulu_ce_client():
41+
api_key = os.getenv("SULU_API_KEY")
42+
client = clients.SuluJudge0CE(api_key)
43+
return client
44+
45+
46+
@pytest.fixture(scope="session")
47+
def sulu_extra_ce_client():
48+
api_key = os.getenv("SULU_API_KEY")
49+
client = clients.SuluJudge0ExtraCE(api_key)
50+
return client

tests/test_clients.py

Lines changed: 28 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,46 @@
1-
import os
2-
31
import pytest
4-
from dotenv import load_dotenv
5-
6-
from judge0 import clients
7-
8-
load_dotenv()
9-
102

11-
def test_atd_ce_client():
12-
api_key = os.getenv("ATD_API_KEY")
3+
DEFAULT_CLIENTS = [
4+
"atd_ce_client",
5+
"atd_extra_ce_client",
6+
"rapid_ce_client",
7+
"rapid_extra_ce_client",
8+
"sulu_ce_client",
9+
"sulu_extra_ce_client",
10+
]
1311

14-
client = clients.ATDJudge0CE(api_key)
1512

13+
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
14+
def test_get_about(client, request):
15+
client = request.getfixturevalue(client)
1616
client.get_about()
17-
client.get_config_info()
18-
client.get_languages()
19-
client.get_statuses()
20-
2117

22-
def test_atd_extra_ce_client():
23-
api_key = os.getenv("ATD_API_KEY")
24-
client = clients.ATDJudge0ExtraCE(api_key)
2518

26-
client.get_about()
19+
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
20+
def test_get_config_info(client, request):
21+
client = request.getfixturevalue(client)
2722
client.get_config_info()
28-
client.get_languages()
29-
client.get_statuses()
30-
3123

32-
def test_rapid_ce_client():
33-
api_key = os.getenv("RAPID_API_KEY")
34-
client = clients.RapidJudge0CE(api_key)
3524

36-
client.get_about()
37-
client.get_config_info()
25+
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
26+
def test_get_languages(client, request):
27+
client = request.getfixturevalue(client)
3828
client.get_languages()
39-
client.get_statuses()
4029

4130

42-
def test_rapid_extra_ce_client():
43-
api_key = os.getenv("RAPID_API_KEY")
44-
client = clients.RapidJudge0ExtraCE(api_key)
45-
46-
client.get_about()
47-
client.get_config_info()
48-
client.get_languages()
31+
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
32+
def test_get_statuses(client, request):
33+
client = request.getfixturevalue(client)
4934
client.get_statuses()
5035

5136

52-
def test_sulu_ce_client():
53-
api_key = os.getenv("SULU_API_KEY")
54-
client = clients.SuluJudge0CE(api_key)
55-
56-
client.get_about()
57-
client.get_config_info()
58-
client.get_languages()
59-
client.get_statuses()
60-
37+
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
38+
def test_is_language_supported_multi_file_submission(client, request):
39+
client = request.getfixturevalue(client)
40+
assert client.is_language_supported(89)
6141

62-
def test_sulu_extra_ce_client():
63-
api_key = os.getenv("SULU_API_KEY")
64-
client = clients.SuluJudge0ExtraCE(api_key)
6542

66-
client.get_about()
67-
client.get_config_info()
68-
client.get_languages()
69-
client.get_statuses()
43+
@pytest.mark.parametrize("client", DEFAULT_CLIENTS)
44+
def test_is_language_supported_non_valid_lang_id(client, request):
45+
client = request.getfixturevalue(client)
46+
assert not client.is_language_supported(-1)

0 commit comments

Comments
 (0)