Skip to content

Commit 647b192

Browse files
committed
Add a test
1 parent d6fca48 commit 647b192

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

server/stmt_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package server
33
import (
44
"testing"
55

6+
"github.com/go-mysql-org/go-mysql/mysql"
67
"github.com/stretchr/testify/require"
78
)
89

@@ -46,3 +47,52 @@ func TestHandleStmtExecute(t *testing.T) {
4647
}
4748
}
4849
}
50+
51+
type mockPrepareHandler struct {
52+
EmptyHandler
53+
context any
54+
paramCount, columnCount int
55+
}
56+
57+
func TestStmtPrepareWithoutFieldsProvider(t *testing.T) {
58+
c := &Conn{
59+
h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1},
60+
stmts: make(map[uint32]*Stmt),
61+
}
62+
63+
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...))
64+
65+
stmt := result.(*Stmt)
66+
require.Nil(t, stmt.ParamFields)
67+
require.Nil(t, stmt.ColumnFields)
68+
}
69+
70+
type mockFieldsProvider struct {
71+
paramFields, columnFields []*mysql.Field
72+
}
73+
74+
func (m *mockFieldsProvider) GetParamFields() []*mysql.Field { return m.paramFields }
75+
func (m *mockFieldsProvider) GetColumnFields() []*mysql.Field { return m.columnFields }
76+
77+
func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, error) {
78+
return h.paramCount, h.columnCount, h.context, nil
79+
}
80+
81+
func TestStmtPrepareWithFieldsProvider(t *testing.T) {
82+
provider := &mockFieldsProvider{
83+
paramFields: []*mysql.Field{{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}},
84+
columnFields: []*mysql.Field{{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}},
85+
}
86+
c := &Conn{
87+
h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1},
88+
stmts: make(map[uint32]*Stmt),
89+
}
90+
91+
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...))
92+
93+
stmt := result.(*Stmt)
94+
require.NotNil(t, stmt.ParamFields)
95+
require.NotNil(t, stmt.ColumnFields)
96+
require.Equal(t, mysql.MYSQL_TYPE_LONG, stmt.ParamFields[0].Type)
97+
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, stmt.ColumnFields[0].Type)
98+
}

0 commit comments

Comments
 (0)