Skip to content

Commit

Permalink
made time_query actually apply timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
wangpatrick57 committed Nov 14, 2024
1 parent 5f006ea commit 31f782d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
23 changes: 14 additions & 9 deletions env/integtest_pg_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
from pathlib import Path

import yaml
from psycopg.errors import QueryCanceled

from env.pg_conn import PostgresConn
from util.pg import (
DEFAULT_POSTGRES_PORT,
get_is_postgres_running,
get_running_postgres_ports,
)
from util.pg import DEFAULT_POSTGRES_PORT, get_is_postgres_running, get_running_postgres_ports
from util.workspace import (
DEFAULT_BOOT_CONFIG_FPATH,
DBGymConfig,
Expand Down Expand Up @@ -180,19 +177,27 @@ def test_time_query(self) -> None:
pg_conn.restart_postgres()

# Test
# No explain
runtime, did_time_out, explain_data = pg_conn.time_query("select pg_sleep(1)", 2)
# Testing no explain no timeout.
runtime, did_time_out, explain_data = pg_conn.time_query("select pg_sleep(1)")
# The runtime should be about 1 second.
self.assertTrue(abs(runtime - 1_000_000) < 100_000)
self.assertFalse(did_time_out)
self.assertIsNone(explain_data)

# With explain
runtime, did_time_out, explain_data = pg_conn.time_query("explain (analyze, format json, timing off) select pg_sleep(1)", 2)
# Testing with explain.
runtime, did_time_out, explain_data = pg_conn.time_query(
"explain (analyze, format json, timing off) select pg_sleep(1)"
)
self.assertTrue(abs(runtime - 1_000_000) < 100_000)
self.assertFalse(did_time_out)
self.assertIsNotNone(explain_data)

# Testing with timeout.
runtime, did_time_out, _ = pg_conn.time_query("select pg_sleep(3)", 2)
# The runtime should be about what the timeout is.
self.assertTrue(abs(runtime - 2_000_000) < 100_000)
self.assertTrue(did_time_out)

# Cleanup
pg_conn.shutdown_postgres()

Expand Down
28 changes: 25 additions & 3 deletions env/pg_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,30 @@ def move_log(self) -> None:
shutil.move(pglog_fpath, pglog_this_step_fpath)
self.log_step += 1

def time_query(self, query: str, timeout: float) -> tuple[float, bool, Any]:
def force_statement_timeout(self, timeout_sec: float) -> None:
timeout_ms = timeout_sec * 1000
retry = True
while retry:
retry = False
try:
self.conn().execute(f"SET statement_timeout = {timeout_ms}")
except QueryCanceled:
retry = True

def time_query(self, query: str, timeout_sec: float = 0) -> tuple[float, bool, Any]:
"""
Run a query with a timeout. If you want to attach per-query knobs, attach them to the query string itself.
Following Postgres's convention, timeout=0 indicates "disable timeout"
It returns the runtime, whether the query timed out, and the explain data.
"""
if timeout_sec > 0:
self.force_statement_timeout(timeout_sec)
else:
assert (
timeout_sec == 0
), f'Setting timeout_sec to 0 indicates "disable timeout". However, setting timeout_sec ({timeout_sec}) < 0 is a bug.'

did_time_out = False
has_explain = "explain" in query.lower()
explain_data = None
Expand All @@ -126,12 +144,16 @@ def time_query(self, query: str, timeout: float) -> tuple[float, bool, Any]:

except QueryCanceled:
logging.getLogger(DBGYM_LOGGER_NAME).debug(
f"{query} exceeded evaluation timeout {timeout}"
f"{query} exceeded evaluation timeout {timeout_sec}"
)
qid_runtime = timeout * 1e6
qid_runtime = timeout_sec * 1e6
did_time_out = True
except Exception as e:
assert False, e
finally:
# Wipe the statement timeout.
self.force_statement_timeout(0)

# qid_runtime is in microseconds.
return qid_runtime, did_time_out, explain_data

Expand Down
23 changes: 1 addition & 22 deletions tune/protox/env/util/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,18 @@
from util.log import DBGYM_LOGGER_NAME


def _force_statement_timeout(
connection: psycopg.Connection[Any], timeout_ms: float
) -> None:
retry = True
while retry:
retry = False
try:
connection.execute(f"SET statement_timeout = {timeout_ms}")
except QueryCanceled:
retry = True


def _acquire_metrics_around_query(
pg_conn: PostgresConn,
query: str,
query_timeout: float = 0.0,
observation_space: Optional[StateSpace] = None,
) -> tuple[float, bool, Any, Any]:
_force_statement_timeout(pg_conn.conn(), 0)
pg_conn.force_statement_timeout(0)
if observation_space and observation_space.require_metrics():
initial_metrics = observation_space.construct_online(pg_conn.conn())

if query_timeout > 0:
_force_statement_timeout(pg_conn.conn(), query_timeout * 1000)
else:
assert (
query_timeout == 0
), f'Setting query_timeout to 0 indicates "timeout". However, setting query_timeout ({query_timeout}) < 0 is a bug.'

qid_runtime, did_time_out, explain_data = pg_conn.time_query(query, query_timeout)

# Wipe the statement timeout.
_force_statement_timeout(pg_conn.conn(), 0)
if observation_space and observation_space.require_metrics():
final_metrics = observation_space.construct_online(pg_conn.conn())
diff = observation_space.state_delta(initial_metrics, final_metrics)
Expand Down

0 comments on commit 31f782d

Please sign in to comment.