Skip to content

Commit 1ff4dd5

Browse files
committed
refactor: add failsafe mechanism for the stable compiler configuration
1 parent 1c5d411 commit 1ff4dd5

File tree

3 files changed

+239
-69
lines changed

3 files changed

+239
-69
lines changed

bigframes/core/compile/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,30 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Any
16+
from typing import Literal
1717

18-
from bigframes import options
1918
from bigframes.core.compile.api import test_only_ibis_inferred_schema
2019
from bigframes.core.compile.configs import CompileRequest, CompileResult
2120

2221

23-
def compiler() -> Any:
24-
"""Returns the appropriate compiler module based on session options."""
25-
if options.experiments.sql_compiler == "experimental":
22+
def compile_sql(
23+
request: CompileRequest,
24+
compiler_name: Literal["sqlglot", "ibis"] = "sqlglot",
25+
) -> CompileResult:
26+
"""Compiles a BigFrameNode according to the request into SQL."""
27+
if compiler_name == "sqlglot":
2628
import bigframes.core.compile.sqlglot.compiler as sqlglot_compiler
2729

28-
return sqlglot_compiler
30+
return sqlglot_compiler.compile_sql(request)
2931
else:
3032
import bigframes.core.compile.ibis_compiler.ibis_compiler as ibis_compiler
3133

32-
return ibis_compiler
34+
return ibis_compiler.compile_sql(request)
3335

3436

3537
__all__ = [
3638
"test_only_ibis_inferred_schema",
3739
"CompileRequest",
3840
"CompileResult",
39-
"compiler",
41+
"compile_sql",
4042
]

bigframes/session/bq_caching_executor.py

Lines changed: 119 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
import math
1818
import threading
1919
from typing import Literal, Mapping, Optional, Sequence, Tuple
20+
import uuid
21+
import warnings
2022
import weakref
2123

2224
import google.api_core.exceptions
2325
from google.cloud import bigquery
2426
import google.cloud.bigquery.job as bq_job
2527
import google.cloud.bigquery.table as bq_table
2628
import google.cloud.bigquery_storage_v1
29+
import google.cloud.exceptions
2730

2831
import bigframes
2932
from bigframes import exceptions as bfe
@@ -160,6 +163,43 @@ def __init__(
160163
)
161164
self._upload_lock = threading.Lock()
162165

166+
def _compile(
167+
self,
168+
node: nodes.BigFrameNode,
169+
*,
170+
ordered: bool = False,
171+
peek: Optional[int] = None,
172+
materialize_all_order_keys: bool = False,
173+
compiler_name: Literal["sqlglot", "ibis"] = "sqlglot",
174+
) -> compile.CompileResult:
175+
return compile.compile_sql(
176+
compile.CompileRequest(
177+
node,
178+
sort_rows=ordered,
179+
peek_count=peek,
180+
materialize_all_order_keys=materialize_all_order_keys,
181+
),
182+
compiler_name=compiler_name,
183+
)
184+
185+
def _with_fallback(self, run_fn):
186+
compiler_option = bigframes.options.experiments.sql_compiler
187+
if compiler_option == "legacy":
188+
return run_fn("ibis")
189+
elif compiler_option == "experimental":
190+
return run_fn("sqlglot")
191+
else: # stable
192+
compiler_id = f"{uuid.uuid1().hex[:12]}"
193+
try:
194+
return run_fn("sqlglot", compiler_id=compiler_id)
195+
except google.cloud.exceptions.BadRequest as e:
196+
msg = bfe.format_message(
197+
f"Compiler ID {compiler_id}: BadRequest on sqlglot. "
198+
f"Falling back to ibis. Details: {e.message}"
199+
)
200+
warnings.warn(msg, category=UserWarning)
201+
return run_fn("ibis", compiler_id=compiler_id)
202+
163203
def to_sql(
164204
self,
165205
array_value: bigframes.core.ArrayValue,
@@ -175,9 +215,7 @@ def to_sql(
175215
else array_value.node
176216
)
177217
node = self._substitute_large_local_sources(node)
178-
compiled = compile.compiler().compile_sql(
179-
compile.CompileRequest(node, sort_rows=ordered)
180-
)
218+
compiled = self._compile(node, ordered=ordered)
181219
return compiled.sql
182220

183221
def execute(
@@ -293,46 +331,56 @@ def _export_gbq(
293331
# validate destination table
294332
existing_table = self._maybe_find_existing_table(spec)
295333

296-
compiled = compile.compiler().compile_sql(
297-
compile.CompileRequest(plan, sort_rows=False)
298-
)
299-
sql = compiled.sql
334+
def run_with_compiler(compiler_name, compiler_id=None):
335+
compiled = self._compile(plan, ordered=False, compiler_name=compiler_name)
336+
sql = compiled.sql
300337

301-
if (existing_table is not None) and _if_schema_match(
302-
existing_table.schema, array_value.schema
303-
):
304-
# b/409086472: Uses DML for table appends and replacements to avoid
305-
# BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits:
306-
# https://cloud.google.com/bigquery/quotas#standard_tables
307-
job_config = bigquery.QueryJobConfig()
338+
if (existing_table is not None) and _if_schema_match(
339+
existing_table.schema, array_value.schema
340+
):
341+
# b/409086472: Uses DML for table appends and replacements to avoid
342+
# BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits:
343+
# https://cloud.google.com/bigquery/quotas#standard_tables
344+
job_config = bigquery.QueryJobConfig()
345+
346+
ir = sqlglot_ir.SQLGlotIR.from_unparsed_query(sql)
347+
if spec.if_exists == "append":
348+
sql = sg_sql.to_sql(
349+
sg_sql.insert(ir.expr.as_select_all(), spec.table)
350+
)
351+
else: # for "replace"
352+
assert spec.if_exists == "replace"
353+
sql = sg_sql.to_sql(
354+
sg_sql.replace(ir.expr.as_select_all(), spec.table)
355+
)
356+
else:
357+
dispositions = {
358+
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
359+
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
360+
"append": bigquery.WriteDisposition.WRITE_APPEND,
361+
}
362+
job_config = bigquery.QueryJobConfig(
363+
write_disposition=dispositions[spec.if_exists],
364+
destination=spec.table,
365+
clustering_fields=spec.cluster_cols if spec.cluster_cols else None,
366+
)
308367

309-
ir = sqlglot_ir.SQLGlotIR.from_unparsed_query(sql)
310-
if spec.if_exists == "append":
311-
sql = sg_sql.to_sql(sg_sql.insert(ir.expr.as_select_all(), spec.table))
312-
else: # for "replace"
313-
assert spec.if_exists == "replace"
314-
sql = sg_sql.to_sql(sg_sql.replace(ir.expr.as_select_all(), spec.table))
315-
else:
316-
dispositions = {
317-
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
318-
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
319-
"append": bigquery.WriteDisposition.WRITE_APPEND,
320-
}
321-
job_config = bigquery.QueryJobConfig(
322-
write_disposition=dispositions[spec.if_exists],
323-
destination=spec.table,
324-
clustering_fields=spec.cluster_cols if spec.cluster_cols else None,
368+
# Attach data type usage to the job labels
369+
job_config.labels["bigframes-dtypes"] = compiled.encoded_type_refs
370+
job_config.labels["bigframes-compiler"] = (
371+
f"{compiler_name}-{compiler_id}" if compiler_id else compiler_name
325372
)
326373

327-
# Attach data type usage to the job labels
328-
job_config.labels["bigframes-dtypes"] = compiled.encoded_type_refs
329-
# TODO(swast): plumb through the api_name of the user-facing api that
330-
# caused this query.
331-
iterator, job = self._run_execute_query(
332-
sql=sql,
333-
job_config=job_config,
334-
session=array_value.session,
335-
)
374+
# TODO(swast): plumb through the api_name of the user-facing api that
375+
# caused this query.
376+
iterator, job = self._run_execute_query(
377+
sql=sql,
378+
job_config=job_config,
379+
session=array_value.session,
380+
)
381+
return iterator, job
382+
383+
iterator, job = self._with_fallback(run_with_compiler)
336384

337385
has_timedelta_col = any(
338386
t == bigframes.dtypes.TIMEDELTA_DTYPE for t in array_value.schema.dtypes
@@ -648,34 +696,44 @@ def _execute_plan_gbq(
648696
]
649697
cluster_cols = cluster_cols[:_MAX_CLUSTER_COLUMNS]
650698

651-
compiled = compile.compiler().compile_sql(
652-
compile.CompileRequest(
699+
def run_with_compiler(compiler_name, compiler_id=None):
700+
compiled = self._compile(
653701
plan,
654-
sort_rows=ordered,
655-
peek_count=peek,
702+
ordered=ordered,
703+
peek=peek,
656704
materialize_all_order_keys=(cache_spec is not None),
705+
compiler_name=compiler_name,
657706
)
658-
)
659-
# might have more columns than og schema, for hidden ordering columns
660-
compiled_schema = compiled.sql_schema
707+
# might have more columns than og schema, for hidden ordering columns
708+
compiled_schema = compiled.sql_schema
661709

662-
destination_table: Optional[bigquery.TableReference] = None
710+
destination_table: Optional[bigquery.TableReference] = None
663711

664-
job_config = bigquery.QueryJobConfig()
665-
if create_table:
666-
destination_table = self.storage_manager.create_temp_table(
667-
compiled_schema, cluster_cols
712+
job_config = bigquery.QueryJobConfig()
713+
if create_table:
714+
destination_table = self.storage_manager.create_temp_table(
715+
compiled_schema, cluster_cols
716+
)
717+
job_config.destination = destination_table
718+
719+
# Attach data type usage to the job labels
720+
job_config.labels["bigframes-dtypes"] = compiled.encoded_type_refs
721+
job_config.labels["bigframes-compiler"] = (
722+
f"{compiler_name}-{compiler_id}" if compiler_id else compiler_name
668723
)
669-
job_config.destination = destination_table
670-
671-
# Attach data type usage to the job labels
672-
job_config.labels["bigframes-dtypes"] = compiled.encoded_type_refs
673-
iterator, query_job = self._run_execute_query(
674-
sql=compiled.sql,
675-
job_config=job_config,
676-
query_with_job=(destination_table is not None),
677-
session=plan.session,
678-
)
724+
725+
iterator, query_job = self._run_execute_query(
726+
sql=compiled.sql,
727+
job_config=job_config,
728+
query_with_job=(destination_table is not None),
729+
session=plan.session,
730+
)
731+
return iterator, query_job, compiled
732+
733+
iterator, query_job, compiled = self._with_fallback(run_with_compiler)
734+
735+
# might have more columns than og schema, for hidden ordering columns
736+
compiled_schema = compiled.sql_schema
679737

680738
# we could actually cache even when caching is not explicitly requested, but being conservative for now
681739
result_bq_data = None
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
import google.cloud.bigquery as bigquery
18+
import google.cloud.exceptions
19+
import pyarrow as pa
20+
import pytest
21+
22+
import bigframes
23+
import bigframes.core.nodes as nodes
24+
import bigframes.core.schema as schemata
25+
from bigframes.session.bq_caching_executor import BigQueryCachingExecutor
26+
27+
28+
@pytest.fixture
29+
def mock_executor():
30+
bqclient = mock.create_autospec(bigquery.Client)
31+
bqclient.project = "test-project"
32+
storage_manager = mock.Mock()
33+
bqstoragereadclient = mock.Mock()
34+
loader = mock.Mock()
35+
publisher = mock.Mock()
36+
return BigQueryCachingExecutor(
37+
bqclient, storage_manager, bqstoragereadclient, loader, publisher=publisher
38+
)
39+
40+
41+
def test_compiler_with_fallback_legacy(mock_executor):
42+
run_fn = mock.Mock()
43+
with bigframes.option_context("experiments.sql_compiler", "legacy"):
44+
mock_executor._with_fallback(run_fn)
45+
run_fn.assert_called_once_with("ibis")
46+
47+
48+
def test_compiler_with_fallback_experimental(mock_executor):
49+
run_fn = mock.Mock()
50+
with bigframes.option_context("experiments.sql_compiler", "experimental"):
51+
mock_executor._with_fallback(run_fn)
52+
run_fn.assert_called_once_with("sqlglot")
53+
54+
55+
def test_compiler_with_fallback_stable_success(mock_executor):
56+
run_fn = mock.Mock()
57+
with bigframes.option_context("experiments.sql_compiler", "stable"):
58+
mock_executor._with_fallback(run_fn)
59+
run_fn.assert_called_once_with("sqlglot", compiler_id=mock.ANY)
60+
61+
62+
def test_compiler_execute_plan_gbq_fallback_labels(mock_executor):
63+
plan = mock.create_autospec(nodes.BigFrameNode)
64+
plan.schema = schemata.ArraySchema(tuple())
65+
plan.session = None
66+
67+
# Mock prepare_plan
68+
mock_executor.prepare_plan = mock.Mock(return_value=plan)
69+
70+
# Mock _compile
71+
from bigframes.core.compile.configs import CompileResult
72+
73+
fake_compiled = CompileResult(
74+
sql="SELECT 1", sql_schema=[], row_order=None, encoded_type_refs="fake_refs"
75+
)
76+
mock_executor._compile = mock.Mock(return_value=fake_compiled)
77+
78+
# Mock _run_execute_query to fail first time, then succeed
79+
mock_iterator = mock.Mock()
80+
mock_iterator.total_rows = 0
81+
mock_iterator.to_arrow.return_value = pa.Table.from_arrays([], names=[])
82+
mock_query_job = mock.Mock(spec=bigquery.QueryJob)
83+
mock_query_job.destination = None
84+
85+
error = google.cloud.exceptions.BadRequest("failed")
86+
error.job = mock.Mock(spec=bigquery.QueryJob) # type: ignore
87+
error.job.job_id = "failed_job_id" # type: ignore
88+
89+
mock_executor._run_execute_query = mock.Mock(
90+
side_effect=[error, (mock_iterator, mock_query_job)]
91+
)
92+
93+
with bigframes.option_context("experiments.sql_compiler", "stable"), pytest.warns(
94+
UserWarning, match="Falling back to ibis"
95+
):
96+
mock_executor._execute_plan_gbq(plan, ordered=False, must_create_table=False)
97+
98+
# Verify labels for both calls
99+
assert mock_executor._run_execute_query.call_count == 2
100+
101+
call_1_kwargs = mock_executor._run_execute_query.call_args_list[0][1]
102+
call_2_kwargs = mock_executor._run_execute_query.call_args_list[1][1]
103+
104+
label_1 = call_1_kwargs["job_config"].labels["bigframes-compiler"]
105+
label_2 = call_2_kwargs["job_config"].labels["bigframes-compiler"]
106+
107+
assert label_1.startswith("sqlglot-")
108+
assert label_2.startswith("ibis-")
109+
# Both should have the same compiler_id suffix
110+
assert label_1.split("-")[1] == label_2.split("-")[1]

0 commit comments

Comments
 (0)