-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdbms_test.go
150 lines (129 loc) · 4.63 KB
/
dbms_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package sqllexer
import (
"embed"
"encoding/json"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
//go:embed testdata/*
var testdata embed.FS
type output struct {
Expected string `json:"expected"`
ObfuscatorConfig *obfuscatorConfig `json:"obfuscator_config,omitempty"`
NormalizerConfig *normalizerConfig `json:"normalizer_config,omitempty"`
StatementMetadata *StatementMetadata `json:"statement_metadata,omitempty"`
}
type testcase struct {
Input string `json:"input"`
Outputs []output `json:"outputs"`
}
// TestQueriesPerDBMS tests a preset of queries and expected output per DBMS
// Test folder structure:
// -- testdata
//
// -- dbms_type
// -- query_type
// -- query_name.json
func TestQueriesPerDBMS(t *testing.T) {
dbmsTypes := []DBMSType{
DBMSPostgres,
DBMSOracle,
DBMSSQLServer,
DBMSMySQL,
DBMSSnowflake,
}
for _, dbms := range dbmsTypes {
// Get all subdirectories of the testdata folder
baseDir := filepath.Join("testdata", string(dbms))
// Get all subdirectories of the testdata folder
queryTypes, err := testdata.ReadDir(baseDir)
if err != nil {
t.Fatal(err)
}
for _, qt := range queryTypes {
dirPath := filepath.Join(baseDir, qt.Name())
files, err := testdata.ReadDir(dirPath)
if err != nil {
t.Fatal(err)
}
for _, file := range files {
testName := strings.TrimSuffix(file.Name(), ".json")
t.Run(testName, func(t *testing.T) {
queryPath := filepath.Join(dirPath, file.Name())
testfile, err := testdata.ReadFile(queryPath)
if err != nil {
t.Fatal(err)
}
var tt testcase
if err := json.Unmarshal(testfile, &tt); err != nil {
t.Fatal(err)
}
var defaultObfuscatorConfig *obfuscatorConfig
var defaultNormalizerConfig *normalizerConfig
for _, output := range tt.Outputs {
// If the test case has a custom obfuscator or normalizer config
// use it, otherwise use the default config
if output.ObfuscatorConfig != nil {
defaultObfuscatorConfig = output.ObfuscatorConfig
} else {
defaultObfuscatorConfig = &obfuscatorConfig{
DollarQuotedFunc: true,
ReplaceDigits: true,
ReplacePositionalParameter: true,
ReplaceBoolean: true,
ReplaceNull: true,
KeepJsonPath: false,
}
}
if output.NormalizerConfig != nil {
defaultNormalizerConfig = output.NormalizerConfig
} else {
defaultNormalizerConfig = &normalizerConfig{
CollectComments: true,
CollectCommands: true,
CollectTables: true,
CollectProcedure: true,
KeepSQLAlias: false,
UppercaseKeywords: false,
RemoveSpaceBetweenParentheses: false,
KeepTrailingSemicolon: false,
KeepIdentifierQuotation: false,
}
}
obfuscator := NewObfuscator(
WithDollarQuotedFunc(defaultObfuscatorConfig.DollarQuotedFunc),
WithReplaceDigits(defaultObfuscatorConfig.ReplaceDigits),
WithReplacePositionalParameter(defaultObfuscatorConfig.ReplacePositionalParameter),
WithReplaceBoolean(defaultObfuscatorConfig.ReplaceBoolean),
WithReplaceNull(defaultObfuscatorConfig.ReplaceNull),
WithKeepJsonPath(defaultObfuscatorConfig.KeepJsonPath),
)
normalizer := NewNormalizer(
WithCollectComments(defaultNormalizerConfig.CollectComments),
WithCollectCommands(defaultNormalizerConfig.CollectCommands),
WithCollectTables(defaultNormalizerConfig.CollectTables),
WithCollectProcedures(defaultNormalizerConfig.CollectProcedure),
WithKeepSQLAlias(defaultNormalizerConfig.KeepSQLAlias),
WithUppercaseKeywords(defaultNormalizerConfig.UppercaseKeywords),
WithRemoveSpaceBetweenParentheses(defaultNormalizerConfig.RemoveSpaceBetweenParentheses),
WithKeepTrailingSemicolon(defaultNormalizerConfig.KeepTrailingSemicolon),
WithKeepIdentifierQuotation(defaultNormalizerConfig.KeepIdentifierQuotation),
)
got, statementMetadata, err := ObfuscateAndNormalize(string(tt.Input), obfuscator, normalizer, WithDBMS(dbms))
if err != nil {
t.Fatal(err)
}
// Compare the expected output with the actual output
assert.Equal(t, output.Expected, got)
// Compare the expected statement metadata with the actual statement metadata
if output.StatementMetadata != nil {
assertStatementMetadataEqual(t, output.StatementMetadata, statementMetadata)
}
}
})
}
}
}
}