Skip to content

Commit

Permalink
Fix querying JOINed pg_catalog.pg_namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Dec 12, 2024
1 parent bb9f8d1 commit 5eca926
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

const VERSION = "0.24.0"
const VERSION = "0.24.1"

func main() {
config := LoadConfig()
Expand Down
4 changes: 4 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func TestHandleQuery(t *testing.T) {
"description": {"nspname"},
"values": {},
},
"SELECT n.nspname FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid = n.oid ORDER BY n.oid LIMIT 1": {
"description": {"nspname"},
"values": {"public"},
},
// pg_statio_user_tables
"SELECT pg_total_relation_size(relid) AS total_size FROM pg_catalog.pg_statio_user_tables WHERE schemaname = 'public'": {
"description": {"total_size"},
Expand Down
46 changes: 32 additions & 14 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ var KNOWN_SET_STATEMENTS = NewSet([]string{
"standard_conforming_strings", // SET standard_conforming_strings = on
"intervalstyle", // SET intervalstyle = iso_8601
"timezone", // SET SESSION timezone TO 'UTC'
"extra_float_digits", // SET extra_float_digits = 3
"application_name", // SET application_name = 'psql'
})

type SelectRemapper struct {
parserTable *QueryParserTable
remapperTable *SelectRemapperTable
remapperWhere *SelectRemapperWhere
remapperSelect *SelectRemapperSelect
Expand All @@ -34,6 +37,7 @@ type SelectRemapper struct {

func NewSelectRemapper(config *Config, icebergReader *IcebergReader) *SelectRemapper {
return &SelectRemapper{
parserTable: NewQueryParserTable(config),
remapperTable: NewSelectRemapperTable(config, icebergReader),
remapperWhere: NewSelectRemapperWhere(config),
remapperSelect: NewSelectRemapperSelect(config),
Expand All @@ -60,7 +64,7 @@ func (selectRemapper *SelectRemapper) RemapQueryTreeWithSet(queryTree *pgQuery.P

queryTree.Stmts[0].Stmt.GetVariableSetStmt().Name = "schema"
queryTree.Stmts[0].Stmt.GetVariableSetStmt().Args = []*pgQuery.Node{
pgQuery.MakeAConstStrNode("main", 0),
pgQuery.MakeAConstStrNode(PG_SCHEMA_PUBLIC, 0),
}

return queryTree
Expand All @@ -87,32 +91,34 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu
// JOIN
if len(selectStatement.FromClause) > 0 && selectStatement.FromClause[0].GetJoinExpr() != nil {
// SELECT
selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel) // recursive
selectStatement.FromClause[0] = selectRemapper.remapJoinExpressions(selectStatement.FromClause[0], indentLevel) // recursive with self-recursion
selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel) // recursive
selectStatement.FromClause[0] = selectRemapper.remapJoinExpressions(selectStatement, selectStatement.FromClause[0], indentLevel+1) // recursive with self-recursion
return selectStatement
}

// FROM
if len(selectStatement.FromClause) > 0 {
// WHERE
if selectStatement.FromClause[0].GetRangeVar() != nil {
selectRemapper.logTreeTraversal("WHERE statements", indentLevel)
selectStatement = selectRemapper.remapperWhere.RemapWhere(selectStatement)
}

// SELECT
selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel) // recursive

// FROM
for i, fromNode := range selectStatement.FromClause {
if fromNode.GetRangeVar() != nil {
// WHERE
selectRemapper.logTreeTraversal("WHERE statements", indentLevel)
schemaTable := selectRemapper.parserTable.NodeToSchemaTable(fromNode)
selectStatement = selectRemapper.remapperWhere.RemapWhere(schemaTable, selectStatement)
// TABLE
selectRemapper.logTreeTraversal("FROM table", indentLevel)
selectStatement.FromClause[i] = selectRemapper.remapperTable.RemapTable(fromNode)
} else if fromNode.GetRangeSubselect() != nil {
// FROM (SELECT ...)
selectRemapper.logTreeTraversal("FROM subselect", indentLevel)
subSelectStatement := fromNode.GetRangeSubselect().Subquery.GetSelectStmt()
subSelectStatement = selectRemapper.remapSelectStatement(subSelectStatement, indentLevel+1) // self-recursion
}

// FROM PG_FUNCTION()
if fromNode.GetRangeFunction() != nil {
selectStatement.FromClause[i] = selectRemapper.remapTableFunction(fromNode, indentLevel+1) // recursive
}
Expand Down Expand Up @@ -257,24 +263,36 @@ func (selectRemapper *SelectRemapper) remapTypeCastsInNode(node *pgQuery.Node) *
return node
}

func (selectRemapper *SelectRemapper) remapJoinExpressions(node *pgQuery.Node, indentLevel int) *pgQuery.Node {
selectRemapper.logTreeTraversal("JOIN left", indentLevel+1)
func (selectRemapper *SelectRemapper) remapJoinExpressions(selectStatement *pgQuery.SelectStmt, node *pgQuery.Node, indentLevel int) *pgQuery.Node {
selectRemapper.logTreeTraversal("JOIN left", indentLevel)
leftJoinNode := node.GetJoinExpr().Larg
if leftJoinNode.GetJoinExpr() != nil {
leftJoinNode = selectRemapper.remapJoinExpressions(leftJoinNode, indentLevel+1) // self-recursion
leftJoinNode = selectRemapper.remapJoinExpressions(selectStatement, leftJoinNode, indentLevel+1) // self-recursion
} else if leftJoinNode.GetRangeVar() != nil {
// WHERE
selectRemapper.logTreeTraversal("WHERE left", indentLevel+1)
schemaTable := selectRemapper.parserTable.NodeToSchemaTable(leftJoinNode)
selectStatement = selectRemapper.remapperWhere.RemapWhere(schemaTable, selectStatement)
// TABLE
selectRemapper.logTreeTraversal("TABLE left", indentLevel+1)
leftJoinNode = selectRemapper.remapperTable.RemapTable(leftJoinNode)
} else if leftJoinNode.GetRangeSubselect() != nil {
leftSelectStatement := leftJoinNode.GetRangeSubselect().Subquery.GetSelectStmt()
leftSelectStatement = selectRemapper.remapSelectStatement(leftSelectStatement, indentLevel+1) // parent-recursion
}
node.GetJoinExpr().Larg = leftJoinNode

selectRemapper.logTreeTraversal("JOIN right", indentLevel+1)
selectRemapper.logTreeTraversal("JOIN right", indentLevel)
rightJoinNode := node.GetJoinExpr().Rarg
if rightJoinNode.GetJoinExpr() != nil {
rightJoinNode = selectRemapper.remapJoinExpressions(rightJoinNode, indentLevel+1) // self-recursion
rightJoinNode = selectRemapper.remapJoinExpressions(selectStatement, rightJoinNode, indentLevel+1) // self-recursion
} else if rightJoinNode.GetRangeVar() != nil {
// WHERE
selectRemapper.logTreeTraversal("WHERE right", indentLevel+1)
schemaTable := selectRemapper.parserTable.NodeToSchemaTable(rightJoinNode)
selectStatement = selectRemapper.remapperWhere.RemapWhere(schemaTable, selectStatement)
// TABLE
selectRemapper.logTreeTraversal("TABLE right", indentLevel+1)
rightJoinNode = selectRemapper.remapperTable.RemapTable(rightJoinNode)
} else if rightJoinNode.GetRangeSubselect() != nil {
rightSelectStatement := rightJoinNode.GetRangeSubselect().Subquery.GetSelectStmt()
Expand Down
4 changes: 1 addition & 3 deletions src/select_remapper_where.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ func NewSelectRemapperWhere(config *Config) *SelectRemapperWhere {
}

// WHERE [CONDITION]
func (remapper *SelectRemapperWhere) RemapWhere(selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt {
schemaTable := remapper.parserTable.NodeToSchemaTable(selectStatement.FromClause[0])

func (remapper *SelectRemapperWhere) RemapWhere(schemaTable SchemaTable, selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt {
// FROM pg_catalog.pg_namespace -> FROM pg_catalog.pg_namespace WHERE nspname != 'main'
if remapper.parserTable.IsPgNamespaceTable(schemaTable) {
withoutMainSchemaWhereCondition := remapper.parserWhere.MakeExpressionNode("nspname", "!=", "main")
Expand Down

0 comments on commit 5eca926

Please sign in to comment.