Skip to content

Commit

Permalink
Fix type cast for ::regproc
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Feb 3, 2025
1 parent 247a8b0 commit fe7e7b1
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 145 deletions.
96 changes: 0 additions & 96 deletions src/parser_type.go

This file was deleted.

85 changes: 85 additions & 0 deletions src/parser_type_cast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package main

import (
"strings"

pgQuery "github.com/pganalyze/pg_query_go/v5"
)

type ParserTypeCast struct {
utils *ParserUtils
config *Config
}

func NewParserTypeCast(config *Config) *ParserTypeCast {
return &ParserTypeCast{utils: NewParserUtils(config), config: config}
}

func (parser *ParserTypeCast) TypeCast(node *pgQuery.Node) *pgQuery.TypeCast {
if node.GetTypeCast() == nil {
return nil
}

typeCast := node.GetTypeCast()
if len(typeCast.TypeName.Names) == 0 {
return nil
}

return typeCast
}

func (parser *ParserTypeCast) TypeName(typeCast *pgQuery.TypeCast) string {
return typeCast.TypeName.Names[0].GetString_().Sval
}

func (parser *ParserTypeCast) ArgStringValue(typeCast *pgQuery.TypeCast) string {
return typeCast.Arg.GetAConst().GetSval().Sval
}

func (parser *ParserTypeCast) MakeCaseTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node {
if existingType := parser.inferNodeType(arg); existingType == typeName {
return arg
}
return parser.utils.MakeTypeCastNode(arg, typeName)
}

func (parser *ParserTypeCast) MakeListValueFromArray(node *pgQuery.Node) *pgQuery.Node {
arrayStr := node.GetAConst().GetSval().Sval
arrayStr = strings.Trim(arrayStr, "{}")
elements := strings.Split(arrayStr, ",")

funcCall := &pgQuery.FuncCall{
Funcname: []*pgQuery.Node{
pgQuery.MakeStrNode("list_value"),
},
}

for _, elem := range elements {
funcCall.Args = append(funcCall.Args,
pgQuery.MakeAConstStrNode(elem, 0))
}

return &pgQuery.Node{
Node: &pgQuery.Node_FuncCall{
FuncCall: funcCall,
},
}
}

func (parser *ParserTypeCast) inferNodeType(node *pgQuery.Node) string {
if typeCast := node.GetTypeCast(); typeCast != nil {
return typeCast.TypeName.Names[0].GetString_().Sval
}

if aConst := node.GetAConst(); aConst != nil {
switch {
case aConst.GetBoolval() != nil:
return "boolean"
case aConst.GetIval() != nil:
return "int8"
case aConst.GetSval() != nil:
return "text"
}
}
return ""
}
23 changes: 18 additions & 5 deletions src/parser_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ func (utils *ParserUtils) MakeSubselectWithRowsNode(tableName string, tableDef T
}

func (utils *ParserUtils) MakeSubselectWithoutRowsNode(tableName string, tableDef TableDefinition, alias string) *pgQuery.Node {
parserType := NewParserType(utils.config)
columnNodes := make([]*pgQuery.Node, len(tableDef.Columns))
for i, col := range tableDef.Columns {
columnNodes[i] = pgQuery.MakeStrNode(col.Name)
Expand All @@ -85,7 +84,7 @@ func (utils *ParserUtils) MakeSubselectWithoutRowsNode(tableName string, tableDe
},
},
}
typedNullNode := parserType.MakeTypeCastNode(nullNode, col.Type)
typedNullNode := utils.MakeTypeCastNode(nullNode, col.Type)
targetList[i] = pgQuery.MakeResTargetNodeWithVal(typedNullNode, 0)
}

Expand Down Expand Up @@ -154,8 +153,6 @@ func (utils *ParserUtils) MakeAConstBoolNode(val bool) *pgQuery.Node {
}

func (utils *ParserUtils) makeTypedConstNode(val string, pgType string) *pgQuery.Node {
parserType := NewParserType(utils.config)

if val == "NULL" {
return &pgQuery.Node{
Node: &pgQuery.Node_AConst{
Expand All @@ -168,5 +165,21 @@ func (utils *ParserUtils) makeTypedConstNode(val string, pgType string) *pgQuery

constNode := pgQuery.MakeAConstStrNode(val, 0)

return parserType.MakeTypeCastNode(constNode, pgType)
return utils.MakeTypeCastNode(constNode, pgType)
}

func (utils *ParserUtils) MakeTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node {
return &pgQuery.Node{
Node: &pgQuery.Node_TypeCast{
TypeCast: &pgQuery.TypeCast{
Arg: arg,
TypeName: &pgQuery.TypeName{
Names: []*pgQuery.Node{
pgQuery.MakeStrNode(typeName),
},
Location: 0,
},
},
},
}
}
5 changes: 5 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,11 @@ func TestHandleQuery(t *testing.T) {
"types": {Uint32ToString(pgtype.Int2OID)},
"values": {"1"},
},
"SELECT 'pg_catalog.array_in'::regproc AS regproc": {
"description": {"regproc"},
"types": {Uint32ToString(pgtype.TextOID)},
"values": {"array_in"},
},

// SELECT * FROM function()
"SELECT * FROM pg_catalog.pg_get_keywords() LIMIT 1": {
Expand Down
Loading

0 comments on commit fe7e7b1

Please sign in to comment.