From 8f67293c0d68531e92552b0e508945d4542c2669 Mon Sep 17 00:00:00 2001 From: exAspArk Date: Wed, 12 Feb 2025 12:16:54 -0500 Subject: [PATCH] Fix ::regclass::oid type casting for values with double quotes --- src/custom_types.go | 21 +++++++++++++++++++++ src/parser_type_cast.go | 14 ++++++-------- src/query_handler_test.go | 4 ++-- src/query_remapper_type_cast.go | 4 ++-- src/utils.go | 10 ++++++++++ 5 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/custom_types.go b/src/custom_types.go index ba2128b..263489e 100644 --- a/src/custom_types.go +++ b/src/custom_types.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "strings" ) //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -89,6 +90,26 @@ type QuerySchemaTable struct { Alias string } +func NewQuerySchemaTableFromString(schemaTable string) QuerySchemaTable { + parts := strings.Split(schemaTable, ".") + + qSchemaTable := QuerySchemaTable{ + Table: parts[len(parts)-1], + } + if len(parts) > 1 { + qSchemaTable.Schema = parts[0] + } + + if !StringContainsUpper(qSchemaTable.Schema) { + qSchemaTable.Schema = strings.ReplaceAll(qSchemaTable.Schema, "\"", "") + } + if !StringContainsUpper(qSchemaTable.Table) { + qSchemaTable.Table = strings.ReplaceAll(qSchemaTable.Table, "\"", "") + } + + return qSchemaTable +} + func (qSchemaTable QuerySchemaTable) ToIcebergSchemaTable() IcebergSchemaTable { return IcebergSchemaTable{ Schema: qSchemaTable.Schema, diff --git a/src/parser_type_cast.go b/src/parser_type_cast.go index 0294046..b5fe869 100644 --- a/src/parser_type_cast.go +++ b/src/parser_type_cast.go @@ -77,7 +77,7 @@ func (parser *ParserTypeCast) MakeListValueFromArray(node *pgQuery.Node) *pgQuer // FROM pg_class c // JOIN pg_namespace n ON n.oid = c.relnamespace // WHERE n.nspname = 'schema' AND c.relname = 'table' -func (parser *ParserTypeCast) MakeSubselectOidBySchemaTable(argumentNode *pgQuery.Node) *pgQuery.Node { +func (parser *ParserTypeCast) MakeSubselectOidBySchemaTableArg(argumentNode *pgQuery.Node) *pgQuery.Node { targetNode := pgQuery.MakeResTargetNodeWithVal( pgQuery.MakeColumnRefNode([]*pgQuery.Node{ pgQuery.MakeStrNode("c"), @@ -108,12 +108,10 @@ func (parser *ParserTypeCast) MakeSubselectOidBySchemaTable(argumentNode *pgQuer ) value := argumentNode.GetAConst().GetSval().Sval - parts := strings.Split(value, ".") - schema := PG_SCHEMA_PUBLIC - if len(parts) > 1 { - schema = parts[0] + qSchemaTable := NewQuerySchemaTableFromString(value) + if qSchemaTable.Schema == "" { + qSchemaTable.Schema = PG_SCHEMA_PUBLIC } - table := parts[len(parts)-1] whereNode := pgQuery.MakeBoolExprNode( pgQuery.BoolExprType_AND_EXPR, @@ -127,7 +125,7 @@ func (parser *ParserTypeCast) MakeSubselectOidBySchemaTable(argumentNode *pgQuer pgQuery.MakeStrNode("n"), pgQuery.MakeStrNode("nspname"), }, 0), - pgQuery.MakeAConstStrNode(schema, 0), + pgQuery.MakeAConstStrNode(qSchemaTable.Schema, 0), 0, ), pgQuery.MakeAExprNode( @@ -139,7 +137,7 @@ func (parser *ParserTypeCast) MakeSubselectOidBySchemaTable(argumentNode *pgQuer pgQuery.MakeStrNode("c"), pgQuery.MakeStrNode("relname"), }, 0), - pgQuery.MakeAConstStrNode(table, 0), + pgQuery.MakeAConstStrNode(qSchemaTable.Table, 0), 0, ), }, diff --git a/src/query_handler_test.go b/src/query_handler_test.go index 2b1b865..a6076bc 100644 --- a/src/query_handler_test.go +++ b/src/query_handler_test.go @@ -672,7 +672,7 @@ func TestHandleQuery(t *testing.T) { }, // Type casts - "SELECT 'public.test_table'::regclass::oid AS oid": { + "SELECT '\"public\".\"test_table\"'::regclass::oid AS oid": { "description": {"oid"}, "types": {Uint32ToString(pgtype.OIDOID)}, "values": {"1270"}, @@ -680,7 +680,7 @@ func TestHandleQuery(t *testing.T) { "SELECT attrelid FROM pg_attribute WHERE attrelid = '\"public\".\"test_table\"'::regclass": { "description": {"attrelid"}, "types": {Uint32ToString(pgtype.Int8OID)}, - "values": {}, + "values": {"1270"}, }, "SELECT objoid, classoid, objsubid, description FROM pg_description WHERE classoid = 'pg_class'::regclass": { "description": {"objoid", "classoid", "objsubid", "description"}, diff --git a/src/query_remapper_type_cast.go b/src/query_remapper_type_cast.go index 4ac0c7c..2047cf8 100644 --- a/src/query_remapper_type_cast.go +++ b/src/query_remapper_type_cast.go @@ -37,7 +37,7 @@ func (remapper *QueryRemapperTypeCast) RemapTypeCast(node *pgQuery.Node) *pgQuer return pgQuery.MakeAConstStrNode(nameParts[len(nameParts)-1], 0) case "regclass": // 'schema.table'::regclass -> SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = 'schema' AND c.relname = 'table' - return remapper.parserTypeCast.MakeSubselectOidBySchemaTable(typeCast.Arg) + return remapper.parserTypeCast.MakeSubselectOidBySchemaTableArg(typeCast.Arg) case "oid": // 'schema.table'::regclass::oid -> SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = 'schema' AND c.relname = 'table' nestedNode := typeCast.Arg @@ -50,7 +50,7 @@ func (remapper *QueryRemapperTypeCast) RemapTypeCast(node *pgQuery.Node) *pgQuer return node } - return remapper.parserTypeCast.MakeSubselectOidBySchemaTable(nestedTypeCast.Arg) + return remapper.parserTypeCast.MakeSubselectOidBySchemaTableArg(nestedTypeCast.Arg) } return node diff --git a/src/utils.go b/src/utils.go index 2935f19..92cde53 100644 --- a/src/utils.go +++ b/src/utils.go @@ -9,6 +9,7 @@ import ( "golang.org/x/crypto/pbkdf2" "os" "strconv" + "unicode" ) func PanicIfError(err error, message ...string) { @@ -67,6 +68,15 @@ func StringToScramSha256(password string) string { ) } +func StringContainsUpper(str string) bool { + for _, char := range str { + if unicode.IsUpper(char) { + return true + } + } + return false +} + func hmacSha256Hash(key []byte, message []byte) []byte { hash := hmac.New(sha256.New, key) hash.Write(message)