Skip to content

Commit 0d3c2e3

Browse files
authored
client: Add support for Session Tracking (#1076)
* client: Add support for Session Tracking * fixup
1 parent 8fd2c2a commit 0d3c2e3

File tree

6 files changed

+192
-8
lines changed

6 files changed

+192
-8
lines changed

client/auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func (c *Conn) writeAuthHandshake() error {
218218
c.ccaps&mysql.CLIENT_MULTI_STATEMENTS | c.ccaps&mysql.CLIENT_MULTI_RESULTS |
219219
c.ccaps&mysql.CLIENT_PS_MULTI_RESULTS | c.ccaps&mysql.CLIENT_CONNECT_ATTRS |
220220
c.ccaps&mysql.CLIENT_COMPRESS | c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM |
221-
c.ccaps&mysql.CLIENT_LOCAL_FILES
221+
c.ccaps&mysql.CLIENT_LOCAL_FILES | c.ccaps&mysql.CLIENT_SESSION_TRACK
222222

223223
capability &^= c.clientExplicitOffCaps
224224

client/conn.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,16 +268,23 @@ func (c *Conn) SetTLSConfig(config *tls.Config) {
268268
}
269269

270270
func (c *Conn) UseDB(dbName string) error {
271+
_, err := c.UseDBWithResult(dbName)
272+
return err
273+
}
274+
275+
func (c *Conn) UseDBWithResult(dbName string) (*mysql.Result, error) {
271276
if err := c.writeCommandStr(mysql.COM_INIT_DB, dbName); err != nil {
272-
return errors.Trace(err)
277+
return nil, errors.Trace(err)
273278
}
274279

275-
if _, err := c.readOK(); err != nil {
276-
return errors.Trace(err)
280+
var r *mysql.Result
281+
var err error
282+
if r, err = c.readOK(); err != nil {
283+
return r, errors.Trace(err)
277284
}
278285

279286
c.db = dbName
280-
return nil
287+
return r, nil
281288
}
282289

283290
func (c *Conn) GetDB() string {

client/resp.go

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,102 @@ func (c *Conn) handleOKPacket(data []byte) (*mysql.Result, error) {
3939

4040
//todo:strict_mode, check warnings as error
4141
r.Warnings = binary.LittleEndian.Uint16(data[pos:])
42-
// pos += 2
42+
pos += 2
4343
} else if c.capability&mysql.CLIENT_TRANSACTIONS > 0 {
4444
r.Status = binary.LittleEndian.Uint16(data[pos:])
4545
c.status = r.Status
46-
// pos += 2
46+
pos += 2
4747
}
4848

49-
// new ok package will check CLIENT_SESSION_TRACK too, but I don't support it now.
49+
if (c.capability&mysql.CLIENT_SESSION_TRACK > 0) &&
50+
(c.status&mysql.SERVER_SESSION_STATE_CHANGED > 0) {
51+
var err error
52+
53+
// Example status message:
54+
// "Records: 3 Duplicates: 0 Warnings: 0"
55+
statusMessageLength := int(data[pos])
56+
pos++
57+
if statusMessageLength > 0 {
58+
r.StatusMessage = utils.ByteSliceToString(data[pos : pos+statusMessageLength])
59+
pos += statusMessageLength
60+
}
61+
62+
sessionTrackingChangeLength := int(data[pos])
63+
pos++
64+
dataLength := len(data[pos:])
65+
if dataLength != sessionTrackingChangeLength {
66+
return nil, fmt.Errorf("incorrect data length for session tracking data: expected %d but got %d",
67+
sessionTrackingChangeLength, dataLength)
68+
}
69+
r.SessionTracking, err = decodeSessionTracking(data[pos:])
70+
if err != nil {
71+
return nil, err
72+
}
73+
}
5074

5175
// skip info
5276
return r, nil
5377
}
5478

79+
func decodeSessionTracking(data []byte) (s *mysql.SessionTrackingInfo, err error) {
80+
s = &mysql.SessionTrackingInfo{}
81+
pos := 0
82+
for pos < len(data) {
83+
sessionTrackingChangeType := data[pos]
84+
pos++ // session tracking type
85+
pos++ // length of session tracking data, unused
86+
87+
switch sessionTrackingChangeType {
88+
case mysql.SESSION_TRACK_SYSTEM_VARIABLES:
89+
if s.Variables == nil {
90+
s.Variables = make(map[string]string, 1)
91+
}
92+
varNameLength := data[pos]
93+
pos++
94+
varName := utils.ByteSliceToString(data[pos : pos+int(varNameLength)])
95+
pos += int(varNameLength)
96+
varValueLength := data[pos]
97+
pos++
98+
s.Variables[varName] = utils.ByteSliceToString(data[pos : pos+int(varValueLength)])
99+
pos += int(varValueLength)
100+
case mysql.SESSION_TRACK_SCHEMA:
101+
schemaInfoLength := data[pos]
102+
pos++
103+
s.Schema = utils.ByteSliceToString(data[pos : pos+int(schemaInfoLength)])
104+
pos += int(schemaInfoLength)
105+
case mysql.SESSION_TRACK_STATE_CHANGE:
106+
s.State = string(data[pos])
107+
pos++
108+
case mysql.SESSION_TRACK_GTIDS:
109+
gtidFormat := data[pos]
110+
if gtidFormat != 0 {
111+
return nil, fmt.Errorf("unexpected GTID format %d", gtidFormat)
112+
}
113+
pos++
114+
gtidLength := data[pos]
115+
pos++
116+
s.GTID = utils.ByteSliceToString(data[pos : pos+int(gtidLength)])
117+
pos += int(gtidLength)
118+
case mysql.SESSION_TRACK_TRANSACTION_CHARACTERISTICS:
119+
characteristicsLength := data[pos]
120+
pos++
121+
if characteristicsLength > 0 {
122+
s.Characteristics = utils.ByteSliceToString(data[pos : pos+int(characteristicsLength)])
123+
pos += int(characteristicsLength)
124+
}
125+
case mysql.SESSION_TRACK_TRANSACTION_STATE:
126+
transactionStateLength := data[pos]
127+
pos++
128+
s.TransactionState = utils.ByteSliceToString(data[pos : pos+int(transactionStateLength)])
129+
pos += int(transactionStateLength)
130+
default:
131+
return nil, fmt.Errorf("got unknown change type %v", sessionTrackingChangeType)
132+
}
133+
}
134+
135+
return s, nil
136+
}
137+
55138
func (c *Conn) handleErrorPacket(data []byte) error {
56139
e := new(mysql.MyError)
57140

client/resp_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package client
2+
3+
import (
4+
"testing"
5+
6+
"github.com/go-mysql-org/go-mysql/mysql"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestDecodeSessionTracking(t *testing.T) {
11+
data := []struct {
12+
input []byte
13+
result bool
14+
output *mysql.SessionTrackingInfo
15+
}{
16+
{
17+
[]byte{},
18+
true,
19+
&mysql.SessionTrackingInfo{GTID: "", TransactionState: "", Variables: map[string]string(nil), Schema: "", State: "", Characteristics: ""},
20+
},
21+
{
22+
// schema=mysql, state=1
23+
[]byte{0x1, 0x6, 0x5, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x2, 0x1, 0x31},
24+
true,
25+
&mysql.SessionTrackingInfo{GTID: "", TransactionState: "", Variables: map[string]string(nil), Schema: "mysql", State: "1", Characteristics: ""},
26+
},
27+
{
28+
// got unknown change type 10
29+
[]byte{0xa, 0x6, 0x5, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x2, 0x1, 0x31},
30+
false,
31+
nil,
32+
},
33+
{
34+
// GTID, autocommit, state
35+
[]byte{0x0, 0xf, 0xa, 0x61, 0x75, 0x74, 0x6f, 0x63, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x3, 0x4f, 0x46, 0x46, 0x2, 0x1, 0x31, 0x3, 0x36, 0x0, 0x34, 0x66, 0x34, 0x39, 0x39, 0x33, 0x63, 0x35, 0x65, 0x2d, 0x64, 0x33, 0x35, 0x33, 0x2d, 0x31, 0x31, 0x66, 0x30, 0x2d, 0x39, 0x62, 0x35, 0x66, 0x2d, 0x65, 0x65, 0x64, 0x65, 0x36, 0x64, 0x35, 0x36, 0x32, 0x36, 0x63, 0x38, 0x3a, 0x31, 0x2d, 0x32, 0x33, 0x38, 0x3a, 0x78, 0x6d, 0x61, 0x73, 0x3a, 0x31, 0x2d, 0x32, 0x39},
36+
true,
37+
&mysql.SessionTrackingInfo{GTID: "f4993c5e-d353-11f0-9b5f-eede6d5626c8:1-238:xmas:1-29", TransactionState: "", Variables: map[string]string{"autocommit": "OFF"}, Schema: "", State: "1", Characteristics: ""},
38+
},
39+
{
40+
// Incorrect GTID format
41+
[]byte{0x0, 0xf, 0xa, 0x61, 0x75, 0x74, 0x6f, 0x63, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x3, 0x4f, 0x46, 0x46, 0x2, 0x1, 0x31, 0x3, 0x36, 0x1, 0x34, 0x66, 0x34, 0x39, 0x39, 0x33, 0x63, 0x35, 0x65, 0x2d, 0x64, 0x33, 0x35, 0x33, 0x2d, 0x31, 0x31, 0x66, 0x30, 0x2d, 0x39, 0x62, 0x35, 0x66, 0x2d, 0x65, 0x65, 0x64, 0x65, 0x36, 0x64, 0x35, 0x36, 0x32, 0x36, 0x63, 0x38, 0x3a, 0x31, 0x2d, 0x32, 0x33, 0x38, 0x3a, 0x78, 0x6d, 0x61, 0x73, 0x3a, 0x31, 0x2d, 0x32, 0x39},
42+
false,
43+
nil,
44+
},
45+
{
46+
// TransactionState, Characteristics
47+
[]byte{0x5, 0x9, 0x8, 0x54, 0x5f, 0x5f, 0x5f, 0x5f, 0x5f, 0x5f, 0x5f, 0x4, 0x1d, 0x1c, 0x53, 0x54, 0x41, 0x52, 0x54, 0x20, 0x54, 0x52, 0x41, 0x4e, 0x53, 0x41, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x20, 0x52, 0x45, 0x41, 0x44, 0x20, 0x4f, 0x4e, 0x4c, 0x59, 0x3b},
48+
true,
49+
&mysql.SessionTrackingInfo{GTID: "", TransactionState: "T_______", Variables: map[string]string(nil), Schema: "", State: "", Characteristics: "START TRANSACTION READ ONLY;"},
50+
},
51+
}
52+
53+
for i, tc := range data {
54+
o, err := decodeSessionTracking(tc.input)
55+
if tc.result {
56+
require.NoError(t, err, "case %d", i)
57+
require.Equal(t, tc.output, o, "case %d", i)
58+
} else {
59+
require.Error(t, err, "case %d", i)
60+
}
61+
}
62+
}
63+
64+
func TestHandleOKPacket(t *testing.T) {
65+
c := Conn{
66+
capability: mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SESSION_TRACK,
67+
status: mysql.SERVER_SESSION_STATE_CHANGED,
68+
}
69+
data := []byte{0x0, 0x3, 0x0, 0x2, 0x40, 0x0, 0x0, 0x26, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x3a, 0x20, 0x33, 0x20, 0x20, 0x44, 0x75, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x3a, 0x20, 0x30, 0x20, 0x20, 0x57, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, 0x3a, 0x20, 0x30, 0x38, 0x3, 0x36, 0x0, 0x34, 0x66, 0x34, 0x39, 0x39, 0x33, 0x63, 0x35, 0x65, 0x2d, 0x64, 0x33, 0x35, 0x33, 0x2d, 0x31, 0x31, 0x66, 0x30, 0x2d, 0x39, 0x62, 0x35, 0x66, 0x2d, 0x65, 0x65, 0x64, 0x65, 0x36, 0x64, 0x35, 0x36, 0x32, 0x36, 0x63, 0x38, 0x3a, 0x31, 0x2d, 0x32, 0x34, 0x31, 0x3a, 0x78, 0x6d, 0x61, 0x73, 0x3a, 0x31, 0x2d, 0x32, 0x39}
70+
_, err := c.handleOKPacket(data)
71+
require.NoError(t, err)
72+
}

mysql/const.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,13 @@ const (
225225
CURSOR_TYPE_SCROLLABLE byte = 0x4
226226
PARAMETER_COUNT_AVAILABLE byte = 0x8
227227
)
228+
229+
// See enum_session_state_type in mysql_com.h
230+
const (
231+
SESSION_TRACK_SYSTEM_VARIABLES = iota
232+
SESSION_TRACK_SCHEMA
233+
SESSION_TRACK_STATE_CHANGE
234+
SESSION_TRACK_GTIDS
235+
SESSION_TRACK_TRANSACTION_CHARACTERISTICS
236+
SESSION_TRACK_TRANSACTION_STATE
237+
)

mysql/result.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,21 @@ type Result struct {
99
InsertId uint64
1010
AffectedRows uint64
1111

12+
StatusMessage string
13+
SessionTracking *SessionTrackingInfo
14+
1215
*Resultset
1316
}
1417

18+
type SessionTrackingInfo struct {
19+
GTID string
20+
TransactionState string
21+
Variables map[string]string
22+
Schema string
23+
State string
24+
Characteristics string
25+
}
26+
1527
func NewResult(resultset *Resultset) *Result {
1628
return &Result{
1729
Resultset: resultset,

0 commit comments

Comments
 (0)