Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (c *Conn) writeAuthHandshake() error {
c.ccaps&mysql.CLIENT_MULTI_STATEMENTS | c.ccaps&mysql.CLIENT_MULTI_RESULTS |
c.ccaps&mysql.CLIENT_PS_MULTI_RESULTS | c.ccaps&mysql.CLIENT_CONNECT_ATTRS |
c.ccaps&mysql.CLIENT_COMPRESS | c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM |
c.ccaps&mysql.CLIENT_LOCAL_FILES
c.ccaps&mysql.CLIENT_LOCAL_FILES | c.ccaps&mysql.CLIENT_SESSION_TRACK

capability &^= c.clientExplicitOffCaps

Expand Down
15 changes: 11 additions & 4 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,23 @@ func (c *Conn) SetTLSConfig(config *tls.Config) {
}

func (c *Conn) UseDB(dbName string) error {
_, err := c.UseDBWithResult(dbName)
return err
}

func (c *Conn) UseDBWithResult(dbName string) (*mysql.Result, error) {
if err := c.writeCommandStr(mysql.COM_INIT_DB, dbName); err != nil {
return errors.Trace(err)
return nil, errors.Trace(err)
}

if _, err := c.readOK(); err != nil {
return errors.Trace(err)
var r *mysql.Result
var err error
if r, err = c.readOK(); err != nil {
return r, errors.Trace(err)
}

c.db = dbName
return nil
return r, nil
}

func (c *Conn) GetDB() string {
Expand Down
89 changes: 86 additions & 3 deletions client/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,102 @@ func (c *Conn) handleOKPacket(data []byte) (*mysql.Result, error) {

//todo:strict_mode, check warnings as error
r.Warnings = binary.LittleEndian.Uint16(data[pos:])
// pos += 2
pos += 2
} else if c.capability&mysql.CLIENT_TRANSACTIONS > 0 {
r.Status = binary.LittleEndian.Uint16(data[pos:])
c.status = r.Status
// pos += 2
pos += 2
}

// new ok package will check CLIENT_SESSION_TRACK too, but I don't support it now.
if (c.capability&mysql.CLIENT_SESSION_TRACK > 0) &&
(c.status&mysql.SERVER_SESSION_STATE_CHANGED > 0) {
var err error

// Example status message:
// "Records: 3 Duplicates: 0 Warnings: 0"
statusMessageLength := int(data[pos])
pos++
if statusMessageLength > 0 {
r.StatusMessage = utils.ByteSliceToString(data[pos : pos+statusMessageLength])
pos += statusMessageLength
}

sessionTrackingChangeLength := int(data[pos])
pos++
dataLength := len(data[pos:])
if dataLength != sessionTrackingChangeLength {
return nil, fmt.Errorf("incorrect data length for session tracking data: expected %d but got %d",
sessionTrackingChangeLength, dataLength)
}
r.SessionTracking, err = decodeSessionTracking(data[pos:])
if err != nil {
return nil, err
}
}

// skip info
return r, nil
}

func decodeSessionTracking(data []byte) (s *mysql.SessionTrackingInfo, err error) {
s = &mysql.SessionTrackingInfo{}
pos := 0
for pos < len(data) {
sessionTrackingChangeType := data[pos]
pos++ // session tracking type
pos++ // length of session tracking data, unused

switch sessionTrackingChangeType {
case mysql.SESSION_TRACK_SYSTEM_VARIABLES:
if s.Variables == nil {
s.Variables = make(map[string]string, 1)
}
varNameLength := data[pos]
pos++
varName := utils.ByteSliceToString(data[pos : pos+int(varNameLength)])
pos += int(varNameLength)
varValueLength := data[pos]
pos++
s.Variables[varName] = utils.ByteSliceToString(data[pos : pos+int(varValueLength)])
pos += int(varValueLength)
case mysql.SESSION_TRACK_SCHEMA:
schemaInfoLength := data[pos]
pos++
s.Schema = utils.ByteSliceToString(data[pos : pos+int(schemaInfoLength)])
pos += int(schemaInfoLength)
case mysql.SESSION_TRACK_STATE_CHANGE:
s.State = string(data[pos])
pos++
case mysql.SESSION_TRACK_GTIDS:
gtidFormat := data[pos]
if gtidFormat != 0 {
return nil, fmt.Errorf("unexpected GTID format %d", gtidFormat)
}
pos++
gtidLength := data[pos]
pos++
s.GTID = utils.ByteSliceToString(data[pos : pos+int(gtidLength)])
pos += int(gtidLength)
case mysql.SESSION_TRACK_TRANSACTION_CHARACTERISTICS:
characteristicsLength := data[pos]
pos++
if characteristicsLength > 0 {
s.Characteristics = utils.ByteSliceToString(data[pos : pos+int(characteristicsLength)])
pos += int(characteristicsLength)
}
case mysql.SESSION_TRACK_TRANSACTION_STATE:
transactionStateLength := data[pos]
pos++
s.TransactionState = utils.ByteSliceToString(data[pos : pos+int(transactionStateLength)])
pos += int(transactionStateLength)
default:
return nil, fmt.Errorf("got unknown change type %v", sessionTrackingChangeType)
}
}

return s, nil
}

func (c *Conn) handleErrorPacket(data []byte) error {
e := new(mysql.MyError)

Expand Down
72 changes: 72 additions & 0 deletions client/resp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package client

import (
"testing"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/stretchr/testify/require"
)

func TestDecodeSessionTracking(t *testing.T) {
data := []struct {
input []byte
result bool
output *mysql.SessionTrackingInfo
}{
{
[]byte{},
true,
&mysql.SessionTrackingInfo{GTID: "", TransactionState: "", Variables: map[string]string(nil), Schema: "", State: "", Characteristics: ""},
},
{
// schema=mysql, state=1
[]byte{0x1, 0x6, 0x5, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x2, 0x1, 0x31},
true,
&mysql.SessionTrackingInfo{GTID: "", TransactionState: "", Variables: map[string]string(nil), Schema: "mysql", State: "1", Characteristics: ""},
},
{
// got unknown change type 10
[]byte{0xa, 0x6, 0x5, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x2, 0x1, 0x31},
false,
nil,
},
{
// GTID, autocommit, state
[]byte{0x0, 0xf, 0xa, 0x61, 0x75, 0x74, 0x6f, 0x63, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x3, 0x4f, 0x46, 0x46, 0x2, 0x1, 0x31, 0x3, 0x36, 0x0, 0x34, 0x66, 0x34, 0x39, 0x39, 0x33, 0x63, 0x35, 0x65, 0x2d, 0x64, 0x33, 0x35, 0x33, 0x2d, 0x31, 0x31, 0x66, 0x30, 0x2d, 0x39, 0x62, 0x35, 0x66, 0x2d, 0x65, 0x65, 0x64, 0x65, 0x36, 0x64, 0x35, 0x36, 0x32, 0x36, 0x63, 0x38, 0x3a, 0x31, 0x2d, 0x32, 0x33, 0x38, 0x3a, 0x78, 0x6d, 0x61, 0x73, 0x3a, 0x31, 0x2d, 0x32, 0x39},
true,
&mysql.SessionTrackingInfo{GTID: "f4993c5e-d353-11f0-9b5f-eede6d5626c8:1-238:xmas:1-29", TransactionState: "", Variables: map[string]string{"autocommit": "OFF"}, Schema: "", State: "1", Characteristics: ""},
},
{
// Incorrect GTID format
[]byte{0x0, 0xf, 0xa, 0x61, 0x75, 0x74, 0x6f, 0x63, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x3, 0x4f, 0x46, 0x46, 0x2, 0x1, 0x31, 0x3, 0x36, 0x1, 0x34, 0x66, 0x34, 0x39, 0x39, 0x33, 0x63, 0x35, 0x65, 0x2d, 0x64, 0x33, 0x35, 0x33, 0x2d, 0x31, 0x31, 0x66, 0x30, 0x2d, 0x39, 0x62, 0x35, 0x66, 0x2d, 0x65, 0x65, 0x64, 0x65, 0x36, 0x64, 0x35, 0x36, 0x32, 0x36, 0x63, 0x38, 0x3a, 0x31, 0x2d, 0x32, 0x33, 0x38, 0x3a, 0x78, 0x6d, 0x61, 0x73, 0x3a, 0x31, 0x2d, 0x32, 0x39},
false,
nil,
},
{
// TransactionState, Characteristics
[]byte{0x5, 0x9, 0x8, 0x54, 0x5f, 0x5f, 0x5f, 0x5f, 0x5f, 0x5f, 0x5f, 0x4, 0x1d, 0x1c, 0x53, 0x54, 0x41, 0x52, 0x54, 0x20, 0x54, 0x52, 0x41, 0x4e, 0x53, 0x41, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x20, 0x52, 0x45, 0x41, 0x44, 0x20, 0x4f, 0x4e, 0x4c, 0x59, 0x3b},
true,
&mysql.SessionTrackingInfo{GTID: "", TransactionState: "T_______", Variables: map[string]string(nil), Schema: "", State: "", Characteristics: "START TRANSACTION READ ONLY;"},
},
}

for i, tc := range data {
o, err := decodeSessionTracking(tc.input)
if tc.result {
require.NoError(t, err, "case %d", i)
require.Equal(t, tc.output, o, "case %d", i)
} else {
require.Error(t, err, "case %d", i)
}
}
}

func TestHandleOKPacket(t *testing.T) {
c := Conn{
capability: mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SESSION_TRACK,
status: mysql.SERVER_SESSION_STATE_CHANGED,
}
data := []byte{0x0, 0x3, 0x0, 0x2, 0x40, 0x0, 0x0, 0x26, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x3a, 0x20, 0x33, 0x20, 0x20, 0x44, 0x75, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x3a, 0x20, 0x30, 0x20, 0x20, 0x57, 0x61, 0x72, 0x6e, 0x69, 0x6e, 0x67, 0x73, 0x3a, 0x20, 0x30, 0x38, 0x3, 0x36, 0x0, 0x34, 0x66, 0x34, 0x39, 0x39, 0x33, 0x63, 0x35, 0x65, 0x2d, 0x64, 0x33, 0x35, 0x33, 0x2d, 0x31, 0x31, 0x66, 0x30, 0x2d, 0x39, 0x62, 0x35, 0x66, 0x2d, 0x65, 0x65, 0x64, 0x65, 0x36, 0x64, 0x35, 0x36, 0x32, 0x36, 0x63, 0x38, 0x3a, 0x31, 0x2d, 0x32, 0x34, 0x31, 0x3a, 0x78, 0x6d, 0x61, 0x73, 0x3a, 0x31, 0x2d, 0x32, 0x39}
_, err := c.handleOKPacket(data)
require.NoError(t, err)
}
10 changes: 10 additions & 0 deletions mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,13 @@ const (
CURSOR_TYPE_SCROLLABLE byte = 0x4
PARAMETER_COUNT_AVAILABLE byte = 0x8
)

// See enum_session_state_type in mysql_com.h
const (
SESSION_TRACK_SYSTEM_VARIABLES = iota
SESSION_TRACK_SCHEMA
SESSION_TRACK_STATE_CHANGE
SESSION_TRACK_GTIDS
SESSION_TRACK_TRANSACTION_CHARACTERISTICS
SESSION_TRACK_TRANSACTION_STATE
)
12 changes: 12 additions & 0 deletions mysql/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,21 @@ type Result struct {
InsertId uint64
AffectedRows uint64

StatusMessage string
SessionTracking *SessionTrackingInfo

*Resultset
}

type SessionTrackingInfo struct {
GTID string
TransactionState string
Variables map[string]string
Schema string
State string
Characteristics string
}

func NewResult(resultset *Resultset) *Result {
return &Result{
Resultset: resultset,
Expand Down
Loading