@@ -3,6 +3,7 @@ package server
33import (
44 "testing"
55
6+ "github.com/go-mysql-org/go-mysql/mysql"
67 "github.com/stretchr/testify/require"
78)
89
@@ -46,3 +47,52 @@ func TestHandleStmtExecute(t *testing.T) {
4647 }
4748 }
4849}
50+
51+ type mockPrepareHandler struct {
52+ EmptyHandler
53+ context any
54+ paramCount , columnCount int
55+ }
56+
57+ func TestStmtPrepareWithoutFieldsProvider (t * testing.T ) {
58+ c := & Conn {
59+ h : & mockPrepareHandler {context : "plain string" , paramCount : 1 , columnCount : 1 },
60+ stmts : make (map [uint32 ]* Stmt ),
61+ }
62+
63+ result := c .dispatch (append ([]byte {mysql .COM_STMT_PREPARE }, "SELECT * FROM t" ... ))
64+
65+ stmt := result .(* Stmt )
66+ require .Nil (t , stmt .ParamFields )
67+ require .Nil (t , stmt .ColumnFields )
68+ }
69+
70+ type mockFieldsProvider struct {
71+ paramFields , columnFields []* mysql.Field
72+ }
73+
74+ func (m * mockFieldsProvider ) GetParamFields () []* mysql.Field { return m .paramFields }
75+ func (m * mockFieldsProvider ) GetColumnFields () []* mysql.Field { return m .columnFields }
76+
77+ func (h * mockPrepareHandler ) HandleStmtPrepare (query string ) (int , int , any , error ) {
78+ return h .paramCount , h .columnCount , h .context , nil
79+ }
80+
81+ func TestStmtPrepareWithFieldsProvider (t * testing.T ) {
82+ provider := & mockFieldsProvider {
83+ paramFields : []* mysql.Field {{Name : []byte ("?" ), Type : mysql .MYSQL_TYPE_LONG }},
84+ columnFields : []* mysql.Field {{Name : []byte ("id" ), Type : mysql .MYSQL_TYPE_LONGLONG }},
85+ }
86+ c := & Conn {
87+ h : & mockPrepareHandler {context : provider , paramCount : 1 , columnCount : 1 },
88+ stmts : make (map [uint32 ]* Stmt ),
89+ }
90+
91+ result := c .dispatch (append ([]byte {mysql .COM_STMT_PREPARE }, "SELECT id FROM t WHERE id = ?" ... ))
92+
93+ stmt := result .(* Stmt )
94+ require .NotNil (t , stmt .ParamFields )
95+ require .NotNil (t , stmt .ColumnFields )
96+ require .Equal (t , mysql .MYSQL_TYPE_LONG , stmt .ParamFields [0 ].Type )
97+ require .Equal (t , mysql .MYSQL_TYPE_LONGLONG , stmt .ColumnFields [0 ].Type )
98+ }
0 commit comments