Skip to content

Commit 11fd2f1

Browse files
authored
refactor: use sqlglot to build load_data ddl (#2515)
Fixes internal issue 418025765 🦕
1 parent 265376f commit 11fd2f1

File tree

15 files changed

+227
-184
lines changed

15 files changed

+227
-184
lines changed

bigframes/bigquery/_operations/io.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import pandas as pd
2020

2121
from bigframes.bigquery._operations.table import _get_table_metadata
22+
import bigframes.core.compile.sqlglot.sql as sql
2223
import bigframes.core.logging.log_adapter as log_adapter
23-
import bigframes.core.sql.io
2424
import bigframes.session
2525

2626

@@ -73,7 +73,7 @@ def load_data(
7373
"""
7474
import bigframes.pandas as bpd
7575

76-
sql = bigframes.core.sql.io.load_data_ddl(
76+
load_data_expr = sql.load_data(
7777
table_name=table_name,
7878
write_disposition=write_disposition,
7979
columns=columns,
@@ -84,11 +84,12 @@ def load_data(
8484
with_partition_columns=with_partition_columns,
8585
connection_name=connection_name,
8686
)
87+
sql_text = sql.to_sql(load_data_expr)
8788

8889
if session is None:
89-
bpd.read_gbq_query(sql)
90+
bpd.read_gbq_query(sql_text)
9091
session = bpd.get_global_session()
9192
else:
92-
session.read_gbq_query(sql)
93+
session.read_gbq_query(sql_text)
9394

9495
return _get_table_metadata(bqclient=session.bqclient, table_name=table_name)

bigframes/core/compile/sqlglot/sql/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
table,
2323
to_sql,
2424
)
25+
from bigframes.core.compile.sqlglot.sql.ddl import load_data
2526
from bigframes.core.compile.sqlglot.sql.dml import insert, replace
2627

2728
__all__ = [
@@ -33,6 +34,8 @@
3334
"literal",
3435
"table",
3536
"to_sql",
37+
# From ddl.py
38+
"load_data",
3639
# From dml.py
3740
"insert",
3841
"replace",
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Mapping, Optional, Union
18+
19+
import bigframes_vendored.sqlglot as sg
20+
import bigframes_vendored.sqlglot.expressions as sge
21+
22+
from bigframes.core.compile.sqlglot.sql import base
23+
24+
25+
def _loaddata_sql(self: sg.Generator, expression: sge.LoadData) -> str:
26+
out = ["LOAD DATA"]
27+
if expression.args.get("overwrite"):
28+
out.append("OVERWRITE")
29+
30+
out.append(f"INTO {self.sql(expression, 'this').strip()}")
31+
32+
# We ignore inpath as it's just a dummy to satisfy sqlglot requirements
33+
# but BigQuery uses FROM FILES instead.
34+
35+
columns = self.sql(expression, "columns").strip()
36+
if columns:
37+
out.append(columns)
38+
39+
partition_by = self.sql(expression, "partition_by").strip()
40+
if partition_by:
41+
out.append(partition_by)
42+
43+
cluster_by = self.sql(expression, "cluster_by").strip()
44+
if cluster_by:
45+
out.append(cluster_by)
46+
47+
options = self.sql(expression, "options").strip()
48+
if options:
49+
out.append(options)
50+
51+
from_files = self.sql(expression, "from_files").strip()
52+
if from_files:
53+
out.append(f"FROM FILES {from_files}")
54+
55+
with_partition_columns = self.sql(expression, "with_partition_columns").strip()
56+
if with_partition_columns:
57+
out.append(f"WITH PARTITION COLUMNS {with_partition_columns}")
58+
59+
connection = self.sql(expression, "connection").strip()
60+
if connection:
61+
out.append(f"WITH CONNECTION {connection}")
62+
63+
return " ".join(out)
64+
65+
66+
# Register the transform for BigQuery generator
67+
sg.dialects.bigquery.BigQuery.Generator.TRANSFORMS[sge.LoadData] = _loaddata_sql
68+
69+
70+
def load_data(
71+
table_name: str,
72+
*,
73+
write_disposition: str = "INTO",
74+
columns: Optional[Mapping[str, str]] = None,
75+
partition_by: Optional[list[str]] = None,
76+
cluster_by: Optional[list[str]] = None,
77+
table_options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
78+
from_files_options: Mapping[str, Union[str, int, float, bool, list]],
79+
with_partition_columns: Optional[Mapping[str, str]] = None,
80+
connection_name: Optional[str] = None,
81+
) -> sge.LoadData:
82+
"""Generates the LOAD DATA DDL statement."""
83+
# We use a Table with a simple identifier for the table name.
84+
# Quoting is handled by the dialect.
85+
table_expr = sge.Table(this=base.identifier(table_name))
86+
87+
sge_columns = (
88+
sge.Schema(
89+
this=None,
90+
expressions=[
91+
sge.ColumnDef(
92+
this=base.identifier(name),
93+
kind=sge.DataType.build(typ, dialect="bigquery"),
94+
)
95+
for name, typ in columns.items()
96+
],
97+
)
98+
if columns
99+
else None
100+
)
101+
102+
sge_partition_by = (
103+
sge.PartitionedByProperty(
104+
this=base.identifier(partition_by[0])
105+
if len(partition_by) == 1
106+
else sge.Tuple(expressions=[base.identifier(col) for col in partition_by])
107+
)
108+
if partition_by
109+
else None
110+
)
111+
112+
sge_cluster_by = (
113+
sge.Cluster(expressions=[base.identifier(col) for col in cluster_by])
114+
if cluster_by
115+
else None
116+
)
117+
118+
sge_table_options = (
119+
sge.Properties(
120+
expressions=[
121+
sge.Property(this=base.identifier(k), value=base.literal(v))
122+
for k, v in table_options.items()
123+
]
124+
)
125+
if table_options
126+
else None
127+
)
128+
129+
sge_from_files = sge.Tuple(
130+
expressions=[
131+
sge.Property(this=base.identifier(k), value=base.literal(v))
132+
for k, v in from_files_options.items()
133+
]
134+
)
135+
136+
sge_with_partition_columns = (
137+
sge.Schema(
138+
this=None,
139+
expressions=[
140+
sge.ColumnDef(
141+
this=base.identifier(name),
142+
kind=sge.DataType.build(typ, dialect="bigquery"),
143+
)
144+
for name, typ in with_partition_columns.items()
145+
],
146+
)
147+
if with_partition_columns
148+
else None
149+
)
150+
151+
sge_connection = base.identifier(connection_name) if connection_name else None
152+
153+
return sge.LoadData(
154+
this=table_expr,
155+
overwrite=(write_disposition == "OVERWRITE"),
156+
inpath=sge.convert("fake"), # satisfy sqlglot's required inpath arg
157+
columns=sge_columns,
158+
partition_by=sge_partition_by,
159+
cluster_by=sge_cluster_by,
160+
options=sge_table_options,
161+
from_files=sge_from_files,
162+
with_partition_columns=sge_with_partition_columns,
163+
connection=sge_connection,
164+
)

bigframes/core/sql/io.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

tests/unit/bigquery/_operations/test_io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import pytest
1818

1919
import bigframes.bigquery._operations.io
20-
import bigframes.core.sql.io
2120
import bigframes.session
2221

2322

@@ -36,6 +35,6 @@ def test_load_data(get_table_metadata_mock, mock_session):
3635
)
3736
mock_session.read_gbq_query.assert_called_once()
3837
generated_sql = mock_session.read_gbq_query.call_args[0][0]
39-
expected = "LOAD DATA INTO my-project.my_dataset.my_table (col1 INT64, col2 STRING) FROM FILES (format = 'CSV', uris = ['gs://bucket/path*'])"
38+
expected = "LOAD DATA INTO `my-project.my_dataset.my_table` (\n `col1` INT64,\n `col2` STRING\n) FROM FILES (format='CSV', uris=['gs://bucket/path*'])"
4039
assert generated_sql == expected
4140
get_table_metadata_mock.assert_called_once()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
LOAD DATA OVERWRITE INTO `my-project.my_dataset.my_table` (
2+
`col1` INT64,
3+
`col2` STRING
4+
) PARTITION BY `date_col` CLUSTER BY
5+
`cluster_col` OPTIONS (
6+
description='my table'
7+
) FROM FILES (format='CSV', uris=['gs://bucket/path*']) WITH PARTITION COLUMNS (
8+
`part1` DATE,
9+
`part2` STRING
10+
) WITH CONNECTION `my-connection`
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
LOAD DATA INTO `my-project.my_dataset.my_table` FROM FILES (format='CSV', uris=['gs://bucket/path*'])

tests/unit/core/compile/sqlglot/snapshots/test_dml/test_insert_from_select/out.sql renamed to tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_insert_from_select/out.sql

File renamed without changes.

tests/unit/core/compile/sqlglot/snapshots/test_dml/test_insert_from_table/out.sql renamed to tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_insert_from_table/out.sql

File renamed without changes.

tests/unit/core/compile/sqlglot/snapshots/test_dml/test_replace_from_select/out.sql renamed to tests/unit/core/compile/sqlglot/sql/snapshots/test_dml/test_replace_from_select/out.sql

File renamed without changes.

0 commit comments

Comments
 (0)