1717import math
1818import threading
1919from typing import Literal , Mapping , Optional , Sequence , Tuple
20- import weakref
2120
2221import google .api_core .exceptions
2322from google .cloud import bigquery
4746 semi_executor ,
4847)
4948import bigframes .session ._io .bigquery as bq_io
49+ import bigframes .session .execution_cache as execution_cache
5050import bigframes .session .execution_spec as ex_spec
5151import bigframes .session .metrics
5252import bigframes .session .planner
5959_MAX_CLUSTER_COLUMNS = 4
6060MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G
6161
62- SourceIdMapping = Mapping [str , str ]
63-
64-
65- class ExecutionCache :
66- def __init__ (self ):
67- # current assumption is only 1 cache of a given node
68- # in future, might have multiple caches, with different layout, localities
69- self ._cached_executions : weakref .WeakKeyDictionary [
70- nodes .BigFrameNode , nodes .CachedTableNode
71- ] = weakref .WeakKeyDictionary ()
72- self ._uploaded_local_data : weakref .WeakKeyDictionary [
73- local_data .ManagedArrowTable ,
74- tuple [bq_data .BigqueryDataSource , SourceIdMapping ],
75- ] = weakref .WeakKeyDictionary ()
76-
77- @property
78- def mapping (self ) -> Mapping [nodes .BigFrameNode , nodes .BigFrameNode ]:
79- return self ._cached_executions
80-
81- def cache_results_table (
82- self ,
83- original_root : nodes .BigFrameNode ,
84- data : bq_data .BigqueryDataSource ,
85- ):
86- # Assumption: GBQ cached table uses field name as bq column name
87- scan_list = nodes .ScanList (
88- tuple (
89- nodes .ScanItem (field .id , field .id .sql ) for field in original_root .fields
90- )
91- )
92- cached_replacement = nodes .CachedTableNode (
93- source = data ,
94- scan_list = scan_list ,
95- table_session = original_root .session ,
96- original_node = original_root ,
97- )
98- assert original_root .schema == cached_replacement .schema
99- self ._cached_executions [original_root ] = cached_replacement
100-
101- def cache_remote_replacement (
102- self ,
103- local_data : local_data .ManagedArrowTable ,
104- bq_data : bq_data .BigqueryDataSource ,
105- ):
106- # bq table has one extra column for offsets, those are implicit for local data
107- assert len (local_data .schema .items ) + 1 == len (bq_data .table .physical_schema )
108- mapping = {
109- local_data .schema .items [i ].column : bq_data .table .physical_schema [i ].name
110- for i in range (len (local_data .schema ))
111- }
112- self ._uploaded_local_data [local_data ] = (bq_data , mapping )
113-
11462
11563class BigQueryCachingExecutor (executor .Executor ):
11664 """Computes BigFrames values using BigQuery Engine.
@@ -128,20 +76,20 @@ def __init__(
12876 bqstoragereadclient : google .cloud .bigquery_storage_v1 .BigQueryReadClient ,
12977 loader : loader .GbqDataLoader ,
13078 * ,
131- strictly_ordered : bool = True ,
13279 metrics : Optional [bigframes .session .metrics .ExecutionMetrics ] = None ,
13380 enable_polars_execution : bool = False ,
13481 publisher : bigframes .core .events .Publisher ,
82+ labels : Mapping [str , str ] = {},
13583 ):
13684 self .bqclient = bqclient
13785 self .storage_manager = storage_manager
138- self .strictly_ordered : bool = strictly_ordered
139- self .cache : ExecutionCache = ExecutionCache ()
86+ self .cache : execution_cache .ExecutionCache = execution_cache .ExecutionCache ()
14087 self .metrics = metrics
14188 self .loader = loader
14289 self .bqstoragereadclient = bqstoragereadclient
14390 self ._enable_polars_execution = enable_polars_execution
14491 self ._publisher = publisher
92+ self ._labels = labels
14593
14694 # TODO(tswast): Send events from semi-executors, too.
14795 self ._semi_executors : Sequence [semi_executor .SemiExecutor ] = (
@@ -410,8 +358,8 @@ def _run_execute_query(
410358 bigframes .options .compute .maximum_bytes_billed
411359 )
412360
413- if not self .strictly_ordered :
414- job_config .labels [ "bigframes-mode" ] = "unordered"
361+ if self ._labels :
362+ job_config .labels . update ( self . _labels )
415363
416364 try :
417365 # Trick the type checker into thinking we got a literal.
@@ -450,9 +398,6 @@ def _run_execute_query(
450398 else :
451399 raise
452400
453- def replace_cached_subtrees (self , node : nodes .BigFrameNode ) -> nodes .BigFrameNode :
454- return nodes .top_down (node , lambda x : self .cache .mapping .get (x , x ))
455-
456401 def _is_trivially_executable (self , array_value : bigframes .core .ArrayValue ):
457402 """
458403 Can the block be evaluated very cheaply?
@@ -482,7 +427,7 @@ def prepare_plan(
482427 ):
483428 self ._simplify_with_caching (plan )
484429
485- plan = self .replace_cached_subtrees (plan )
430+ plan = self .cache . subsitute_cached_subplans (plan )
486431 plan = rewrite .column_pruning (plan )
487432 plan = plan .top_down (rewrite .fold_row_counts )
488433
@@ -527,7 +472,7 @@ def _cache_with_session_awareness(
527472 self ._cache_with_cluster_cols (
528473 bigframes .core .ArrayValue (target ), cluster_cols_sql_names
529474 )
530- elif self . strictly_ordered :
475+ elif not target . order_ambiguous :
531476 self ._cache_with_offsets (bigframes .core .ArrayValue (target ))
532477 else :
533478 self ._cache_with_cluster_cols (bigframes .core .ArrayValue (target ), [])
@@ -552,7 +497,7 @@ def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool:
552497 node ,
553498 min_complexity = (QUERY_COMPLEXITY_LIMIT / 500 ),
554499 max_complexity = QUERY_COMPLEXITY_LIMIT ,
555- cache = dict ( self .cache . mapping ) ,
500+ cache = self .cache ,
556501 # Heuristic: subtree_compleixty * (copies of subtree)^2
557502 heuristic = lambda complexity , count : math .log (complexity )
558503 + 2 * math .log (count ),
@@ -581,32 +526,37 @@ def _substitute_large_local_sources(self, original_root: nodes.BigFrameNode):
581526 def map_local_scans (node : nodes .BigFrameNode ):
582527 if not isinstance (node , nodes .ReadLocalNode ):
583528 return node
584- if node .local_data_source not in self .cache ._uploaded_local_data :
585- return node
586- bq_source , source_mapping = self .cache ._uploaded_local_data [
529+ uploaded_local_data = self .cache .get_uploaded_local_data (
587530 node .local_data_source
588- ]
589- scan_list = node .scan_list .remap_source_ids (source_mapping )
531+ )
532+ if uploaded_local_data is None :
533+ return node
534+
535+ scan_list = node .scan_list .remap_source_ids (
536+ uploaded_local_data .source_mapping
537+ )
590538 # offsets_col isn't part of ReadTableNode, so emulate by adding to end of scan_list
591539 if node .offsets_col is not None :
592540 # Offsets are always implicitly the final column of uploaded data
593541 # See: Loader.load_data
594542 scan_list = scan_list .append (
595- bq_source .table .physical_schema [- 1 ].name ,
543+ uploaded_local_data . bq_source .table .physical_schema [- 1 ].name ,
596544 bigframes .dtypes .INT_DTYPE ,
597545 node .offsets_col ,
598546 )
599- return nodes .ReadTableNode (bq_source , scan_list , node .session )
547+ return nodes .ReadTableNode (
548+ uploaded_local_data .bq_source , scan_list , node .session
549+ )
600550
601551 return original_root .bottom_up (map_local_scans )
602552
603553 def _upload_local_data (self , local_table : local_data .ManagedArrowTable ):
604- if local_table in self .cache ._uploaded_local_data :
554+ if self .cache .get_uploaded_local_data ( local_table ) is not None :
605555 return
606556 # Lock prevents concurrent repeated work, but slows things down.
607557 # Might be better as a queue and a worker thread
608558 with self ._upload_lock :
609- if local_table not in self .cache ._uploaded_local_data :
559+ if self .cache .get_uploaded_local_data ( local_table ) is None :
610560 uploaded = self .loader .load_data_or_write_data (
611561 local_table , bigframes .core .guid .generate_guid ()
612562 )
0 commit comments