Skip to content
This repository was archived by the owner on Feb 27, 2025. It is now read-only.

[DRAFT] Support bulk insert into SQL Graph tables #131

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ class SQLServerBulkJdbcOptions(val params: CaseInsensitiveMap[String])
val allowEncryptedValueModifications =
params.getOrElse("allowEncryptedValueModifications", "false").toBoolean


val schemaCheckEnabled =
params.getOrElse("schemaCheckEnabled", "true").toBoolean

val hideGraphColumns =
params.getOrElse("hideGraphColumns", "true").toBoolean

// Not a feature
// Only used for internally testing data idempotency
val testDataIdempotency =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,29 @@ object BulkCopyUtils extends Logging {
*/
private[spark] def getComputedCols(
conn: Connection,
table: String): List[String] = {
val queryStr = s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');"
table: String,
hideGraphColumns: Boolean): List[String] = {
// TODO can optimize this, also evaluate SQLi issues
val queryStr = if (hideGraphColumns) s"""IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14)
exec sp_executesql N'SELECT name
FROM sys.computed_columns
WHERE object_id = OBJECT_ID(''${table}'')
UNION ALL
SELECT C.name
FROM sys.tables AS T
JOIN sys.columns AS C
ON T.object_id = C.object_id
WHERE T.object_id = OBJECT_ID(''${table}'')
AND (T.is_edge = 1 OR T.is_node = 1)
AND C.is_hidden = 0
AND C.graph_type = 2'
ELSE
SELECT name
FROM sys.computed_columns
WHERE object_id = OBJECT_ID('${table}')
"""
else s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');"

val computedColRs = conn.createStatement.executeQuery(queryStr)
val computedCols = ListBuffer[String]()
while (computedColRs.next()) {
Expand Down Expand Up @@ -263,7 +284,7 @@ object BulkCopyUtils extends Logging {
val colMetaData = {
if(checkSchema) {
checkExTableType(conn, options)
matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled)
matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled, options.hideGraphColumns)
} else {
defaultColMetadataMap(rs.getMetaData())
}
Expand All @@ -289,6 +310,7 @@ object BulkCopyUtils extends Logging {
* @param url: String,
* @param isCaseSensitive: Boolean
* @param strictSchemaCheck: Boolean
* @param hideGraphColumns - Whether to hide the $node_id, $from_id, $to_id, $edge_id columns in SQL graph tables
*/
private[spark] def matchSchemas(
conn: Connection,
Expand All @@ -297,13 +319,14 @@ object BulkCopyUtils extends Logging {
rs: ResultSet,
url: String,
isCaseSensitive: Boolean,
strictSchemaCheck: Boolean): Array[ColumnMetadata]= {
strictSchemaCheck: Boolean,
hideGraphColumns: Boolean): Array[ColumnMetadata]= {
val dfColCaseMap = (df.schema.fieldNames.map(item => item.toLowerCase)
zip df.schema.fieldNames.toList).toMap
val dfCols = df.schema

val tableCols = getSchema(rs, JdbcDialects.get(url))
val computedCols = getComputedCols(conn, dbtable)
val computedCols = getComputedCols(conn, dbtable, hideGraphColumns)

val prefix = "Spark Dataframe and SQL Server table have differing"

Expand Down