@@ -19,6 +19,10 @@ type Stmt struct {
1919 params int
2020 columns int
2121 warnings int
22+
23+ // Field definitions from the PREPARE response (for proxy passthrough)
24+ ParamFields []* mysql.Field
25+ ColumnFields []* mysql.Field
2226}
2327
2428func (s * Stmt ) ParamNum () int {
@@ -33,6 +37,18 @@ func (s *Stmt) WarningsNum() int {
3337 return s .warnings
3438}
3539
40+ // GetParamFields returns the parameter field definitions from the PREPARE response.
41+ // Implements server.StmtFieldsProvider for proxy passthrough.
42+ func (s * Stmt ) GetParamFields () []* mysql.Field {
43+ return s .ParamFields
44+ }
45+
46+ // GetColumnFields returns the column field definitions from the PREPARE response.
47+ // Implements server.StmtFieldsProvider for proxy passthrough.
48+ func (s * Stmt ) GetColumnFields () []* mysql.Field {
49+ return s .ColumnFields
50+ }
51+
3652func (s * Stmt ) Execute (args ... interface {}) (* mysql.Result , error ) {
3753 if err := s .write (args ... ); err != nil {
3854 return nil , errors .Trace (err )
@@ -275,8 +291,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
275291 }
276292
277293 if s .params > 0 {
278- for range s .params {
279- if _ , err := s .conn .ReadPacket (); err != nil {
294+ s .ParamFields = make ([]* mysql.Field , s .params )
295+ for i := range s .params {
296+ data , err := s .conn .ReadPacket ()
297+ if err != nil {
298+ return nil , errors .Trace (err )
299+ }
300+ s .ParamFields [i ] = & mysql.Field {}
301+ if err := s .ParamFields [i ].Parse (data ); err != nil {
280302 return nil , errors .Trace (err )
281303 }
282304 }
@@ -290,9 +312,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
290312 }
291313
292314 if s .columns > 0 {
293- // TODO process when CLIENT_CACHE_METADATA enabled
294- for range s .columns {
295- if _ , err := s .conn .ReadPacket (); err != nil {
315+ s .ColumnFields = make ([]* mysql.Field , s .columns )
316+ for i := range s .columns {
317+ data , err := s .conn .ReadPacket ()
318+ if err != nil {
319+ return nil , errors .Trace (err )
320+ }
321+ s .ColumnFields [i ] = & mysql.Field {}
322+ if err := s .ColumnFields [i ].Parse (data ); err != nil {
296323 return nil , errors .Trace (err )
297324 }
298325 }
0 commit comments