Skip to content

Commit

Permalink
Fix ::regclass::oid type casting for values with double quotes
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Feb 12, 2025
1 parent 1b3a1e1 commit 8f67293
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 12 deletions.
21 changes: 21 additions & 0 deletions src/custom_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"strings"
)

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -89,6 +90,26 @@ type QuerySchemaTable struct {
Alias string
}

func NewQuerySchemaTableFromString(schemaTable string) QuerySchemaTable {
parts := strings.Split(schemaTable, ".")

qSchemaTable := QuerySchemaTable{
Table: parts[len(parts)-1],
}
if len(parts) > 1 {
qSchemaTable.Schema = parts[0]
}

if !StringContainsUpper(qSchemaTable.Schema) {
qSchemaTable.Schema = strings.ReplaceAll(qSchemaTable.Schema, "\"", "")
}
if !StringContainsUpper(qSchemaTable.Table) {
qSchemaTable.Table = strings.ReplaceAll(qSchemaTable.Table, "\"", "")
}

return qSchemaTable
}

func (qSchemaTable QuerySchemaTable) ToIcebergSchemaTable() IcebergSchemaTable {
return IcebergSchemaTable{
Schema: qSchemaTable.Schema,
Expand Down
14 changes: 6 additions & 8 deletions src/parser_type_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (parser *ParserTypeCast) MakeListValueFromArray(node *pgQuery.Node) *pgQuer
// FROM pg_class c
// JOIN pg_namespace n ON n.oid = c.relnamespace
// WHERE n.nspname = 'schema' AND c.relname = 'table'
func (parser *ParserTypeCast) MakeSubselectOidBySchemaTable(argumentNode *pgQuery.Node) *pgQuery.Node {
func (parser *ParserTypeCast) MakeSubselectOidBySchemaTableArg(argumentNode *pgQuery.Node) *pgQuery.Node {
targetNode := pgQuery.MakeResTargetNodeWithVal(
pgQuery.MakeColumnRefNode([]*pgQuery.Node{
pgQuery.MakeStrNode("c"),
Expand Down Expand Up @@ -108,12 +108,10 @@ func (parser *ParserTypeCast) MakeSubselectOidBySchemaTable(argumentNode *pgQuer
)

value := argumentNode.GetAConst().GetSval().Sval
parts := strings.Split(value, ".")
schema := PG_SCHEMA_PUBLIC
if len(parts) > 1 {
schema = parts[0]
qSchemaTable := NewQuerySchemaTableFromString(value)
if qSchemaTable.Schema == "" {
qSchemaTable.Schema = PG_SCHEMA_PUBLIC
}
table := parts[len(parts)-1]

whereNode := pgQuery.MakeBoolExprNode(
pgQuery.BoolExprType_AND_EXPR,
Expand All @@ -127,7 +125,7 @@ func (parser *ParserTypeCast) MakeSubselectOidBySchemaTable(argumentNode *pgQuer
pgQuery.MakeStrNode("n"),
pgQuery.MakeStrNode("nspname"),
}, 0),
pgQuery.MakeAConstStrNode(schema, 0),
pgQuery.MakeAConstStrNode(qSchemaTable.Schema, 0),
0,
),
pgQuery.MakeAExprNode(
Expand All @@ -139,7 +137,7 @@ func (parser *ParserTypeCast) MakeSubselectOidBySchemaTable(argumentNode *pgQuer
pgQuery.MakeStrNode("c"),
pgQuery.MakeStrNode("relname"),
}, 0),
pgQuery.MakeAConstStrNode(table, 0),
pgQuery.MakeAConstStrNode(qSchemaTable.Table, 0),
0,
),
},
Expand Down
4 changes: 2 additions & 2 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -672,15 +672,15 @@ func TestHandleQuery(t *testing.T) {
},

// Type casts
"SELECT 'public.test_table'::regclass::oid AS oid": {
"SELECT '\"public\".\"test_table\"'::regclass::oid AS oid": {
"description": {"oid"},
"types": {Uint32ToString(pgtype.OIDOID)},
"values": {"1270"},
},
"SELECT attrelid FROM pg_attribute WHERE attrelid = '\"public\".\"test_table\"'::regclass": {
"description": {"attrelid"},
"types": {Uint32ToString(pgtype.Int8OID)},
"values": {},
"values": {"1270"},
},
"SELECT objoid, classoid, objsubid, description FROM pg_description WHERE classoid = 'pg_class'::regclass": {
"description": {"objoid", "classoid", "objsubid", "description"},
Expand Down
4 changes: 2 additions & 2 deletions src/query_remapper_type_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (remapper *QueryRemapperTypeCast) RemapTypeCast(node *pgQuery.Node) *pgQuer
return pgQuery.MakeAConstStrNode(nameParts[len(nameParts)-1], 0)
case "regclass":
// 'schema.table'::regclass -> SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = 'schema' AND c.relname = 'table'
return remapper.parserTypeCast.MakeSubselectOidBySchemaTable(typeCast.Arg)
return remapper.parserTypeCast.MakeSubselectOidBySchemaTableArg(typeCast.Arg)
case "oid":
// 'schema.table'::regclass::oid -> SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = 'schema' AND c.relname = 'table'
nestedNode := typeCast.Arg
Expand All @@ -50,7 +50,7 @@ func (remapper *QueryRemapperTypeCast) RemapTypeCast(node *pgQuery.Node) *pgQuer
return node
}

return remapper.parserTypeCast.MakeSubselectOidBySchemaTable(nestedTypeCast.Arg)
return remapper.parserTypeCast.MakeSubselectOidBySchemaTableArg(nestedTypeCast.Arg)
}

return node
Expand Down
10 changes: 10 additions & 0 deletions src/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"golang.org/x/crypto/pbkdf2"
"os"
"strconv"
"unicode"
)

func PanicIfError(err error, message ...string) {
Expand Down Expand Up @@ -67,6 +68,15 @@ func StringToScramSha256(password string) string {
)
}

func StringContainsUpper(str string) bool {
for _, char := range str {
if unicode.IsUpper(char) {
return true
}
}
return false
}

func hmacSha256Hash(key []byte, message []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(message)
Expand Down

0 comments on commit 8f67293

Please sign in to comment.