Skip to content

Commit 1184c48

Browse files
authored
Add tests (#20)
* Add tests * Fix linter errors * Add tests to the workflow
1 parent 4580bad commit 1184c48

File tree

5 files changed

+301
-0
lines changed

5 files changed

+301
-0
lines changed

.github/workflows/test.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,6 @@ jobs:
4545
version: "v1.57"
4646
skip-pkg-cache: true
4747
install-mode: "goinstall"
48+
49+
- name: Run tests 🧪
50+
run: make test

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ update-all:
1313

1414
build-dev: tidy
1515
go build
16+
17+
test:
18+
go test -v ./...

go.mod

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ require (
1111
github.com/jackc/pgx/v5 v5.5.5
1212
github.com/prometheus/client_golang v1.19.0
1313
github.com/spf13/cast v1.6.0
14+
github.com/stretchr/testify v1.8.4
1415
google.golang.org/grpc v1.63.0
1516
)
1617

1718
require (
1819
github.com/beorn7/perks v1.0.1 // indirect
1920
github.com/cespare/xxhash/v2 v2.3.0 // indirect
21+
github.com/davecgh/go-spew v1.1.1 // indirect
2022
github.com/expr-lang/expr v1.16.3 // indirect
2123
github.com/fatih/color v1.16.0 // indirect
2224
github.com/golang/protobuf v1.5.4 // indirect
@@ -26,6 +28,7 @@ require (
2628
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
2729
github.com/oklog/run v1.1.0 // indirect
2830
github.com/pganalyze/pg_query_go/v5 v5.1.0 // indirect
31+
github.com/pmezard/go-difflib v1.0.0 // indirect
2932
github.com/prometheus/client_model v0.6.1 // indirect
3033
github.com/prometheus/common v0.52.2 // indirect
3134
github.com/prometheus/procfs v0.13.0 // indirect
@@ -37,4 +40,5 @@ require (
3740
golang.org/x/text v0.14.0 // indirect
3841
google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect
3942
google.golang.org/protobuf v1.33.0 // indirect
43+
gopkg.in/yaml.v3 v3.0.1 // indirect
4044
)

go.sum

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

plugin/plugin_test.go

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
package plugin
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
11+
sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
12+
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
13+
"github.com/hashicorp/go-hclog"
14+
"github.com/jackc/pgx/v5/pgproto3"
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
)
18+
19+
func Test_isSQLi(t *testing.T) {
20+
p := &Plugin{
21+
EnableLibinjection: true,
22+
Logger: hclog.NewNullLogger(),
23+
}
24+
// This is a false positive, since the query is not an SQL injection.
25+
assert.True(t, p.isSQLi("SELECT * FROM users WHERE id = 1"))
26+
// This is an SQL injection.
27+
assert.True(t, p.isSQLi("SELECT * FROM users WHERE id = 1 OR 1=1"))
28+
}
29+
30+
func Test_isSQLiDisabled(t *testing.T) {
31+
p := &Plugin{
32+
EnableLibinjection: false,
33+
Logger: hclog.NewNullLogger(),
34+
}
35+
// This is an SQL injection, but the libinjection is disabled.
36+
assert.False(t, p.isSQLi("SELECT * FROM users WHERE id = 1 OR 1=1"))
37+
}
38+
39+
func Test_errorResponse(t *testing.T) {
40+
p := &Plugin{
41+
Logger: hclog.NewNullLogger(),
42+
}
43+
44+
query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
45+
queryBytes, err := query.Encode(nil)
46+
require.NoError(t, err)
47+
48+
req := map[string]any{
49+
"request": queryBytes,
50+
}
51+
reqJSON, err := v1.NewStruct(req)
52+
require.NoError(t, err)
53+
assert.NotNil(t, reqJSON)
54+
55+
resp := p.errorResponse(
56+
reqJSON,
57+
map[string]any{
58+
"score": 0.9999,
59+
"detector": "deep_learning_model",
60+
},
61+
)
62+
// We are modifying the pointer to the object, so they should be the same.
63+
assert.Equal(t, reqJSON, resp)
64+
assert.Len(t, resp.GetFields(), 3)
65+
assert.Contains(t, resp.GetFields(), "request")
66+
assert.Contains(t, resp.GetFields(), "response")
67+
assert.Contains(t, resp.GetFields(), sdkAct.Signals)
68+
// 2 signals: Terminate and Log.
69+
assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2)
70+
}
71+
72+
func Test_OnTrafficFromClinet(t *testing.T) {
73+
p := &Plugin{
74+
Logger: hclog.NewNullLogger(),
75+
ModelName: "sqli_model",
76+
ModelVersion: "2",
77+
}
78+
79+
server := httptest.NewServer(
80+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81+
fmt.Println(r.URL.Path)
82+
switch r.URL.Path {
83+
case TokenizeAndSequencePath:
84+
w.WriteHeader(http.StatusOK)
85+
w.Header().Set("Content-Type", "application/json")
86+
// This is the tokenized query:
87+
// {"query":"select * from users where id = 1 or 1=1"}
88+
resp := map[string][]float32{
89+
"tokens": {
90+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12,
91+
},
92+
}
93+
data, _ := json.Marshal(resp)
94+
_, err := w.Write(data)
95+
require.NoError(t, err)
96+
case fmt.Sprintf(PredictPath, p.ModelName, p.ModelVersion):
97+
w.WriteHeader(http.StatusOK)
98+
w.Header().Set("Content-Type", "application/json")
99+
// This is the output of the deep learning model.
100+
resp := map[string][][]float32{"outputs": {{0.999909341}}}
101+
data, _ := json.Marshal(resp)
102+
_, err := w.Write(data)
103+
require.NoError(t, err)
104+
default:
105+
w.WriteHeader(http.StatusNotFound)
106+
}
107+
}),
108+
)
109+
defer server.Close()
110+
111+
p.TokenizerAPIAddress = server.URL
112+
p.ServingAPIAddress = server.URL
113+
114+
query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
115+
queryBytes, err := query.Encode(nil)
116+
require.NoError(t, err)
117+
118+
req := map[string]any{
119+
"request": queryBytes,
120+
}
121+
reqJSON, err := v1.NewStruct(req)
122+
require.NoError(t, err)
123+
assert.NotNil(t, reqJSON)
124+
125+
resp, err := p.OnTrafficFromClient(context.Background(), reqJSON)
126+
require.NoError(t, err)
127+
assert.NotNil(t, resp)
128+
assert.Len(t, resp.GetFields(), 4)
129+
assert.Contains(t, resp.GetFields(), "request")
130+
assert.Contains(t, resp.GetFields(), "query")
131+
assert.Contains(t, resp.GetFields(), "response")
132+
assert.Contains(t, resp.GetFields(), sdkAct.Signals)
133+
// 2 signals: Terminate and Log.
134+
assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2)
135+
}
136+
137+
func Test_OnTrafficFromClinetFailedTokenization(t *testing.T) {
138+
plugins := []*Plugin{
139+
{
140+
Logger: hclog.NewNullLogger(),
141+
ModelName: "sqli_model",
142+
ModelVersion: "2",
143+
// If libinjection is enabled, the response should contain the "response" field,
144+
// and the "signals" field, which means the plugin will terminate the request.
145+
EnableLibinjection: true,
146+
},
147+
{
148+
Logger: hclog.NewNullLogger(),
149+
ModelName: "sqli_model",
150+
ModelVersion: "2",
151+
// If libinjection is disabled, the response should not contain the "response" field,
152+
// and the "signals" field, which means the plugin will not terminate the request.
153+
EnableLibinjection: false,
154+
},
155+
}
156+
157+
server := httptest.NewServer(
158+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
159+
fmt.Println(r.URL.Path)
160+
switch r.URL.Path {
161+
case TokenizeAndSequencePath:
162+
w.WriteHeader(http.StatusInternalServerError)
163+
default:
164+
w.WriteHeader(http.StatusNotFound)
165+
}
166+
}),
167+
)
168+
defer server.Close()
169+
170+
for i := range plugins {
171+
plugins[i].TokenizerAPIAddress = server.URL
172+
plugins[i].ServingAPIAddress = server.URL
173+
174+
query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
175+
queryBytes, err := query.Encode(nil)
176+
require.NoError(t, err)
177+
178+
req := map[string]any{
179+
"request": queryBytes,
180+
}
181+
reqJSON, err := v1.NewStruct(req)
182+
require.NoError(t, err)
183+
assert.NotNil(t, reqJSON)
184+
185+
resp, err := plugins[i].OnTrafficFromClient(context.Background(), reqJSON)
186+
require.NoError(t, err)
187+
assert.NotNil(t, resp)
188+
if plugins[i].EnableLibinjection {
189+
assert.Len(t, resp.GetFields(), 4)
190+
assert.Contains(t, resp.GetFields(), "request")
191+
assert.Contains(t, resp.GetFields(), "query")
192+
assert.Contains(t, resp.GetFields(), "response")
193+
assert.Contains(t, resp.GetFields(), sdkAct.Signals)
194+
// 2 signals: Terminate and Log.
195+
assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2)
196+
} else {
197+
assert.Len(t, resp.GetFields(), 2)
198+
assert.Contains(t, resp.GetFields(), "request")
199+
assert.Contains(t, resp.GetFields(), "query")
200+
assert.NotContains(t, resp.GetFields(), "response")
201+
assert.NotContains(t, resp.GetFields(), sdkAct.Signals)
202+
}
203+
}
204+
}
205+
206+
func Test_OnTrafficFromClinetFailedPrediction(t *testing.T) {
207+
plugins := []*Plugin{
208+
{
209+
Logger: hclog.NewNullLogger(),
210+
ModelName: "sqli_model",
211+
ModelVersion: "2",
212+
// If libinjection is disabled, the response should not contain the "response" field,
213+
// and the "signals" field, which means the plugin will not terminate the request.
214+
EnableLibinjection: false,
215+
},
216+
{
217+
Logger: hclog.NewNullLogger(),
218+
ModelName: "sqli_model",
219+
ModelVersion: "2",
220+
// If libinjection is enabled, the response should contain the "response" field,
221+
// and the "signals" field, which means the plugin will terminate the request.
222+
EnableLibinjection: true,
223+
},
224+
}
225+
226+
// This is the same for both plugins.
227+
predictPath := fmt.Sprintf(PredictPath, plugins[0].ModelName, plugins[1].ModelVersion)
228+
229+
server := httptest.NewServer(
230+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
231+
fmt.Println(r.URL.Path)
232+
switch r.URL.Path {
233+
case TokenizeAndSequencePath:
234+
w.WriteHeader(http.StatusOK)
235+
w.Header().Set("Content-Type", "application/json")
236+
// This is the tokenized query:
237+
// {"query":"select * from users where id = 1 or 1=1"}
238+
resp := map[string][]float32{
239+
"tokens": {
240+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 6, 5, 73, 7, 68, 4, 11, 12,
241+
},
242+
}
243+
data, _ := json.Marshal(resp)
244+
_, err := w.Write(data)
245+
require.NoError(t, err)
246+
case predictPath:
247+
w.WriteHeader(http.StatusInternalServerError)
248+
default:
249+
w.WriteHeader(http.StatusNotFound)
250+
}
251+
}),
252+
)
253+
defer server.Close()
254+
255+
for i := range plugins {
256+
plugins[i].TokenizerAPIAddress = server.URL
257+
plugins[i].ServingAPIAddress = server.URL
258+
259+
query := pgproto3.Query{String: "SELECT * FROM users WHERE id = 1 OR 1=1"}
260+
queryBytes, err := query.Encode(nil)
261+
require.NoError(t, err)
262+
263+
req := map[string]any{
264+
"request": queryBytes,
265+
}
266+
reqJSON, err := v1.NewStruct(req)
267+
require.NoError(t, err)
268+
assert.NotNil(t, reqJSON)
269+
270+
resp, err := plugins[i].OnTrafficFromClient(context.Background(), reqJSON)
271+
require.NoError(t, err)
272+
assert.NotNil(t, resp)
273+
if plugins[i].EnableLibinjection {
274+
assert.Len(t, resp.GetFields(), 4)
275+
assert.Contains(t, resp.GetFields(), "request")
276+
assert.Contains(t, resp.GetFields(), "query")
277+
assert.Contains(t, resp.GetFields(), "response")
278+
assert.Contains(t, resp.GetFields(), sdkAct.Signals)
279+
// 2 signals: Terminate and Log.
280+
assert.Len(t, resp.Fields[sdkAct.Signals].GetListValue().AsSlice(), 2)
281+
} else {
282+
assert.Len(t, resp.GetFields(), 2)
283+
assert.Contains(t, resp.GetFields(), "request")
284+
assert.Contains(t, resp.GetFields(), "query")
285+
assert.NotContains(t, resp.GetFields(), "response")
286+
assert.NotContains(t, resp.GetFields(), sdkAct.Signals)
287+
}
288+
}
289+
}

0 commit comments

Comments
 (0)