Skip to content

Commit 8e13d8e

Browse files
committed
Add a test
1 parent d6fca48 commit 8e13d8e

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

server/stmt_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package server
33
import (
44
"testing"
55

6+
"github.com/go-mysql-org/go-mysql/mysql"
67
"github.com/stretchr/testify/require"
78
)
89

@@ -46,3 +47,52 @@ func TestHandleStmtExecute(t *testing.T) {
4647
}
4748
}
4849
}
50+
51+
type mockPrepareHandler struct {
52+
EmptyHandler
53+
context any
54+
paramCount, columnCount int
55+
}
56+
57+
func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, error) {
58+
return h.paramCount, h.columnCount, h.context, nil
59+
}
60+
61+
func TestStmtPrepareWithoutFieldsProvider(t *testing.T) {
62+
c := &Conn{
63+
h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1},
64+
stmts: make(map[uint32]*Stmt),
65+
}
66+
67+
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...))
68+
69+
stmt := result.(*Stmt)
70+
require.Nil(t, stmt.ParamFields)
71+
require.Nil(t, stmt.ColumnFields)
72+
}
73+
74+
type mockFieldsProvider struct {
75+
paramFields, columnFields []*mysql.Field
76+
}
77+
78+
func (m *mockFieldsProvider) GetParamFields() []*mysql.Field { return m.paramFields }
79+
func (m *mockFieldsProvider) GetColumnFields() []*mysql.Field { return m.columnFields }
80+
81+
func TestStmtPrepareWithFieldsProvider(t *testing.T) {
82+
provider := &mockFieldsProvider{
83+
paramFields: []*mysql.Field{{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}},
84+
columnFields: []*mysql.Field{{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}},
85+
}
86+
c := &Conn{
87+
h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1},
88+
stmts: make(map[uint32]*Stmt),
89+
}
90+
91+
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...))
92+
93+
stmt := result.(*Stmt)
94+
require.NotNil(t, stmt.ParamFields)
95+
require.NotNil(t, stmt.ColumnFields)
96+
require.Equal(t, mysql.MYSQL_TYPE_LONG, stmt.ParamFields[0].Type)
97+
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, stmt.ColumnFields[0].Type)
98+
}

0 commit comments

Comments
 (0)