diff --git a/src/query_handler.go b/src/query_handler.go index 0dedaa9..4c153fd 100644 --- a/src/query_handler.go +++ b/src/query_handler.go @@ -179,6 +179,10 @@ func (queryHandler *QueryHandler) HandleQuery(originalQuery string) ([]pgproto3. return nil, err } + if query == "" { + return []pgproto3.Message{&pgproto3.EmptyQueryResponse{}}, nil + } + rows, err := queryHandler.duckdb.QueryContext(context.Background(), query) if err != nil { errorMessage := err.Error() @@ -217,23 +221,24 @@ func (queryHandler *QueryHandler) HandleParseQuery(message *pgproto3.Parse) ([]p return nil, nil, err } - statement, err := queryHandler.duckdb.PrepareContext(ctx, query) - if err != nil { - LogError(queryHandler.config, "Couldn't prepare query via DuckDB:", query+"\n"+err.Error()) - return nil, nil, err - } - preparedStatement := &PreparedStatement{ Name: message.Name, OriginalQuery: originalQuery, Query: query, - Statement: statement, ParameterOIDs: message.ParameterOIDs, } + if query == "" { + return []pgproto3.Message{&pgproto3.EmptyQueryResponse{}}, preparedStatement, nil + } - messages := []pgproto3.Message{&pgproto3.ParseComplete{}} + statement, err := queryHandler.duckdb.PrepareContext(ctx, query) + preparedStatement.Statement = statement + if err != nil { + LogError(queryHandler.config, "Couldn't prepare query via DuckDB:", query+"\n"+err.Error()) + return nil, nil, err + } - return messages, preparedStatement, nil + return []pgproto3.Message{&pgproto3.ParseComplete{}}, preparedStatement, nil } func (queryHandler *QueryHandler) HandleBindQuery(message *pgproto3.Bind, preparedStatement *PreparedStatement) ([]pgproto3.Message, *PreparedStatement, error) { @@ -287,6 +292,10 @@ func (queryHandler *QueryHandler) HandleDescribeQuery(message *pgproto3.Describe } } + if preparedStatement.Query == "" { + return []pgproto3.Message{&pgproto3.NoData{}}, preparedStatement, nil + } + if len(preparedStatement.ParameterOIDs) != len(preparedStatement.Variables) { // Bind step didn't happen before return []pgproto3.Message{&pgproto3.NoData{}}, preparedStatement, nil } @@ -311,6 +320,10 @@ func (queryHandler *QueryHandler) HandleExecuteQuery(message *pgproto3.Execute, return nil, errors.New("portal mismatch") } + if preparedStatement.Query == "" { + return []pgproto3.Message{&pgproto3.EmptyQueryResponse{}}, nil + } + if preparedStatement.Rows == nil { // If Describe step didn't have Bind step before rows, err := preparedStatement.Statement.QueryContext(context.Background(), preparedStatement.Variables...) if err != nil { @@ -396,7 +409,7 @@ func (queryHandler *QueryHandler) remapQuery(query string) (string, error) { } if strings.HasSuffix(query, " --INSPECT") { - LogDebug(queryHandler.config, queryTree.Stmts[0].Stmt) + LogDebug(queryHandler.config, queryTree.Stmts) } queryTree.Stmts, err = queryHandler.queryRemapper.RemapStatements(queryTree.Stmts) diff --git a/src/query_handler_test.go b/src/query_handler_test.go index f2ee103..74546c3 100644 --- a/src/query_handler_test.go +++ b/src/query_handler_test.go @@ -233,13 +233,6 @@ func TestHandleQuery(t *testing.T) { "values": {"memory", "public", "test_table"}, }, - // Empty query - "-- ping": { - "description": {"1"}, - "types": {Uint32ToString(pgtype.Int4OID)}, - "values": {"1"}, - }, - // DISCARD "DISCARD ALL": { "description": {"1"}, @@ -921,6 +914,17 @@ func TestHandleQuery(t *testing.T) { testDataRowValues(t, messages[1], []string{"UTC"}) testCommandCompleteTag(t, messages[2], "SHOW") }) + + t.Run("Handles an empty query", func(t *testing.T) { + queryHandler := initQueryHandler() + + messages, err := queryHandler.HandleQuery("-- ping") + + testNoError(t, err) + testMessageTypes(t, messages, []pgproto3.Message{ + &pgproto3.EmptyQueryResponse{}, + }) + }) } func TestHandleParseQuery(t *testing.T) { @@ -1024,6 +1028,22 @@ func TestHandleDescribeQuery(t *testing.T) { } }) + t.Run("Handles DESCRIBE extended query step if query is empty", func(t *testing.T) { + queryHandler := initQueryHandler() + parseMessage := &pgproto3.Parse{Query: ""} + _, preparedStatement, _ := queryHandler.HandleParseQuery(parseMessage) + bindMessage := &pgproto3.Bind{} + _, preparedStatement, _ = queryHandler.HandleBindQuery(bindMessage, preparedStatement) + message := &pgproto3.Describe{ObjectType: 'P'} + + messages, _, err := queryHandler.HandleDescribeQuery(message, preparedStatement) + + testNoError(t, err) + testMessageTypes(t, messages, []pgproto3.Message{ + &pgproto3.NoData{}, + }) + }) + t.Run("Handles DESCRIBE (Statement) extended query step if there was no BIND step", func(t *testing.T) { queryHandler := initQueryHandler() query := "SELECT usename, passwd FROM pg_shadow WHERE usename=$1" @@ -1061,6 +1081,24 @@ func TestHandleExecuteQuery(t *testing.T) { }) testDataRowValues(t, messages[0], []string{"bemidb", "bemidb-encrypted"}) }) + + t.Run("Handles EXECUTE extended query step if query is empty", func(t *testing.T) { + queryHandler := initQueryHandler() + parseMessage := &pgproto3.Parse{Query: ""} + _, preparedStatement, _ := queryHandler.HandleParseQuery(parseMessage) + bindMessage := &pgproto3.Bind{} + _, preparedStatement, _ = queryHandler.HandleBindQuery(bindMessage, preparedStatement) + describeMessage := &pgproto3.Describe{ObjectType: 'P'} + _, preparedStatement, _ = queryHandler.HandleDescribeQuery(describeMessage, preparedStatement) + message := &pgproto3.Execute{} + + messages, err := queryHandler.HandleExecuteQuery(message, preparedStatement) + + testNoError(t, err) + testMessageTypes(t, messages, []pgproto3.Message{ + &pgproto3.EmptyQueryResponse{}, + }) + }) } func TestHandleMultipleQueries(t *testing.T) { diff --git a/src/query_remapper.go b/src/query_remapper.go index 549a3da..a063949 100644 --- a/src/query_remapper.go +++ b/src/query_remapper.go @@ -52,17 +52,18 @@ func NewQueryRemapper(config *Config, icebergReader *IcebergReader, duckdb *Duck } func (remapper *QueryRemapper) RemapStatements(statements []*pgQuery.RawStmt) ([]*pgQuery.RawStmt, error) { + // Empty query if len(statements) == 0 { - return FALLBACK_QUERY_TREE.Stmts, nil + return statements, nil } for i, stmt := range statements { node := stmt.Stmt switch { - // Empty query + // Empty statement case node == nil: - return nil, errors.New("empty query") + return nil, errors.New("empty statement") // SELECT ... case node.GetSelectStmt() != nil: