Skip to content

Commit 388aeec

Browse files
committed
refactored how workspace works in gymlib tests
1 parent e6ed566 commit 388aeec

6 files changed

+45
-31
lines changed

env/tests/gymlib_integtest_util.py

+12-22
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,20 @@ class GymlibIntegtestManager:
3535
BENCHMARK = "tpch"
3636
SCALE_FACTOR = 0.01
3737
DBGYM_CONFIG_PATH = Path("env/tests/gymlib_integtest_dbgym_config.yaml")
38-
39-
# This is set at most once by set_up_workspace().
40-
DBGYM_WORKSPACE: Optional[DBGymWorkspace] = None
38+
WORKSPACE_PATH: Optional[Path] = None
4139

4240
@staticmethod
4341
def set_up_workspace() -> None:
4442
"""
4543
Set up the workspace if it has not already been set up.
4644
None of the integtest_*.py files will delete the workspace so that future tests run faster.
4745
"""
48-
workspace_path = get_workspace_path_from_config(
46+
GymlibIntegtestManager.WORKSPACE_PATH = get_workspace_path_from_config(
4947
GymlibIntegtestManager.DBGYM_CONFIG_PATH
5048
)
5149

5250
# This if statement prevents us from setting up the workspace twice, which saves time.
53-
if not workspace_path.exists():
51+
if not GymlibIntegtestManager.WORKSPACE_PATH.exists():
5452
subprocess.run(
5553
["./env/tests/_set_up_gymlib_integtest_workspace.sh"],
5654
env={
@@ -64,23 +62,13 @@ def set_up_workspace() -> None:
6462
check=True,
6563
)
6664

67-
# Once we get here, we have an invariant that the workspace exists. We need this
68-
# invariant to be true in order to create the DBGymWorkspace.
69-
#
70-
# However, it also can't be created more than once so we need to check `is None`.
71-
if GymlibIntegtestManager.DBGYM_WORKSPACE is None:
72-
# Reset this in case it had been created by a test *not* using GymlibIntegtestManager.set_up_workspace().
73-
DBGymWorkspace._num_times_created_this_run = 0
74-
GymlibIntegtestManager.DBGYM_WORKSPACE = DBGymWorkspace(workspace_path)
75-
7665
@staticmethod
77-
def get_dbgym_workspace() -> DBGymWorkspace:
78-
assert GymlibIntegtestManager.DBGYM_WORKSPACE is not None
79-
return GymlibIntegtestManager.DBGYM_WORKSPACE
66+
def get_workspace_path() -> Path:
67+
assert GymlibIntegtestManager.WORKSPACE_PATH is not None
68+
return GymlibIntegtestManager.WORKSPACE_PATH
8069

8170
@staticmethod
8271
def get_default_metadata() -> TuningMetadata:
83-
dbgym_workspace = GymlibIntegtestManager.get_dbgym_workspace()
8472
assert GymlibIntegtestManager.BENCHMARK == "tpch"
8573
suffix = get_workload_suffix(
8674
GymlibIntegtestManager.BENCHMARK,
@@ -91,23 +79,25 @@ def get_default_metadata() -> TuningMetadata:
9179
return TuningMetadata(
9280
workload_path=fully_resolve_path(
9381
get_workload_symlink_path(
94-
dbgym_workspace.dbgym_workspace_path,
82+
GymlibIntegtestManager.get_workspace_path(),
9583
GymlibIntegtestManager.BENCHMARK,
9684
GymlibIntegtestManager.SCALE_FACTOR,
9785
suffix,
9886
),
9987
),
10088
pristine_dbdata_snapshot_path=fully_resolve_path(
10189
get_dbdata_tgz_symlink_path(
102-
dbgym_workspace.dbgym_workspace_path,
90+
GymlibIntegtestManager.get_workspace_path(),
10391
GymlibIntegtestManager.BENCHMARK,
10492
GymlibIntegtestManager.SCALE_FACTOR,
10593
),
10694
),
10795
dbdata_parent_path=fully_resolve_path(
108-
get_tmp_path_from_workspace_path(dbgym_workspace.dbgym_workspace_path),
96+
get_tmp_path_from_workspace_path(
97+
GymlibIntegtestManager.get_workspace_path()
98+
),
10999
),
110100
pgbin_path=fully_resolve_path(
111-
get_pgbin_symlink_path(dbgym_workspace.dbgym_workspace_path),
101+
get_pgbin_symlink_path(GymlibIntegtestManager.get_workspace_path()),
112102
),
113103
)

env/tests/integtest_pg_conn.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,18 @@
1010
get_is_postgres_running,
1111
get_running_postgres_ports,
1212
)
13+
from util.workspace import DBGymWorkspace
1314

1415

1516
class PostgresConnTests(unittest.TestCase):
17+
workspace: DBGymWorkspace
18+
1619
@staticmethod
1720
def setUpClass() -> None:
1821
GymlibIntegtestManager.set_up_workspace()
22+
PostgresConnTests.workspace = DBGymWorkspace(
23+
GymlibIntegtestManager.get_workspace_path()
24+
)
1925

2026
def setUp(self) -> None:
2127
self.assertFalse(
@@ -38,7 +44,7 @@ def tearDown(self) -> None:
3844

3945
def create_pg_conn(self, pgport: int = DEFAULT_POSTGRES_PORT) -> PostgresConn:
4046
return PostgresConn(
41-
GymlibIntegtestManager.get_dbgym_workspace(),
47+
PostgresConnTests.workspace,
4248
pgport,
4349
self.metadata.pristine_dbdata_snapshot_path,
4450
self.metadata.dbdata_parent_path,

env/tests/integtest_replay.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,22 @@
1010
SysKnobsDelta,
1111
TuningArtifactsWriter,
1212
)
13+
from util.workspace import DBGymWorkspace
1314

1415

1516
class ReplayTests(unittest.TestCase):
17+
workspace: DBGymWorkspace
18+
1619
@staticmethod
1720
def setUpClass() -> None:
1821
GymlibIntegtestManager.set_up_workspace()
22+
ReplayTests.workspace = DBGymWorkspace(
23+
GymlibIntegtestManager.get_workspace_path()
24+
)
1925

2026
def test_replay(self) -> None:
2127
writer = TuningArtifactsWriter(
22-
GymlibIntegtestManager.get_dbgym_workspace(),
28+
ReplayTests.workspace,
2329
GymlibIntegtestManager.get_default_metadata(),
2430
)
2531
writer.write_step(
@@ -41,7 +47,7 @@ def test_replay(self) -> None:
4147
)
4248
)
4349
replay_data = replay(
44-
GymlibIntegtestManager.get_dbgym_workspace(),
50+
ReplayTests.workspace,
4551
writer.tuning_artifacts_path,
4652
)
4753

env/tests/integtest_tuning_artifacts.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,19 @@
99
TuningArtifactsReader,
1010
TuningArtifactsWriter,
1111
)
12+
from util.workspace import DBGymWorkspace
1213

1314

1415
class PostgresConnTests(unittest.TestCase):
1516
@staticmethod
1617
def setUpClass() -> None:
1718
GymlibIntegtestManager.set_up_workspace()
1819

20+
def setUp(self) -> None:
21+
# We re-create a workspace for each test because each test will create its own TuningArtifactsWriter.
22+
DBGymWorkspace._num_times_created_this_run = 0
23+
self.workspace = DBGymWorkspace(GymlibIntegtestManager.get_workspace_path())
24+
1925
@staticmethod
2026
def make_config(letter: str) -> DBMSConfigDelta:
2127
return DBMSConfigDelta(
@@ -26,7 +32,7 @@ def make_config(letter: str) -> DBMSConfigDelta:
2632

2733
def test_get_delta_at_step(self) -> None:
2834
writer = TuningArtifactsWriter(
29-
GymlibIntegtestManager.get_dbgym_workspace(),
35+
self.workspace,
3036
GymlibIntegtestManager.get_default_metadata(),
3137
)
3238

@@ -51,7 +57,7 @@ def test_get_delta_at_step(self) -> None:
5157

5258
def test_get_all_deltas_in_order(self) -> None:
5359
writer = TuningArtifactsWriter(
54-
GymlibIntegtestManager.get_dbgym_workspace(),
60+
self.workspace,
5561
GymlibIntegtestManager.get_default_metadata(),
5662
)
5763

@@ -72,7 +78,7 @@ def test_get_all_deltas_in_order(self) -> None:
7278

7379
def test_get_metadata(self) -> None:
7480
writer = TuningArtifactsWriter(
75-
GymlibIntegtestManager.get_dbgym_workspace(),
81+
self.workspace,
7682
GymlibIntegtestManager.get_default_metadata(),
7783
)
7884
reader = TuningArtifactsReader(writer.tuning_artifacts_path)

env/tests/integtest_workload.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@
33
from benchmark.tpch.constants import DEFAULT_TPCH_SEED, NUM_TPCH_QUERIES
44
from env.tests.gymlib_integtest_util import GymlibIntegtestManager
55
from env.workload import Workload
6+
from util.workspace import DBGymWorkspace
67

78

89
class WorkloadTests(unittest.TestCase):
10+
workspace: DBGymWorkspace
11+
912
@staticmethod
1013
def setUpClass() -> None:
1114
GymlibIntegtestManager.set_up_workspace()
15+
WorkloadTests.workspace = DBGymWorkspace(
16+
GymlibIntegtestManager.get_workspace_path()
17+
)
1218

1319
def test_workload(self) -> None:
1420
workload_path = GymlibIntegtestManager.get_default_metadata().workload_path
15-
16-
workload = Workload(GymlibIntegtestManager.get_dbgym_workspace(), workload_path)
21+
workload = Workload(WorkloadTests.workspace, workload_path)
1722

1823
# Check the order of query IDs.
1924
self.assertEqual(

env/tuning_artifacts.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from dataclasses import asdict, dataclass
33
from pathlib import Path
4-
from typing import Any, NewType, TypedDict
4+
from typing import Any, NewType
55

66
from util.workspace import DBGymWorkspace, is_fully_resolved
77

@@ -84,6 +84,7 @@ def __init__(
8484
self.tuning_artifacts_path = (
8585
self.dbgym_workspace.dbgym_this_run_path / "tuning_artifacts"
8686
)
87+
# exist_ok is False because you should only create one TuningArtifactsWriter per run.
8788
self.tuning_artifacts_path.mkdir(parents=False, exist_ok=False)
8889
assert is_fully_resolved(self.tuning_artifacts_path)
8990
self.next_step_num = 0

0 commit comments

Comments
 (0)