Skip to content

Commit 309c6ec

Browse files
committedJan 24, 2025·
Missing tests added.
1 parent afd81fc commit 309c6ec

13 files changed

+1283
-260
lines changed
 

‎caddywaf.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
208208
}
209209

210210
// Load IP blacklist
211-
m.ipBlacklist = NewCIDRTrie() // Initialize as CIDRTrie
211+
m.ipBlacklist = NewCIDRTrie() // Initialize as CIDRTrie
212+
m.logger.Debug("ipBlacklist initialized in Provision", zap.Bool("isNil", m.ipBlacklist == nil)) // ADDED LOGGING - Check if nil right after initialization
212213
if m.IPBlacklistFile != "" {
213214
err = m.loadIPBlacklistIntoMap(m.IPBlacklistFile, m.ipBlacklist)
214215
if err != nil {
@@ -226,8 +227,12 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
226227
}
227228

228229
// Load WAF rules - calling the new external loadRules function
229-
if err := m.loadRules(m.RuleFiles); err != nil {
230-
return fmt.Errorf("failed to load rules: %w", err)
230+
if len(m.RuleFiles) > 0 { // Modified condition to check for rule files before loading
231+
if err := m.loadRules(m.RuleFiles); err != nil {
232+
return fmt.Errorf("failed to load rules: %w", err)
233+
}
234+
} else {
235+
m.logger.Warn("No rule files specified, WAF will run without rules.") // Log a warning instead of error
231236
}
232237

233238
m.logger.Info("WAF middleware provisioned successfully")

‎caddywaf_test.go

-254
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"strings"
1010

1111
"bytes"
12-
"encoding/json"
1312
"net/http"
1413
"net/http/httptest"
1514
"os"
@@ -552,15 +551,6 @@ func TestUnmarshalCaddyfile_MissingRuleFile(t *testing.T) {
552551
// MockGeoIPReader is a mock implementation of GeoIP reader for testing
553552
type MockGeoIPReader struct{}
554553

555-
// TestWithGeoIPCache tests the WithGeoIPCache method.
556-
func TestWithGeoIPCache(t *testing.T) {
557-
logger := zap.NewNop()
558-
handler := NewGeoIPHandler(logger)
559-
560-
handler.WithGeoIPCache(time.Minute * 10)
561-
assert.Equal(t, time.Minute*10, handler.geoIPCacheTTL)
562-
}
563-
564554
// TestWithGeoIPLookupFallbackBehavior tests the WithGeoIPLookupFallbackBehavior method.
565555
func TestWithGeoIPLookupFallbackBehavior(t *testing.T) {
566556
logger := zap.NewNop()
@@ -570,50 +560,6 @@ func TestWithGeoIPLookupFallbackBehavior(t *testing.T) {
570560
assert.Equal(t, "default", handler.geoIPLookupFallbackBehavior)
571561
}
572562

573-
// TestLoadGeoIPDatabase tests the LoadGeoIPDatabase method.
574-
func TestLoadGeoIPDatabase(t *testing.T) {
575-
logger := zap.NewNop()
576-
handler := NewGeoIPHandler(logger)
577-
578-
// Test with a valid GeoIP database
579-
// Mock the GeoIP database loading
580-
reader := &MockGeoIPReader{}
581-
err := error(nil)
582-
assert.NoError(t, err)
583-
assert.NotNil(t, reader)
584-
585-
// Test with an invalid path
586-
_, err = handler.LoadGeoIPDatabase("nonexistent.mmdb")
587-
assert.Error(t, err)
588-
}
589-
590-
func TestFileExists(t *testing.T) {
591-
// Create a temporary file for testing
592-
tmpFile, err := os.CreateTemp("", "testfile")
593-
if err != nil {
594-
t.Fatalf("Failed to create temporary file: %v", err)
595-
}
596-
defer os.Remove(tmpFile.Name()) // Clean up the file after the test
597-
598-
// Test case: File exists
599-
assert.True(t, fileExists(tmpFile.Name()), "Expected file to exist")
600-
601-
// Test case: File does not exist
602-
assert.False(t, fileExists("nonexistentfile.txt"), "Expected file to not exist")
603-
604-
// Test case: Path is empty
605-
assert.False(t, fileExists(""), "Expected empty path to return false")
606-
607-
// Test case: Path is a directory
608-
tmpDir, err := os.MkdirTemp("", "testdir")
609-
if err != nil {
610-
t.Fatalf("Failed to create temporary directory: %v", err)
611-
}
612-
defer os.Remove(tmpDir) // Clean up the directory after the test
613-
614-
assert.False(t, fileExists(tmpDir), "Expected directory path to return false")
615-
}
616-
617563
func TestLogRequest(t *testing.T) {
618564
// Create a test logger using zaptest
619565
logger := zaptest.NewLogger(t)
@@ -639,56 +585,6 @@ func TestLogRequest(t *testing.T) {
639585
time.Sleep(100 * time.Millisecond)
640586
}
641587

642-
func TestLogWorker(t *testing.T) {
643-
// Create a test logger using zaptest
644-
logger := zaptest.NewLogger(t)
645-
646-
// Create a Middleware instance with the test logger
647-
middleware := &Middleware{
648-
logger: logger,
649-
logLevel: zapcore.DebugLevel,
650-
}
651-
652-
// Start the log worker
653-
middleware.StartLogWorker()
654-
655-
// Send a log entry
656-
middleware.logChan <- LogEntry{
657-
Level: zapcore.InfoLevel,
658-
Message: "Worker test message",
659-
Fields: []zap.Field{zap.String("field", "value")},
660-
}
661-
662-
// Wait for the log entry to be processed
663-
time.Sleep(100 * time.Millisecond)
664-
665-
// Stop the log worker
666-
middleware.StopLogWorker()
667-
}
668-
669-
func TestNewRateLimiter(t *testing.T) {
670-
config := RateLimit{
671-
Requests: 10,
672-
Window: time.Minute,
673-
CleanupInterval: time.Minute,
674-
Paths: []string{"/api/.*"},
675-
MatchAllPaths: false,
676-
}
677-
678-
rl, err := NewRateLimiter(config)
679-
if err != nil {
680-
t.Fatalf("Failed to create rate limiter: %v", err)
681-
}
682-
683-
assert.NotNil(t, rl)
684-
assert.Equal(t, 10, rl.config.Requests)
685-
assert.Equal(t, time.Minute, rl.config.Window)
686-
assert.Equal(t, time.Minute, rl.config.CleanupInterval)
687-
assert.Equal(t, 1, len(rl.config.PathRegexes))
688-
assert.Equal(t, "/api/.*", rl.config.Paths[0])
689-
assert.False(t, rl.config.MatchAllPaths)
690-
}
691-
692588
func TestIsRateLimited_PathMatching(t *testing.T) {
693589
config := RateLimit{
694590
Requests: 2,
@@ -1014,149 +910,6 @@ func newMockLogger() *MockLogger {
1014910
return &MockLogger{Logger: logger}
1015911
}
1016912

1017-
func TestValidateRule(t *testing.T) {
1018-
// Test valid rule
1019-
validRule := &Rule{
1020-
ID: "rule1",
1021-
Pattern: ".*",
1022-
Targets: []string{"header"},
1023-
Phase: 1,
1024-
Score: 5,
1025-
Action: "block",
1026-
}
1027-
assert.NoError(t, validateRule(validRule))
1028-
1029-
// Test invalid rule (empty ID)
1030-
invalidRule := &Rule{
1031-
ID: "",
1032-
Pattern: ".*",
1033-
Targets: []string{"header"},
1034-
Phase: 1,
1035-
Score: 5,
1036-
Action: "block",
1037-
}
1038-
assert.Error(t, validateRule(invalidRule))
1039-
1040-
// Test invalid rule (invalid phase)
1041-
invalidRule.Phase = 5
1042-
assert.Error(t, validateRule(invalidRule))
1043-
1044-
// Test invalid rule (negative score)
1045-
invalidRule.Phase = 1
1046-
invalidRule.Score = -1
1047-
assert.Error(t, validateRule(invalidRule))
1048-
1049-
// Test invalid rule (invalid action)
1050-
invalidRule.Score = 5
1051-
invalidRule.Action = "invalid"
1052-
assert.Error(t, validateRule(invalidRule))
1053-
}
1054-
1055-
func TestProcessRuleMatch(t *testing.T) {
1056-
logger := newMockLogger()
1057-
middleware := &Middleware{
1058-
logger: logger.Logger,
1059-
AnomalyThreshold: 10,
1060-
ruleHits: sync.Map{},
1061-
muMetrics: sync.RWMutex{},
1062-
}
1063-
1064-
rule := &Rule{
1065-
ID: "rule1",
1066-
Targets: []string{"header"},
1067-
Description: "Test rule",
1068-
Score: 5,
1069-
Action: "block",
1070-
}
1071-
1072-
state := &WAFState{
1073-
TotalScore: 0,
1074-
ResponseWritten: false,
1075-
}
1076-
1077-
req := httptest.NewRequest("GET", "http://example.com", nil)
1078-
1079-
// Create a context and add logID to it
1080-
ctx := context.Background()
1081-
logID := "test-log-id" // Or generate a UUID if needed: uuid.New().String()
1082-
ctx = context.WithValue(ctx, ContextKeyLogId("logID"), logID)
1083-
1084-
// Create a new request with the context
1085-
req = req.WithContext(ctx)
1086-
1087-
// Create a ResponseRecorder
1088-
w := NewResponseRecorder(httptest.NewRecorder())
1089-
1090-
// Test blocking rule
1091-
shouldContinue := middleware.processRuleMatch(w, req, rule, "value", state)
1092-
assert.False(t, shouldContinue)
1093-
assert.Equal(t, http.StatusForbidden, w.StatusCode())
1094-
assert.True(t, state.Blocked)
1095-
assert.Equal(t, 5, state.TotalScore)
1096-
1097-
// Test logging rule
1098-
rule.Action = "log"
1099-
state = &WAFState{
1100-
TotalScore: 0,
1101-
ResponseWritten: false,
1102-
}
1103-
// Re-create a ResponseRecorder for the second test
1104-
w = NewResponseRecorder(httptest.NewRecorder())
1105-
shouldContinue = middleware.processRuleMatch(w, req, rule, "value", state)
1106-
assert.True(t, shouldContinue)
1107-
assert.False(t, state.Blocked)
1108-
assert.Equal(t, 5, state.TotalScore)
1109-
}
1110-
1111-
func TestLoadRules(t *testing.T) {
1112-
logger := newMockLogger()
1113-
middleware := &Middleware{
1114-
logger: logger.Logger, // Use the embedded *zap.Logger
1115-
ruleCache: NewRuleCache(),
1116-
mu: sync.RWMutex{}, // Use sync.RWMutex directly
1117-
}
1118-
1119-
// Create a temporary rule file
1120-
ruleFile, err := os.CreateTemp("", "rules.json")
1121-
assert.NoError(t, err)
1122-
defer os.Remove(ruleFile.Name())
1123-
1124-
rules := []Rule{
1125-
{
1126-
ID: "rule1",
1127-
Pattern: ".*",
1128-
Targets: []string{"header"},
1129-
Phase: 1,
1130-
Score: 5,
1131-
Action: "block",
1132-
},
1133-
{
1134-
ID: "rule2",
1135-
Pattern: ".*",
1136-
Targets: []string{"header"},
1137-
Phase: 2,
1138-
Score: 10,
1139-
Action: "log",
1140-
},
1141-
}
1142-
1143-
// Write rules to the temporary file
1144-
ruleData, err := json.Marshal(rules)
1145-
assert.NoError(t, err)
1146-
_, err = ruleFile.Write(ruleData)
1147-
assert.NoError(t, err)
1148-
ruleFile.Close()
1149-
1150-
// Test loading rules
1151-
err = middleware.loadRules([]string{ruleFile.Name()})
1152-
assert.NoError(t, err)
1153-
assert.Equal(t, 2, len(middleware.Rules[1])+len(middleware.Rules[2]))
1154-
1155-
// Test loading invalid rule file
1156-
err = middleware.loadRules([]string{"nonexistent.json"})
1157-
assert.Error(t, err)
1158-
}
1159-
1160913
func TestProcessRuleMatch_HighScore(t *testing.T) {
1161914
logger := newMockLogger()
1162915
middleware := &Middleware{
@@ -1212,13 +965,6 @@ func TestValidateRule_EmptyTargets(t *testing.T) {
1212965
assert.Contains(t, err.Error(), "has no targets")
1213966
}
1214967

1215-
func TestUnique(t *testing.T) {
1216-
// Test removing duplicates from a slice of strings
1217-
ips := []string{"1.1.1.1", "2.2.2.2", "1.1.1.1", "3.3.3.3"}
1218-
uniqueIPs := unique(ips)
1219-
assert.Equal(t, []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"}, uniqueIPs)
1220-
}
1221-
1222968
func TestNewRequestValueExtractor(t *testing.T) {
1223969
logger := zap.NewNop()
1224970
redactSensitiveData := true

0 commit comments

Comments
 (0)
Please sign in to comment.