Skip to content

SEA: Reduce network calls for synchronous commands #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: sea-migration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 55 additions & 52 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
DeleteSessionRequest,
StatementParameter,
ExecuteStatementResponse,
GetStatementResponse,
CreateSessionResponse,
)

Expand Down Expand Up @@ -324,7 +323,7 @@ def _extract_description_from_manifest(
return columns

def _results_message_to_execute_response(
self, response: GetStatementResponse
self, response: ExecuteStatementResponse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't know that GetStatementResponse and ExecuteStatementResponse have the same fields wrt results (interchangeable).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did not realise this at first either, but this can be confirmed by comparing the response in the REST reference as well:

This does make sense logically to me as well, the purpose of the GET is to get the info related to an execution statement.

) -> ExecuteResponse:
"""
Convert a SEA response to an ExecuteResponse and extract result data.
Expand Down Expand Up @@ -358,6 +357,28 @@ def _results_message_to_execute_response(

return execute_response

def _response_to_result_set(
self, response: ExecuteStatementResponse, cursor: Cursor
) -> SeaResultSet:
"""
Convert a SEA response to a SeaResultSet.
"""

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this lazy import has perf gain (lazily module loading when needed)?


execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)

def _check_command_not_in_failed_or_closed_state(
self, state: CommandState, command_id: CommandId
) -> None:
Expand All @@ -378,7 +399,7 @@ def _check_command_not_in_failed_or_closed_state(

def _wait_until_command_done(
self, response: ExecuteStatementResponse
) -> CommandState:
) -> ExecuteStatementResponse:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems odd to me that this method does polling and still ends up return ExecuteStatementResponse. Semantically, this is a response to ExecuteRequest

"""
Wait until a command is done.
"""
Expand All @@ -388,11 +409,12 @@ def _wait_until_command_done(

while state in [CommandState.PENDING, CommandState.RUNNING]:
time.sleep(self.POLL_INTERVAL_SECONDS)
state = self.get_query_state(command_id)
response = self._poll_query(command_id)
state = response.status.state

self._check_command_not_in_failed_or_closed_state(state, command_id)

return state
return response

def execute_command(
self,
Expand Down Expand Up @@ -494,8 +516,12 @@ def execute_command(
if async_op:
return None

self._wait_until_command_done(response)
return self.get_execution_result(command_id, cursor)
if response.status.state == CommandState.SUCCEEDED:
# if the response succeeded within the wait_timeout, return the results immediately
return self._response_to_result_set(response, cursor)

response = self._wait_until_command_done(response)
return self._response_to_result_set(response, cursor)

def cancel_command(self, command_id: CommandId) -> None:
"""
Expand Down Expand Up @@ -547,18 +573,9 @@ def close_command(self, command_id: CommandId) -> None:
data=request.to_dict(),
)

def get_query_state(self, command_id: CommandId) -> CommandState:
def _poll_query(self, command_id: CommandId) -> ExecuteStatementResponse:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
Poll for the current command info.
"""

if command_id.backend_type != BackendType.SEA:
Expand All @@ -574,9 +591,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = ExecuteStatementResponse.from_dict(response_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay to return ExecuteResponse as a result of Polling?


# Parse the response
response = GetStatementResponse.from_dict(response_data)
return response

def get_query_state(self, command_id: CommandId) -> CommandState:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
"""

response = self._poll_query(command_id)
return response.status.state

def get_execution_result(
Expand All @@ -598,38 +631,8 @@ def get_execution_result(
ValueError: If the command ID is invalid
"""

if command_id.backend_type != BackendType.SEA:
raise ValueError("Not a valid SEA command ID")

sea_statement_id = command_id.to_sea_statement_id()
if sea_statement_id is None:
raise ValueError("Not a valid SEA command ID")

# Create the request model
request = GetStatementRequest(statement_id=sea_statement_id)

# Get the statement result
response_data = self.http_client._make_request(
method="GET",
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = GetStatementResponse.from_dict(response_data)

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)
response = self._poll_query(command_id)
return self._response_to_result_set(response, cursor)

# == Metadata Operations ==

Expand Down
2 changes: 0 additions & 2 deletions src/databricks/sql/backend/sea/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from databricks.sql.backend.sea.models.responses import (
ExecuteStatementResponse,
GetStatementResponse,
CreateSessionResponse,
)

Expand All @@ -47,6 +46,5 @@
"DeleteSessionRequest",
# Response models
"ExecuteStatementResponse",
"GetStatementResponse",
"CreateSessionResponse",
]
20 changes: 0 additions & 20 deletions src/databricks/sql/backend/sea/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
)


@dataclass
class GetStatementResponse:
"""Representation of the response from getting information about a statement."""

statement_id: str
status: StatementStatus
manifest: ResultManifest
result: ResultData

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
"""Create a GetStatementResponse from a dictionary."""
return cls(
statement_id=data.get("statement_id", ""),
status=_parse_status(data),
manifest=_parse_manifest(data),
result=_parse_result(data),
)


Comment on lines -127 to -146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you would still need this? Probably part of different PRs

@dataclass
class CreateSessionResponse:
"""Representation of the response from creating a new session."""
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_command_execution_sync(
mock_http_client._make_request.return_value = execute_response

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
result = sea_client.execute_command(
operation="SELECT 1",
Expand All @@ -242,9 +242,6 @@ def test_command_execution_sync(
enforce_embedded_schema_correctness=False,
)
assert result == "mock_result_set"
cmd_id_arg = mock_get_result.call_args[0][0]
assert isinstance(cmd_id_arg, CommandId)
assert cmd_id_arg.guid == "test-statement-123"

# Test with invalid session ID
with pytest.raises(ValueError) as excinfo:
Expand Down Expand Up @@ -332,7 +329,7 @@ def test_command_execution_advanced(
mock_http_client._make_request.side_effect = [initial_response, poll_response]

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
with patch("time.sleep"):
result = sea_client.execute_command(
Expand Down Expand Up @@ -360,7 +357,7 @@ def test_command_execution_advanced(
dbsql_param = IntegerParameter(name="param1", value=1)
param = dbsql_param.as_tspark_param(named=True)

with patch.object(sea_client, "get_execution_result"):
with patch.object(sea_client, "_response_to_result_set"):
sea_client.execute_command(
operation="SELECT * FROM table WHERE col = :param1",
session_id=sea_session_id,
Expand Down
Loading