Skip to content

Commit

Permalink
Add ::regclass support
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Dec 4, 2024
1 parent be711d6 commit caa03bb
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,31 @@ func TestHandleQuery(t *testing.T) {
"description": {"user_defined_column"},
"values": {""},
},
// Typecasts
"SELECT objoid, classoid, objsubid, description FROM pg_description WHERE classoid = 'pg_class'::regclass": {
"description": {"objoid", "classoid", "objsubid", "description"},
"values": {},
},
"SELECT d.objoid, d.classoid, d.description, c.relname FROM pg_description d JOIN pg_class c ON d.classoid = 'pg_class'::regclass": {
"description": {"objoid", "classoid", "description", "relname"},
"values": {},
},
"SELECT objoid, classoid, objsubid, description FROM (SELECT * FROM pg_description WHERE classoid = 'pg_class'::regclass) d": {
"description": {"objoid", "classoid", "objsubid", "description"},
"values": {},
},
"SELECT objoid, classoid, objsubid, description FROM pg_description WHERE (classoid = 'pg_class'::regclass AND objsubid = 0) OR classoid = 'pg_type'::regclass": {
"description": {"objoid", "classoid", "objsubid", "description"},
"values": {},
},
"SELECT objoid, classoid, objsubid, description FROM pg_description WHERE classoid IN ('pg_class'::regclass, 'pg_type'::regclass)": {
"description": {"objoid", "classoid", "objsubid", "description"},
"values": {},
},
"SELECT objoid FROM pg_description WHERE classoid = CASE WHEN true THEN 'pg_class'::regclass ELSE 'pg_type'::regclass END": {
"description": {"objoid"},
"values": {},
},
}

for query, responses := range responsesByQuery {
Expand Down
107 changes: 107 additions & 0 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func (selectRemapper *SelectRemapper) RemapQueryTreeWithSet(queryTree *pgQuery.P
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
selectStatement = selectRemapper.remapTypeCastsInSelect(selectStatement)

if selectStatement.FromClause == nil && selectStatement.Larg != nil && selectStatement.Rarg != nil {
LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel+1)+" UNION left")
leftSelectStatement := selectStatement.Larg
Expand Down Expand Up @@ -105,6 +107,98 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu
return selectStatement
}

func (selectRemapper *SelectRemapper) remapTypeCastsInSelect(selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt {
// WHERE [CONDITION]
if selectStatement.WhereClause != nil {
selectStatement.WhereClause = selectRemapper.remapTypeCastsInNode(selectStatement.WhereClause)
}

// FROM / JOIN [TABLE] and VALUES
if len(selectStatement.FromClause) > 0 {
for _, fromNode := range selectStatement.FromClause {
if fromNode.GetJoinExpr() != nil {
joinExpr := fromNode.GetJoinExpr()
if joinExpr.Quals != nil {
joinExpr.Quals = selectRemapper.remapTypeCastsInNode(joinExpr.Quals)
}
}
// Subqueries
if fromNode.GetRangeSubselect() != nil {
subSelect := fromNode.GetRangeSubselect().Subquery.GetSelectStmt()
selectRemapper.remapTypeCastsInSelect(subSelect)
}
}
}

// VALUES list
if len(selectStatement.ValuesLists) > 0 {
for i, valuesList := range selectStatement.ValuesLists {
for j, value := range valuesList.GetList().Items {
selectStatement.ValuesLists[i].GetList().Items[j] = selectRemapper.remapTypeCastsInNode(value)
}
}
}

return selectStatement
}

func (selectRemapper *SelectRemapper) remapTypeCastsInNode(node *pgQuery.Node) *pgQuery.Node {
if node == nil {
return nil
}

// Direct typecast
if node.GetTypeCast() != nil {
return selectRemapper.remapTypecast(node)
}

// Handle CASE expressions
if node.GetCaseExpr() != nil {
caseExpr := node.GetCaseExpr()
// Handle WHEN clauses
for i, when := range caseExpr.Args {
whenClause := when.GetCaseWhen()
if whenClause.Result != nil {
whenClause.Result = selectRemapper.remapTypeCastsInNode(whenClause.Result)
}
caseExpr.Args[i] = when
}
// Handle ELSE clause
if caseExpr.Defresult != nil {
caseExpr.Defresult = selectRemapper.remapTypeCastsInNode(caseExpr.Defresult)
}
}

// AND/OR expressions
if node.GetBoolExpr() != nil {
boolExpr := node.GetBoolExpr()
for i, arg := range boolExpr.Args {
boolExpr.Args[i] = selectRemapper.remapTypeCastsInNode(arg)
}
}

// Comparison expressions
if node.GetAExpr() != nil {
aExpr := node.GetAExpr()
if aExpr.Lexpr != nil {
aExpr.Lexpr = selectRemapper.remapTypeCastsInNode(aExpr.Lexpr)
}
if aExpr.Rexpr != nil {
aExpr.Rexpr = selectRemapper.remapTypeCastsInNode(aExpr.Rexpr)
}
}

// IN expressions
if node.GetList() != nil {
list := node.GetList()
for i, item := range list.Items {
list.Items[i] = selectRemapper.remapTypeCastsInNode(item)
}
}

return node
}

func (selectRemapper *SelectRemapper) remapJoinExpressions(node *pgQuery.Node, indentLevel int) *pgQuery.Node {
LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel+1)+" JOIN left")
leftJoinNode := node.GetJoinExpr().Larg
Expand Down Expand Up @@ -318,3 +412,16 @@ func (selectRemapper *SelectRemapper) remappedConstantNode(functionCall *pgQuery

return nil
}

func (selectRemapper *SelectRemapper) remapTypecast(node *pgQuery.Node) *pgQuery.Node {
if node.GetTypeCast() != nil {
typeCast := node.GetTypeCast()
if len(typeCast.TypeName.Names) > 0 {
typeName := typeCast.TypeName.Names[0].GetString_().Sval
if typeName == "regclass" {
return typeCast.Arg
}
}
}
return node
}

0 comments on commit caa03bb

Please sign in to comment.