diff --git a/src/parser_type.go b/src/parser_type.go deleted file mode 100644 index 1c75869..0000000 --- a/src/parser_type.go +++ /dev/null @@ -1,96 +0,0 @@ -package main - -import ( - "strings" - - pgQuery "github.com/pganalyze/pg_query_go/v5" -) - -type ParserType struct { - config *Config -} - -func NewParserType(config *Config) *ParserType { - return &ParserType{config: config} -} - -func (parser *ParserType) MakeTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node { - return &pgQuery.Node{ - Node: &pgQuery.Node_TypeCast{ - TypeCast: &pgQuery.TypeCast{ - Arg: arg, - TypeName: &pgQuery.TypeName{ - Names: []*pgQuery.Node{ - pgQuery.MakeStrNode(typeName), - }, - Location: 0, - }, - }, - }, - } -} - -func (parser *ParserType) inferNodeType(node *pgQuery.Node) string { - if typeCast := node.GetTypeCast(); typeCast != nil { - return typeCast.TypeName.Names[0].GetString_().Sval - } - - if aConst := node.GetAConst(); aConst != nil { - switch { - case aConst.GetBoolval() != nil: - return "boolean" - case aConst.GetIval() != nil: - return "int8" - case aConst.GetSval() != nil: - return "text" - } - } - return "" -} - -func (parser *ParserType) MakeCaseTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node { - if existingType := parser.inferNodeType(arg); existingType == typeName { - return arg - } - return parser.MakeTypeCastNode(arg, typeName) -} - -func (parser *ParserType) RemapTypeCast(node *pgQuery.Node) *pgQuery.Node { - if node.GetTypeCast() != nil { - typeCast := node.GetTypeCast() - if len(typeCast.TypeName.Names) > 0 { - typeName := typeCast.TypeName.Names[0].GetString_().Sval - if typeName == "regclass" { - return typeCast.Arg - } - - if typeName == "text" { - return parser.MakeListValueFromArray(typeCast.Arg) - } - } - } - return node -} - -func (parser *ParserType) MakeListValueFromArray(node *pgQuery.Node) *pgQuery.Node { - arrayStr := node.GetAConst().GetSval().Sval - arrayStr = strings.Trim(arrayStr, "{}") - elements := strings.Split(arrayStr, ",") - - funcCall := &pgQuery.FuncCall{ - Funcname: []*pgQuery.Node{ - pgQuery.MakeStrNode("list_value"), - }, - } - - for _, elem := range elements { - funcCall.Args = append(funcCall.Args, - pgQuery.MakeAConstStrNode(elem, 0)) - } - - return &pgQuery.Node{ - Node: &pgQuery.Node_FuncCall{ - FuncCall: funcCall, - }, - } -} diff --git a/src/parser_type_cast.go b/src/parser_type_cast.go new file mode 100644 index 0000000..a9a3c10 --- /dev/null +++ b/src/parser_type_cast.go @@ -0,0 +1,85 @@ +package main + +import ( + "strings" + + pgQuery "github.com/pganalyze/pg_query_go/v5" +) + +type ParserTypeCast struct { + utils *ParserUtils + config *Config +} + +func NewParserTypeCast(config *Config) *ParserTypeCast { + return &ParserTypeCast{utils: NewParserUtils(config), config: config} +} + +func (parser *ParserTypeCast) TypeCast(node *pgQuery.Node) *pgQuery.TypeCast { + if node.GetTypeCast() == nil { + return nil + } + + typeCast := node.GetTypeCast() + if len(typeCast.TypeName.Names) == 0 { + return nil + } + + return typeCast +} + +func (parser *ParserTypeCast) TypeName(typeCast *pgQuery.TypeCast) string { + return typeCast.TypeName.Names[0].GetString_().Sval +} + +func (parser *ParserTypeCast) ArgStringValue(typeCast *pgQuery.TypeCast) string { + return typeCast.Arg.GetAConst().GetSval().Sval +} + +func (parser *ParserTypeCast) MakeCaseTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node { + if existingType := parser.inferNodeType(arg); existingType == typeName { + return arg + } + return parser.utils.MakeTypeCastNode(arg, typeName) +} + +func (parser *ParserTypeCast) MakeListValueFromArray(node *pgQuery.Node) *pgQuery.Node { + arrayStr := node.GetAConst().GetSval().Sval + arrayStr = strings.Trim(arrayStr, "{}") + elements := strings.Split(arrayStr, ",") + + funcCall := &pgQuery.FuncCall{ + Funcname: []*pgQuery.Node{ + pgQuery.MakeStrNode("list_value"), + }, + } + + for _, elem := range elements { + funcCall.Args = append(funcCall.Args, + pgQuery.MakeAConstStrNode(elem, 0)) + } + + return &pgQuery.Node{ + Node: &pgQuery.Node_FuncCall{ + FuncCall: funcCall, + }, + } +} + +func (parser *ParserTypeCast) inferNodeType(node *pgQuery.Node) string { + if typeCast := node.GetTypeCast(); typeCast != nil { + return typeCast.TypeName.Names[0].GetString_().Sval + } + + if aConst := node.GetAConst(); aConst != nil { + switch { + case aConst.GetBoolval() != nil: + return "boolean" + case aConst.GetIval() != nil: + return "int8" + case aConst.GetSval() != nil: + return "text" + } + } + return "" +} diff --git a/src/parser_utils.go b/src/parser_utils.go index 32af04d..547de0b 100644 --- a/src/parser_utils.go +++ b/src/parser_utils.go @@ -70,7 +70,6 @@ func (utils *ParserUtils) MakeSubselectWithRowsNode(tableName string, tableDef T } func (utils *ParserUtils) MakeSubselectWithoutRowsNode(tableName string, tableDef TableDefinition, alias string) *pgQuery.Node { - parserType := NewParserType(utils.config) columnNodes := make([]*pgQuery.Node, len(tableDef.Columns)) for i, col := range tableDef.Columns { columnNodes[i] = pgQuery.MakeStrNode(col.Name) @@ -85,7 +84,7 @@ func (utils *ParserUtils) MakeSubselectWithoutRowsNode(tableName string, tableDe }, }, } - typedNullNode := parserType.MakeTypeCastNode(nullNode, col.Type) + typedNullNode := utils.MakeTypeCastNode(nullNode, col.Type) targetList[i] = pgQuery.MakeResTargetNodeWithVal(typedNullNode, 0) } @@ -154,8 +153,6 @@ func (utils *ParserUtils) MakeAConstBoolNode(val bool) *pgQuery.Node { } func (utils *ParserUtils) makeTypedConstNode(val string, pgType string) *pgQuery.Node { - parserType := NewParserType(utils.config) - if val == "NULL" { return &pgQuery.Node{ Node: &pgQuery.Node_AConst{ @@ -168,5 +165,21 @@ func (utils *ParserUtils) makeTypedConstNode(val string, pgType string) *pgQuery constNode := pgQuery.MakeAConstStrNode(val, 0) - return parserType.MakeTypeCastNode(constNode, pgType) + return utils.MakeTypeCastNode(constNode, pgType) +} + +func (utils *ParserUtils) MakeTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node { + return &pgQuery.Node{ + Node: &pgQuery.Node_TypeCast{ + TypeCast: &pgQuery.TypeCast{ + Arg: arg, + TypeName: &pgQuery.TypeName{ + Names: []*pgQuery.Node{ + pgQuery.MakeStrNode(typeName), + }, + Location: 0, + }, + }, + }, + } } diff --git a/src/query_handler_test.go b/src/query_handler_test.go index e2a13e8..8d9ca80 100644 --- a/src/query_handler_test.go +++ b/src/query_handler_test.go @@ -679,6 +679,11 @@ func TestHandleQuery(t *testing.T) { "types": {Uint32ToString(pgtype.Int2OID)}, "values": {"1"}, }, + "SELECT 'pg_catalog.array_in'::regproc AS regproc": { + "description": {"regproc"}, + "types": {Uint32ToString(pgtype.TextOID)}, + "values": {"array_in"}, + }, // SELECT * FROM function() "SELECT * FROM pg_catalog.pg_get_keywords() LIMIT 1": { diff --git a/src/query_remapper.go b/src/query_remapper.go index 6b11afa..d2186ed 100644 --- a/src/query_remapper.go +++ b/src/query_remapper.go @@ -26,28 +26,30 @@ var FALLBACK_QUERY_TREE, _ = pgQuery.Parse(FALLBACK_SQL_QUERY) var FALLBACK_SET_QUERY_TREE, _ = pgQuery.Parse("SET schema TO public") type QueryRemapper struct { - parserTable *ParserTable - parserType *ParserType - remapperTable *QueryRemapperTable - remapperWhere *QueryRemapperWhere - remapperSelect *QueryRemapperSelect - remapperShow *QueryRemapperShow - icebergReader *IcebergReader - duckdb *Duckdb - config *Config + parserTable *ParserTable + parserTypeCast *ParserTypeCast + remapperTable *QueryRemapperTable + remapperTypeCast *QueryRemapperTypeCast + remapperWhere *QueryRemapperWhere + remapperSelect *QueryRemapperSelect + remapperShow *QueryRemapperShow + icebergReader *IcebergReader + duckdb *Duckdb + config *Config } func NewQueryRemapper(config *Config, icebergReader *IcebergReader, duckdb *Duckdb) *QueryRemapper { return &QueryRemapper{ - parserTable: NewParserTable(config), - parserType: NewParserType(config), - remapperTable: NewQueryRemapperTable(config, icebergReader, duckdb), - remapperWhere: NewQueryRemapperWhere(config), - remapperSelect: NewQueryRemapperSelect(config), - remapperShow: NewQueryRemapperShow(config), - icebergReader: icebergReader, - duckdb: duckdb, - config: config, + parserTable: NewParserTable(config), + parserTypeCast: NewParserTypeCast(config), + remapperTable: NewQueryRemapperTable(config, icebergReader, duckdb), + remapperTypeCast: NewQueryRemapperTypeCast(config), + remapperWhere: NewQueryRemapperWhere(config), + remapperSelect: NewQueryRemapperSelect(config), + remapperShow: NewQueryRemapperShow(config), + icebergReader: icebergReader, + duckdb: duckdb, + config: config, } } @@ -120,12 +122,6 @@ func (remapper *QueryRemapper) remapSelectStatement(selectStatement *pgQuery.Sel } } - // CASE - if hasCaseExpr := remapper.hasCaseExpressions(selectStatement); hasCaseExpr { - remapper.traceTreeTraversal("CASE expressions", indentLevel) - remapper.remapCaseExpressions(selectStatement, indentLevel) // recursive - } - // UNION if selectStatement.FromClause == nil && selectStatement.Larg != nil && selectStatement.Rarg != nil { remapper.traceTreeTraversal("UNION left", indentLevel) @@ -144,7 +140,7 @@ func (remapper *QueryRemapper) remapSelectStatement(selectStatement *pgQuery.Sel // WHERE if selectStatement.WhereClause != nil { - selectStatement.WhereClause = remapper.remapTypeCastsInNode(selectStatement.WhereClause) + selectStatement.WhereClause = remapper.remapTypeCastsInNode(selectStatement.WhereClause) // recursive selectStatement = remapper.remapWhereExpressions(selectStatement, selectStatement.WhereClause, indentLevel) // recursive } @@ -180,17 +176,10 @@ func (remapper *QueryRemapper) remapSelectStatement(selectStatement *pgQuery.Sel } } + // SELECT selectStatement = remapper.remapSelect(selectStatement, indentLevel) // recursive - return selectStatement -} -func (remapper *QueryRemapper) hasCaseExpressions(selectStatement *pgQuery.SelectStmt) bool { - for _, target := range selectStatement.TargetList { - if target.GetResTarget().Val.GetCaseExpr() != nil { - return true - } - } - return false + return selectStatement } func (remapper *QueryRemapper) remapCaseExpressions(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt { @@ -252,16 +241,16 @@ func (remapper *QueryRemapper) remapCaseExpressions(selectStatement *pgQuery.Sel func (remapper *QueryRemapper) ensureConsistentCaseTypes(caseExpr *pgQuery.CaseExpr) { if len(caseExpr.Args) > 0 { if when := caseExpr.Args[0].GetCaseWhen(); when != nil && when.Result != nil { - if typeName := remapper.parserType.inferNodeType(when.Result); typeName != "" { + if typeName := remapper.parserTypeCast.inferNodeType(when.Result); typeName != "" { // WHEN for i := 1; i < len(caseExpr.Args); i++ { if whenClause := caseExpr.Args[i].GetCaseWhen(); whenClause != nil && whenClause.Result != nil { - whenClause.Result = remapper.parserType.MakeCaseTypeCastNode(whenClause.Result, typeName) + whenClause.Result = remapper.parserTypeCast.MakeCaseTypeCastNode(whenClause.Result, typeName) } } // ELSE if caseExpr.Defresult != nil { - caseExpr.Defresult = remapper.parserType.MakeCaseTypeCastNode(caseExpr.Defresult, typeName) + caseExpr.Defresult = remapper.parserTypeCast.MakeCaseTypeCastNode(caseExpr.Defresult, typeName) } } } @@ -316,7 +305,7 @@ func (remapper *QueryRemapper) remapTypeCastsInNode(node *pgQuery.Node) *pgQuery // Direct typecast if node.GetTypeCast() != nil { - return remapper.remapTypecast(node) + return remapper.remapperTypeCast.RemapTypeCast(node) } // Handle CASE expressions @@ -405,7 +394,7 @@ func (remapper *QueryRemapper) remapJoinExpressions(selectStatement *pgQuery.Sel remapper.traceTreeTraversal("JOIN on", indentLevel) if node.GetJoinExpr().Quals != nil { - node.GetJoinExpr().Quals = remapper.remapTypeCastsInNode(node.GetJoinExpr().Quals) + node.GetJoinExpr().Quals = remapper.remapTypeCastsInNode(node.GetJoinExpr().Quals) // recursive } return node @@ -442,11 +431,19 @@ func (remapper *QueryRemapper) remapWhereExpressions(selectStatement *pgQuery.Se return selectStatement } +// SELECT ... func (remapper *QueryRemapper) remapSelect(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt { remapper.traceTreeTraversal("SELECT statements", indentLevel) // SELECT ... for i, targetNode := range selectStatement.TargetList { + if targetNode.GetResTarget().Val.GetCaseExpr() == nil { + targetNode.GetResTarget().Val = remapper.remapTypeCastsInNode(targetNode.GetResTarget().Val) // recursive + } else { + // CASE + remapper.remapCaseExpressions(selectStatement, indentLevel) // recursive + } + targetNode = remapper.remapperSelect.RemapSelect(targetNode) selectStatement.TargetList[i] = targetNode } @@ -455,7 +452,7 @@ func (remapper *QueryRemapper) remapSelect(selectStatement *pgQuery.SelectStmt, if len(selectStatement.ValuesLists) > 0 { for i, valuesList := range selectStatement.ValuesLists { for j, value := range valuesList.GetList().Items { - selectStatement.ValuesLists[i].GetList().Items[j] = remapper.remapTypeCastsInNode(value) + selectStatement.ValuesLists[i].GetList().Items[j] = remapper.remapTypeCastsInNode(value) // recursive } } } @@ -463,10 +460,6 @@ func (remapper *QueryRemapper) remapSelect(selectStatement *pgQuery.SelectStmt, return selectStatement } -func (remapper *QueryRemapper) remapTypecast(node *pgQuery.Node) *pgQuery.Node { - return remapper.parserType.RemapTypeCast(node) -} - func (remapper *QueryRemapper) traceTreeTraversal(label string, indentLevel int) { LogTrace(remapper.config, strings.Repeat(">", indentLevel), label) } diff --git a/src/query_remapper_type_cast.go b/src/query_remapper_type_cast.go new file mode 100644 index 0000000..a0c05a4 --- /dev/null +++ b/src/query_remapper_type_cast.go @@ -0,0 +1,41 @@ +package main + +import ( + "strings" + + pgQuery "github.com/pganalyze/pg_query_go/v5" +) + +type QueryRemapperTypeCast struct { + parserTypeCast *ParserTypeCast + config *Config +} + +func NewQueryRemapperTypeCast(config *Config) *QueryRemapperTypeCast { + remapper := &QueryRemapperTypeCast{ + parserTypeCast: NewParserTypeCast(config), + config: config, + } + return remapper +} + +// value::type -> value +func (remapper *QueryRemapperTypeCast) RemapTypeCast(node *pgQuery.Node) *pgQuery.Node { + typeCast := remapper.parserTypeCast.TypeCast(node) + if typeCast == nil { + return node + } + + typeName := remapper.parserTypeCast.TypeName(typeCast) + switch typeName { + case "regclass": + return typeCast.Arg + case "text": + return remapper.parserTypeCast.MakeListValueFromArray(typeCast.Arg) + case "regproc": + functionNameParts := strings.Split(remapper.parserTypeCast.ArgStringValue(typeCast), ".") // pg_catalog.func_name + return pgQuery.MakeAConstStrNode(functionNameParts[len(functionNameParts)-1], 0) + } + + return node +}