Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions client/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ type Stmt struct {
params int
columns int
warnings int

// Field definitions from the PREPARE response (for proxy passthrough)
paramFields []*mysql.Field
columnFields []*mysql.Field
}

func (s *Stmt) ParamNum() int {
Expand All @@ -33,6 +37,20 @@ func (s *Stmt) WarningsNum() int {
return s.warnings
}

// GetParamFields returns the parameter field definitions from the PREPARE response.
// Implements server.StmtFieldsProvider for proxy passthrough.
// The caller should not modify the returned slice.
func (s *Stmt) GetParamFields() []*mysql.Field {
return s.paramFields
}

// GetColumnFields returns the column field definitions from the PREPARE response.
// Implements server.StmtFieldsProvider for proxy passthrough.
// The caller should not modify the returned slice.
func (s *Stmt) GetColumnFields() []*mysql.Field {
return s.columnFields
}

func (s *Stmt) Execute(args ...interface{}) (*mysql.Result, error) {
if err := s.write(args...); err != nil {
return nil, errors.Trace(err)
Expand Down Expand Up @@ -275,8 +293,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
}

if s.params > 0 {
for range s.params {
if _, err := s.conn.ReadPacket(); err != nil {
s.paramFields = make([]*mysql.Field, s.params)
for i := range s.params {
data, err := s.conn.ReadPacket()
if err != nil {
return nil, errors.Trace(err)
}
s.paramFields[i] = &mysql.Field{}
if err := s.paramFields[i].Parse(data); err != nil {
return nil, errors.Trace(err)
}
}
Expand All @@ -290,9 +314,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
}

if s.columns > 0 {
// TODO process when CLIENT_CACHE_METADATA enabled
for range s.columns {
if _, err := s.conn.ReadPacket(); err != nil {
s.columnFields = make([]*mysql.Field, s.columns)
for i := range s.columns {
data, err := s.conn.ReadPacket()
if err != nil {
return nil, errors.Trace(err)
}
s.columnFields[i] = &mysql.Field{}
if err := s.columnFields[i].Parse(data); err != nil {
return nil, errors.Trace(err)
}
}
Expand Down
13 changes: 13 additions & 0 deletions server/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ import (
"github.com/go-mysql-org/go-mysql/utils"
)

// StmtFieldsProvider is an optional interface that prepared statement contexts can implement
// to provide field definitions for proxy passthrough scenarios.
type StmtFieldsProvider interface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to verify why we need this interface? If the only reason is to cast st.Context, I think maybe we should move the common struct of client and server package to a common package, like utils or a new package, to avoid a runtime cast.

Copy link
Contributor Author

@ramnes ramnes Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting that we create something like

type StmtFields struct {
    ParamFields []*Field
    ClumnFields []*Field
}

and then use it in client.Stmt?

GetParamFields() []*mysql.Field
GetColumnFields() []*mysql.Field
}

// Handler is what a server needs to implement the client-server protocol
type Handler interface {
// handle COM_INIT_DB command, you can check whether the dbName is valid, or other.
Expand Down Expand Up @@ -112,6 +119,12 @@ func (c *Conn) dispatch(data []byte) interface{} {
if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil {
return err
} else {
// If context provides field definitions (e.g., from a backend prepared statement),
// use them for accurate metadata passthrough in proxy scenarios.
if provider, ok := st.Context.(StmtFieldsProvider); ok {
st.ParamFields = provider.GetParamFields()
st.ColumnFields = provider.GetColumnFields()
}
st.ResetParams()
c.stmts[c.stmtID] = st
return st
Expand Down
16 changes: 14 additions & 2 deletions server/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type Stmt struct {
Args []interface{}

Context interface{}

// Field definitions for proxy passthrough (optional, uses dummy fields if nil)
ParamFields []*mysql.Field
ColumnFields []*mysql.Field
}

func (s *Stmt) Rest(params int, columns int, context interface{}) {
Expand Down Expand Up @@ -61,7 +65,11 @@ func (c *Conn) writePrepare(s *Stmt) error {
if s.Params > 0 {
for i := 0; i < s.Params; i++ {
data = data[0:4]
data = append(data, paramFieldData...)
if s.ParamFields != nil && i < len(s.ParamFields) {
data = append(data, s.ParamFields[i].Dump()...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems no need to parse it? We can only forward the raw value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale was that for a high-level library that users can use to inspect fields metadata, it makes sense that they don't have to parse the bytes themselves. But for a pure proxy use-case, yeah, the raw value would be better. Both work for me!

} else {
data = append(data, paramFieldData...)
}

if err := c.WritePacket(data); err != nil {
return errors.Trace(err)
Expand All @@ -76,7 +84,11 @@ func (c *Conn) writePrepare(s *Stmt) error {
if s.Columns > 0 {
for i := 0; i < s.Columns; i++ {
data = data[0:4]
data = append(data, columnFieldData...)
if s.ColumnFields != nil && i < len(s.ColumnFields) {
data = append(data, s.ColumnFields[i].Dump()...)
} else {
data = append(data, columnFieldData...)
}

if err := c.WritePacket(data); err != nil {
return errors.Trace(err)
Expand Down
50 changes: 50 additions & 0 deletions server/stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"testing"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -46,3 +47,52 @@ func TestHandleStmtExecute(t *testing.T) {
}
}
}

type mockPrepareHandler struct {
EmptyHandler
context any
paramCount, columnCount int
}

func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, error) {
return h.paramCount, h.columnCount, h.context, nil
}

func TestStmtPrepareWithoutFieldsProvider(t *testing.T) {
c := &Conn{
h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1},
stmts: make(map[uint32]*Stmt),
}

result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...))

stmt := result.(*Stmt)
require.Nil(t, stmt.ParamFields)
require.Nil(t, stmt.ColumnFields)
}

type mockFieldsProvider struct {
paramFields, columnFields []*mysql.Field
}

func (m *mockFieldsProvider) GetParamFields() []*mysql.Field { return m.paramFields }
func (m *mockFieldsProvider) GetColumnFields() []*mysql.Field { return m.columnFields }

func TestStmtPrepareWithFieldsProvider(t *testing.T) {
provider := &mockFieldsProvider{
paramFields: []*mysql.Field{{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}},
columnFields: []*mysql.Field{{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}},
}
c := &Conn{
h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1},
stmts: make(map[uint32]*Stmt),
}

result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...))

stmt := result.(*Stmt)
require.NotNil(t, stmt.ParamFields)
require.NotNil(t, stmt.ColumnFields)
require.Equal(t, mysql.MYSQL_TYPE_LONG, stmt.ParamFields[0].Type)
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, stmt.ColumnFields[0].Type)
}
Loading