Skip to content

Commit 4cd040a

Browse files
authored
Merge pull request #305 from se7entyse7en/query-cancel-conn-id
Modifies killQuery in order to use the connection id
2 parents 6dcbed5 + 8134886 commit 4cd040a

File tree

4 files changed

+56
-52
lines changed

4 files changed

+56
-52
lines changed

server/handler/query.go

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"strconv"
1212
"strings"
1313

14-
"github.com/pressly/lg"
1514
"github.com/src-d/gitbase-web/server/serializer"
1615
"github.com/src-d/gitbase-web/server/service"
1716

@@ -74,8 +73,16 @@ func Query(db service.SQLDB) RequestProcessFunc {
7473
c := make(chan error, 1)
7574

7675
var rows *sql.Rows
76+
conn, err := db.Conn(r.Context())
77+
defer conn.Close()
78+
79+
connID, err := getConnID(r, conn)
80+
if err != nil {
81+
return nil, fmt.Errorf("failed to get connection id: %s", err)
82+
}
83+
7784
go func() {
78-
rows, err = db.QueryContext(r.Context(), query)
85+
rows, err = conn.QueryContext(r.Context(), query)
7986
c <- err
8087
}()
8188

@@ -88,13 +95,14 @@ func Query(db service.SQLDB) RequestProcessFunc {
8895
}
8996

9097
if r.Context().Err() != nil {
91-
killQuery(r, db, query)
98+
db.Exec(fmt.Sprintf("KILL %d", connID))
9299
return nil, dbError(r.Context().Err())
93100
}
94101

95102
if err != nil {
96103
return nil, dbError(err)
97104
}
105+
98106
defer rows.Close()
99107

100108
columnNames, columnTypes, err := columnsInfo(rows)
@@ -128,46 +136,20 @@ func Query(db service.SQLDB) RequestProcessFunc {
128136
}
129137
}
130138

131-
func killQuery(r *http.Request, db service.SQLDB, query string) {
132-
const showProcessList = "SHOW FULL PROCESSLIST"
133-
pRows, pErr := db.Query(showProcessList)
134-
if pErr != nil {
135-
lg.RequestLog(r).WithError(pErr).Errorf("failed to execute %q", showProcessList)
136-
return
137-
}
138-
defer pRows.Close()
139-
140-
found := false
141-
var foundID int
142-
143-
for pRows.Next() {
144-
var id int
145-
var info sql.NullString
146-
var rb sql.RawBytes
147-
// The columns are:
148-
// Id, User, Host, db, Command, Time, State, Info
149-
// gitbase returns the query on "Info".
150-
if err := pRows.Scan(&id, &rb, &rb, &rb, &rb, &rb, &rb, &info); err != nil {
151-
lg.RequestLog(r).WithError(err).Errorf("failed to scan the results of %q", showProcessList)
152-
return
153-
}
154-
155-
if info.Valid && info.String == query {
156-
if found {
157-
// Found more than one match for current query, we cannot know which
158-
// one is ours. Skip the cancellation
159-
lg.RequestLog(r).Errorf("cannot cancel the query, found more than one match in gitbase")
160-
return
161-
}
139+
func getConnID(r *http.Request, conn *sql.Conn) (uint32, error) {
140+
const connIDQuery = "SELECT CONNECTION_ID()"
141+
var connID uint32
162142

163-
found = true
164-
foundID = id
165-
}
143+
row := conn.QueryRowContext(context.Background(), connIDQuery)
144+
if row == nil {
145+
return 0, fmt.Errorf("failed to execute %q", connIDQuery)
166146
}
167147

168-
if found {
169-
db.Exec(fmt.Sprintf("KILL %d", foundID))
148+
if err := row.Scan(&connID); err != nil {
149+
return 0, fmt.Errorf("failed to scan the results of %q: %s", connIDQuery, err)
170150
}
151+
152+
return connID, nil
171153
}
172154

173155
// columnsInfo returns the column names and column types, or error

server/handler/query_test.go

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package handler
22

33
import (
44
"context"
5+
"database/sql"
56
"fmt"
67
"net/http"
78
"net/http/httptest"
@@ -14,8 +15,9 @@ import (
1415

1516
"github.com/pressly/lg"
1617
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
1719
"github.com/stretchr/testify/suite"
18-
"gopkg.in/DATA-DOG/go-sqlmock.v1"
20+
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
1921
"gopkg.in/bblfsh/sdk.v2/uast/nodes"
2022
)
2123

@@ -43,7 +45,7 @@ func (suite *QuerySuite) TestAddLimit() {
4345
`, "SELECT * FROM repositories LIMIT 100"},
4446
{" SELECT * FROM repositories ", "SELECT * FROM repositories LIMIT 100"},
4547
{" SELECT * FROM repositories ; ", "SELECT * FROM repositories LIMIT 100"},
46-
{` SELECT * FROM repositories
48+
{` SELECT * FROM repositories
4749
; `, "SELECT * FROM repositories LIMIT 100"},
4850
{"/* comment */ SELECT * FROM repositories", "SELECT * FROM repositories LIMIT 100"},
4951
{"SELECT * FROM repositories /* comment */", "SELECT * FROM repositories LIMIT 100"},
@@ -56,7 +58,7 @@ func (suite *QuerySuite) TestAddLimit() {
5658
{"select * from repositories limit 1 ;", "select * from repositories limit 1"},
5759
{`select * from repositories limit 1
5860
;`, "select * from repositories limit 1"},
59-
{`select * from repositories limit 1
61+
{`select * from repositories limit 1
6062
; `, "select * from repositories limit 1"},
6163
{"select * from repositories limit 900", "select * from repositories LIMIT 100"},
6264
{"select * from repositories limit 900;", "select * from repositories LIMIT 100"},
@@ -98,6 +100,8 @@ func (suite *QuerySuite) TestBadRequest() {
98100
}
99101

100102
func (suite *QuerySuite) TestQueryErr() {
103+
mockProcessRows := sqlmock.NewRows([]string{"Id"}).AddRow(1288)
104+
suite.mock.ExpectQuery("SELECT CONNECTION_ID()").WillReturnRows(mockProcessRows)
101105
suite.mock.ExpectQuery(".*").WillReturnError(fmt.Errorf("forced err"))
102106

103107
json := `{"query": "select * from repositories"}`
@@ -108,12 +112,25 @@ func (suite *QuerySuite) TestQueryErr() {
108112
suite.Equal(http.StatusBadRequest, res.Code)
109113
}
110114

115+
func (suite *QuerySuite) TestQueryConnIdErr() {
116+
suite.mock.ExpectQuery("SELECT CONNECTION_ID()").WillReturnError(sql.ErrNoRows)
117+
118+
json := `{"query": "select * from repositories"}`
119+
req, _ := http.NewRequest("POST", "/query", strings.NewReader(json))
120+
res := httptest.NewRecorder()
121+
suite.handler.ServeHTTP(res, req)
122+
123+
suite.Equal(http.StatusInternalServerError, res.Code)
124+
}
125+
111126
func (suite *QuerySuite) TestQuery() {
112127
rows := sqlmock.NewRows([]string{"a", "b", "c", "d"}).
113128
AddRow(1, "one", 1.5, 100).
114129
AddRow(nil, nil, nil, nil)
115130

116-
suite.mock.ExpectQuery(".*").WillReturnRows(rows)
131+
mockProcessRows := sqlmock.NewRows([]string{"Id"}).AddRow(1288)
132+
suite.mock.ExpectQuery("SELECT CONNECTION_ID()").WillReturnRows(mockProcessRows)
133+
suite.mock.ExpectQuery(`select \* from repositories`).WillReturnRows(rows)
117134

118135
json := `{"query": "select * from repositories"}`
119136
req, _ := http.NewRequest("POST", "/query", strings.NewReader(json))
@@ -199,15 +216,14 @@ func (suite *QuerySuite) TestQueryAbort() {
199216
// Ideally we would test that the sql query context is canceled, but
200217
// go-sqlmock does not have something like ExpectContextCancellation
201218

219+
require := require.New(suite.T())
220+
221+
mockProcessRows := sqlmock.NewRows([]string{"Id"}).AddRow(1288)
222+
suite.mock.ExpectQuery("SELECT CONNECTION_ID()").WillReturnRows(mockProcessRows)
223+
202224
mockRows := sqlmock.NewRows([]string{"a", "b", "c", "d"}).AddRow(1, "one", 1.5, 100)
203225
suite.mock.ExpectQuery(`select \* from repositories`).WillDelayFor(2 * time.Second).WillReturnRows(mockRows)
204226

205-
mockProcessRows := sqlmock.NewRows(
206-
[]string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"}).
207-
AddRow(1234, nil, "localhost:3306", nil, "query", 2, "SquashedTable(refs, commit_files, files)(1/5)", "select * from files").
208-
AddRow(1288, nil, "localhost:3306", nil, "query", 2, "SquashedTable(refs, commit_files, files)(1/5)", "select * from repositories")
209-
suite.mock.ExpectQuery("SHOW FULL PROCESSLIST").WillReturnRows(mockProcessRows)
210-
211227
suite.mock.ExpectExec("KILL 1288")
212228

213229
json := `{"query": "select * from repositories"}`
@@ -224,8 +240,8 @@ func (suite *QuerySuite) TestQueryAbort() {
224240
defer wg.Done()
225241

226242
_, err := suite.requestProcessFunc(suite.db)(r)
227-
suite.Error(err)
228-
suite.Equal(context.Canceled, err)
243+
require.Error(err)
244+
require.Equal(context.Canceled, err)
229245
}
230246

231247
go func() {
@@ -242,5 +258,5 @@ func (suite *QuerySuite) TestQueryAbort() {
242258

243259
wg.Wait()
244260

245-
suite.Equal(context.Canceled, ctx.Err())
261+
require.Equal(context.Canceled, ctx.Err())
246262
}

server/service/common.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
// SQLDB describes a *sql.DB
99
type SQLDB interface {
1010
Close() error
11+
Conn(context.Context) (*sql.Conn, error)
1112
Query(query string, args ...interface{}) (*sql.Rows, error)
1213
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
1314
QueryRow(query string, args ...interface{}) *sql.Row

server/testing/common.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ import (
88
// MockDB is a mock of *sql.DB
99
type MockDB struct{}
1010

11+
// Conn returns a conn from the pool
12+
func (db *MockDB) Conn(ctx context.Context) (*sql.Conn, error) {
13+
return nil, nil
14+
}
15+
1116
// Close closes the DB connection
1217
func (db *MockDB) Close() error {
1318
return nil

0 commit comments

Comments
 (0)