Skip to content

Commit 460dbdc

Browse files
refactor: Improve cache encapsulation (#2525)
1 parent 0ebc733 commit 460dbdc

File tree

4 files changed

+124
-77
lines changed

4 files changed

+124
-77
lines changed

bigframes/core/tree_properties.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515

1616
import functools
1717
import itertools
18-
from typing import Callable, Dict, Optional, Sequence
18+
from typing import Callable, Dict, Optional, Sequence, TYPE_CHECKING
1919

2020
import bigframes.core.nodes as nodes
2121

22+
if TYPE_CHECKING:
23+
import bigframes.session.execution_cache as execution_cache
24+
2225

2326
def is_trivially_executable(node: nodes.BigFrameNode) -> bool:
2427
if local_only(node):
@@ -65,7 +68,7 @@ def select_cache_target(
6568
root: nodes.BigFrameNode,
6669
min_complexity: float,
6770
max_complexity: float,
68-
cache: dict[nodes.BigFrameNode, nodes.BigFrameNode],
71+
cache: execution_cache.ExecutionCache,
6972
heuristic: Callable[[int, int], float],
7073
) -> Optional[nodes.BigFrameNode]:
7174
"""Take tree, and return candidate nodes with (# of occurences, post-caching planning complexity).
@@ -75,7 +78,7 @@ def select_cache_target(
7578

7679
@functools.cache
7780
def _with_caching(subtree: nodes.BigFrameNode) -> nodes.BigFrameNode:
78-
return nodes.top_down(subtree, lambda x: cache.get(x, x))
81+
return cache.subsitute_cached_subplans(subtree)
7982

8083
def _combine_counts(
8184
left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int]
@@ -106,6 +109,7 @@ def _node_counts_inner(
106109
if len(node_counts) == 0:
107110
raise ValueError("node counts should be non-zero")
108111

112+
# for each considered node, calculate heuristic value, and return node with max value
109113
return max(
110114
node_counts.keys(),
111115
key=lambda node: heuristic(

bigframes/session/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,20 @@ def __init__(
265265
metrics=self._metrics,
266266
publisher=self._publisher,
267267
)
268+
269+
labels = {}
270+
if not self._strictly_ordered:
271+
labels["bigframes-mode"] = "unordered"
272+
268273
self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor(
269274
bqclient=self._clients_provider.bqclient,
270275
bqstoragereadclient=self._clients_provider.bqstoragereadclient,
271276
loader=self._loader,
272277
storage_manager=self._temp_storage_manager,
273-
strictly_ordered=self._strictly_ordered,
274278
metrics=self._metrics,
275279
enable_polars_execution=context.enable_polars_execution,
276280
publisher=self._publisher,
281+
labels=labels,
277282
)
278283

279284
def __del__(self):

bigframes/session/bq_caching_executor.py

Lines changed: 23 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import math
1818
import threading
1919
from typing import Literal, Mapping, Optional, Sequence, Tuple
20-
import weakref
2120

2221
import google.api_core.exceptions
2322
from google.cloud import bigquery
@@ -47,6 +46,7 @@
4746
semi_executor,
4847
)
4948
import bigframes.session._io.bigquery as bq_io
49+
import bigframes.session.execution_cache as execution_cache
5050
import bigframes.session.execution_spec as ex_spec
5151
import bigframes.session.metrics
5252
import bigframes.session.planner
@@ -59,58 +59,6 @@
5959
_MAX_CLUSTER_COLUMNS = 4
6060
MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G
6161

62-
SourceIdMapping = Mapping[str, str]
63-
64-
65-
class ExecutionCache:
66-
def __init__(self):
67-
# current assumption is only 1 cache of a given node
68-
# in future, might have multiple caches, with different layout, localities
69-
self._cached_executions: weakref.WeakKeyDictionary[
70-
nodes.BigFrameNode, nodes.CachedTableNode
71-
] = weakref.WeakKeyDictionary()
72-
self._uploaded_local_data: weakref.WeakKeyDictionary[
73-
local_data.ManagedArrowTable,
74-
tuple[bq_data.BigqueryDataSource, SourceIdMapping],
75-
] = weakref.WeakKeyDictionary()
76-
77-
@property
78-
def mapping(self) -> Mapping[nodes.BigFrameNode, nodes.BigFrameNode]:
79-
return self._cached_executions
80-
81-
def cache_results_table(
82-
self,
83-
original_root: nodes.BigFrameNode,
84-
data: bq_data.BigqueryDataSource,
85-
):
86-
# Assumption: GBQ cached table uses field name as bq column name
87-
scan_list = nodes.ScanList(
88-
tuple(
89-
nodes.ScanItem(field.id, field.id.sql) for field in original_root.fields
90-
)
91-
)
92-
cached_replacement = nodes.CachedTableNode(
93-
source=data,
94-
scan_list=scan_list,
95-
table_session=original_root.session,
96-
original_node=original_root,
97-
)
98-
assert original_root.schema == cached_replacement.schema
99-
self._cached_executions[original_root] = cached_replacement
100-
101-
def cache_remote_replacement(
102-
self,
103-
local_data: local_data.ManagedArrowTable,
104-
bq_data: bq_data.BigqueryDataSource,
105-
):
106-
# bq table has one extra column for offsets, those are implicit for local data
107-
assert len(local_data.schema.items) + 1 == len(bq_data.table.physical_schema)
108-
mapping = {
109-
local_data.schema.items[i].column: bq_data.table.physical_schema[i].name
110-
for i in range(len(local_data.schema))
111-
}
112-
self._uploaded_local_data[local_data] = (bq_data, mapping)
113-
11462

11563
class BigQueryCachingExecutor(executor.Executor):
11664
"""Computes BigFrames values using BigQuery Engine.
@@ -128,20 +76,20 @@ def __init__(
12876
bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient,
12977
loader: loader.GbqDataLoader,
13078
*,
131-
strictly_ordered: bool = True,
13279
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
13380
enable_polars_execution: bool = False,
13481
publisher: bigframes.core.events.Publisher,
82+
labels: Mapping[str, str] = {},
13583
):
13684
self.bqclient = bqclient
13785
self.storage_manager = storage_manager
138-
self.strictly_ordered: bool = strictly_ordered
139-
self.cache: ExecutionCache = ExecutionCache()
86+
self.cache: execution_cache.ExecutionCache = execution_cache.ExecutionCache()
14087
self.metrics = metrics
14188
self.loader = loader
14289
self.bqstoragereadclient = bqstoragereadclient
14390
self._enable_polars_execution = enable_polars_execution
14491
self._publisher = publisher
92+
self._labels = labels
14593

14694
# TODO(tswast): Send events from semi-executors, too.
14795
self._semi_executors: Sequence[semi_executor.SemiExecutor] = (
@@ -410,8 +358,8 @@ def _run_execute_query(
410358
bigframes.options.compute.maximum_bytes_billed
411359
)
412360

413-
if not self.strictly_ordered:
414-
job_config.labels["bigframes-mode"] = "unordered"
361+
if self._labels:
362+
job_config.labels.update(self._labels)
415363

416364
try:
417365
# Trick the type checker into thinking we got a literal.
@@ -450,9 +398,6 @@ def _run_execute_query(
450398
else:
451399
raise
452400

453-
def replace_cached_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
454-
return nodes.top_down(node, lambda x: self.cache.mapping.get(x, x))
455-
456401
def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
457402
"""
458403
Can the block be evaluated very cheaply?
@@ -482,7 +427,7 @@ def prepare_plan(
482427
):
483428
self._simplify_with_caching(plan)
484429

485-
plan = self.replace_cached_subtrees(plan)
430+
plan = self.cache.subsitute_cached_subplans(plan)
486431
plan = rewrite.column_pruning(plan)
487432
plan = plan.top_down(rewrite.fold_row_counts)
488433

@@ -527,7 +472,7 @@ def _cache_with_session_awareness(
527472
self._cache_with_cluster_cols(
528473
bigframes.core.ArrayValue(target), cluster_cols_sql_names
529474
)
530-
elif self.strictly_ordered:
475+
elif not target.order_ambiguous:
531476
self._cache_with_offsets(bigframes.core.ArrayValue(target))
532477
else:
533478
self._cache_with_cluster_cols(bigframes.core.ArrayValue(target), [])
@@ -552,7 +497,7 @@ def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool:
552497
node,
553498
min_complexity=(QUERY_COMPLEXITY_LIMIT / 500),
554499
max_complexity=QUERY_COMPLEXITY_LIMIT,
555-
cache=dict(self.cache.mapping),
500+
cache=self.cache,
556501
# Heuristic: subtree_compleixty * (copies of subtree)^2
557502
heuristic=lambda complexity, count: math.log(complexity)
558503
+ 2 * math.log(count),
@@ -581,32 +526,37 @@ def _substitute_large_local_sources(self, original_root: nodes.BigFrameNode):
581526
def map_local_scans(node: nodes.BigFrameNode):
582527
if not isinstance(node, nodes.ReadLocalNode):
583528
return node
584-
if node.local_data_source not in self.cache._uploaded_local_data:
585-
return node
586-
bq_source, source_mapping = self.cache._uploaded_local_data[
529+
uploaded_local_data = self.cache.get_uploaded_local_data(
587530
node.local_data_source
588-
]
589-
scan_list = node.scan_list.remap_source_ids(source_mapping)
531+
)
532+
if uploaded_local_data is None:
533+
return node
534+
535+
scan_list = node.scan_list.remap_source_ids(
536+
uploaded_local_data.source_mapping
537+
)
590538
# offsets_col isn't part of ReadTableNode, so emulate by adding to end of scan_list
591539
if node.offsets_col is not None:
592540
# Offsets are always implicitly the final column of uploaded data
593541
# See: Loader.load_data
594542
scan_list = scan_list.append(
595-
bq_source.table.physical_schema[-1].name,
543+
uploaded_local_data.bq_source.table.physical_schema[-1].name,
596544
bigframes.dtypes.INT_DTYPE,
597545
node.offsets_col,
598546
)
599-
return nodes.ReadTableNode(bq_source, scan_list, node.session)
547+
return nodes.ReadTableNode(
548+
uploaded_local_data.bq_source, scan_list, node.session
549+
)
600550

601551
return original_root.bottom_up(map_local_scans)
602552

603553
def _upload_local_data(self, local_table: local_data.ManagedArrowTable):
604-
if local_table in self.cache._uploaded_local_data:
554+
if self.cache.get_uploaded_local_data(local_table) is not None:
605555
return
606556
# Lock prevents concurrent repeated work, but slows things down.
607557
# Might be better as a queue and a worker thread
608558
with self._upload_lock:
609-
if local_table not in self.cache._uploaded_local_data:
559+
if self.cache.get_uploaded_local_data(local_table) is None:
610560
uploaded = self.loader.load_data_or_write_data(
611561
local_table, bigframes.core.guid.generate_guid()
612562
)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
from typing import Mapping, Optional
19+
import weakref
20+
21+
from bigframes.core import bq_data, local_data, nodes
22+
23+
SourceIdMapping = Mapping[str, str]
24+
25+
26+
@dataclasses.dataclass(frozen=True)
27+
class UploadedLocalData:
28+
bq_source: bq_data.BigqueryDataSource
29+
source_mapping: SourceIdMapping
30+
31+
32+
class ExecutionCache:
33+
def __init__(self):
34+
# effectively two separate caches that don't interact
35+
self._cached_executions: weakref.WeakKeyDictionary[
36+
nodes.BigFrameNode, bq_data.BigqueryDataSource
37+
] = weakref.WeakKeyDictionary()
38+
# This upload cache is entirely independent of the plan cache.
39+
self._uploaded_local_data: weakref.WeakKeyDictionary[
40+
local_data.ManagedArrowTable,
41+
UploadedLocalData,
42+
] = weakref.WeakKeyDictionary()
43+
44+
def subsitute_cached_subplans(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode:
45+
def replace_if_cached(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
46+
if node not in self._cached_executions:
47+
return node
48+
# Assumption: GBQ cached table uses field name as bq column name
49+
scan_list = nodes.ScanList(
50+
tuple(nodes.ScanItem(field.id, field.id.sql) for field in node.fields)
51+
)
52+
bq_data = self._cached_executions[node]
53+
cached_replacement = nodes.CachedTableNode(
54+
source=bq_data,
55+
scan_list=scan_list,
56+
table_session=node.session,
57+
original_node=node,
58+
)
59+
assert node.schema == cached_replacement.schema
60+
return cached_replacement
61+
62+
return nodes.top_down(root, replace_if_cached)
63+
64+
def cache_results_table(
65+
self,
66+
original_root: nodes.BigFrameNode,
67+
data: bq_data.BigqueryDataSource,
68+
):
69+
self._cached_executions[original_root] = data
70+
71+
## Local data upload caching
72+
def cache_remote_replacement(
73+
self,
74+
local_data: local_data.ManagedArrowTable,
75+
bq_data: bq_data.BigqueryDataSource,
76+
):
77+
# bq table has one extra column for offsets, those are implicit for local data
78+
assert len(local_data.schema.items) + 1 == len(bq_data.table.physical_schema)
79+
mapping = {
80+
local_data.schema.items[i].column: bq_data.table.physical_schema[i].name
81+
for i in range(len(local_data.schema))
82+
}
83+
self._uploaded_local_data[local_data] = UploadedLocalData(bq_data, mapping)
84+
85+
def get_uploaded_local_data(
86+
self, local_data: local_data.ManagedArrowTable
87+
) -> Optional[UploadedLocalData]:
88+
return self._uploaded_local_data.get(local_data)

0 commit comments

Comments
 (0)