Skip to content

Commit 03c279b

Browse files
feat: Implement IP range checking with caching (main)
Adds functions to check if an IP address is within a set of CIDR ranges. Includes caching for improved performance. Uses a LRU cache with an expiration time to store results. Adds tests for IP range checking and cache behavior.
1 parent 821ddcb commit 03c279b

File tree

4 files changed

+206
-3
lines changed

4 files changed

+206
-3
lines changed

Diff for: go.mod

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ go 1.23
44

55
require (
66
github.com/caddyserver/caddy/v2 v2.9.1
7+
github.com/hashicorp/golang-lru/v2 v2.0.7
8+
github.com/stretchr/testify v1.9.0
79
go.uber.org/zap v1.27.0
810
)
911

@@ -27,6 +29,7 @@ require (
2729
github.com/cespare/xxhash/v2 v2.3.0 // indirect
2830
github.com/chzyer/readline v1.5.1 // indirect
2931
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
32+
github.com/davecgh/go-spew v1.1.1 // indirect
3033
github.com/dgraph-io/badger v1.6.2 // indirect
3134
github.com/dgraph-io/badger/v2 v2.2007.4 // indirect
3235
github.com/dgraph-io/ristretto v0.1.0 // indirect
@@ -80,6 +83,7 @@ require (
8083
github.com/onsi/ginkgo/v2 v2.13.2 // indirect
8184
github.com/pires/go-proxyproto v0.7.1-0.20240628150027-b718e7ce4964 // indirect
8285
github.com/pkg/errors v0.9.1 // indirect
86+
github.com/pmezard/go-difflib v1.0.0 // indirect
8387
github.com/prometheus/client_golang v1.19.1 // indirect
8488
github.com/prometheus/client_model v0.5.0 // indirect
8589
github.com/prometheus/common v0.48.0 // indirect

Diff for: go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:Fecb
243243
github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw=
244244
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys=
245245
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I=
246+
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
247+
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
246248
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
247249
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
248250
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=

Diff for: utils/ip.go

+36-3
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,27 @@ package utils
22

33
import (
44
"fmt"
5+
"github.com/hashicorp/golang-lru/v2/expirable"
56
"github.com/jasonlovesdoggo/caddy-defender/ranges/data"
67
"go.uber.org/zap"
78
"net"
9+
"time"
810
)
911

10-
// IPInRanges checks if the given IP is within any of the provided CIDR ranges.
11-
// It returns true if the IP is in any of the ranges, false otherwise.
12-
func IPInRanges(clientIP net.IP, cidrRanges []string, log *zap.Logger) bool {
12+
const MaxKeys = 10000
13+
14+
var cache = expirable.NewLRU[string, bool](MaxKeys, nil, time.Minute*10)
15+
16+
// normalizeIP converts an IP to its normalized string representation.
17+
func normalizeIP(ip net.IP) string {
18+
if v4 := ip.To4(); v4 != nil {
19+
return v4.String()
20+
}
21+
return ip.String()
22+
}
23+
24+
// rawIPInRanges checks if the given IP is in the given CIDR ranges without using the cache.
25+
func rawIPInRanges(clientIP net.IP, cidrRanges []string, log *zap.Logger) bool {
1326
for _, cidr := range cidrRanges {
1427
// If the range is a predefined key (e.g., "openai"), use the corresponding CIDRs
1528
if ranges, ok := data.IPRanges[cidr]; ok {
@@ -37,3 +50,23 @@ func IPInRanges(clientIP net.IP, cidrRanges []string, log *zap.Logger) bool {
3750
}
3851
return false
3952
}
53+
54+
// IPInRanges checks if the given IP is within any of the provided CIDR ranges.
55+
// It returns true if the IP is in any of the ranges, false otherwise.
56+
func IPInRanges(clientIP net.IP, cidrRanges []string, log *zap.Logger) bool {
57+
// Normalize the IP for consistent cache keys
58+
cacheKey := normalizeIP(clientIP)
59+
60+
// Check the cache first
61+
if val, ok := cache.Get(cacheKey); ok {
62+
return val
63+
}
64+
65+
// If not in the cache, check the ranges
66+
inRanges := rawIPInRanges(clientIP, cidrRanges, log)
67+
68+
// Add the result to the cache
69+
cache.Add(cacheKey, inRanges)
70+
71+
return inRanges
72+
}

Diff for: utils/ip_test.go

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package utils
2+
3+
import (
4+
"github.com/hashicorp/golang-lru/v2/expirable"
5+
"github.com/jasonlovesdoggo/caddy-defender/ranges/data"
6+
"net"
7+
"testing"
8+
"time"
9+
10+
"github.com/stretchr/testify/assert"
11+
"go.uber.org/zap"
12+
)
13+
14+
// Test data
15+
var (
16+
validCIDRs = []string{
17+
"192.168.1.0/24",
18+
"10.0.0.0/8",
19+
"2001:db8::/32",
20+
}
21+
invalidCIDRs = []string{
22+
"invalid-cidr",
23+
"192.168.1.0/33", // Invalid subnet mask
24+
}
25+
predefinedCIDRs = map[string][]string{
26+
"openai": {
27+
"203.0.113.0/24",
28+
"2001:db8:1::/48",
29+
},
30+
}
31+
)
32+
33+
// Mock logger for testing
34+
var testLogger = zap.NewNop()
35+
36+
// TestRawIPInRanges tests the rawIPInRanges function.
37+
func TestRawIPInRanges(t *testing.T) {
38+
tests := []struct {
39+
name string
40+
ip string
41+
cidrRanges []string
42+
expected bool
43+
}{
44+
{
45+
name: "IPv4 in range",
46+
ip: "192.168.1.100",
47+
cidrRanges: validCIDRs,
48+
expected: true,
49+
},
50+
{
51+
name: "IPv4 not in range",
52+
ip: "192.168.2.100",
53+
cidrRanges: validCIDRs,
54+
expected: false,
55+
},
56+
{
57+
name: "IPv6 in range",
58+
ip: "2001:db8::1",
59+
cidrRanges: validCIDRs,
60+
expected: true,
61+
},
62+
{
63+
name: "IPv6 not in range",
64+
ip: "2001:db8:1::1",
65+
cidrRanges: []string{"2001:db8::/48"}, // Narrower range
66+
expected: false,
67+
},
68+
{
69+
name: "Invalid CIDR",
70+
ip: "192.168.1.100",
71+
cidrRanges: invalidCIDRs,
72+
expected: false,
73+
},
74+
{
75+
name: "Predefined CIDR (IPv4)",
76+
ip: "203.0.113.10",
77+
cidrRanges: []string{"openai"},
78+
expected: true,
79+
},
80+
{
81+
name: "Predefined CIDR (IPv6)",
82+
ip: "2001:db8:1::10",
83+
cidrRanges: []string{"openai"},
84+
expected: true,
85+
},
86+
}
87+
88+
for _, tt := range tests {
89+
t.Run(tt.name, func(t *testing.T) {
90+
clientIP := net.ParseIP(tt.ip)
91+
assert.NotNil(t, clientIP, "Failed to parse IP")
92+
93+
// Mock predefined CIDRs
94+
data.IPRanges = predefinedCIDRs
95+
96+
result := rawIPInRanges(clientIP, tt.cidrRanges, testLogger)
97+
assert.Equal(t, tt.expected, result, "Unexpected result for IP %s", tt.ip)
98+
})
99+
}
100+
}
101+
102+
// TestIPInRanges tests the IPInRanges function, including caching behavior.
103+
func TestIPInRanges(t *testing.T) {
104+
tests := []struct {
105+
name string
106+
ip string
107+
cidrRanges []string
108+
expected bool
109+
}{
110+
{
111+
name: "IPv4 in range (cached)",
112+
ip: "192.168.1.100",
113+
cidrRanges: validCIDRs,
114+
expected: true,
115+
},
116+
{
117+
name: "IPv4 not in range (cached)",
118+
ip: "192.168.2.100",
119+
cidrRanges: validCIDRs,
120+
expected: false,
121+
},
122+
}
123+
124+
for _, tt := range tests {
125+
t.Run(tt.name, func(t *testing.T) {
126+
clientIP := net.ParseIP(tt.ip)
127+
assert.NotNil(t, clientIP, "Failed to parse IP")
128+
129+
// Mock predefined CIDRs
130+
data.IPRanges = predefinedCIDRs
131+
132+
// First call (not cached)
133+
result := IPInRanges(clientIP, tt.cidrRanges, testLogger)
134+
assert.Equal(t, tt.expected, result, "Unexpected result for IP %s (first call)", tt.ip)
135+
136+
// Second call (cached)
137+
result = IPInRanges(clientIP, tt.cidrRanges, testLogger)
138+
assert.Equal(t, tt.expected, result, "Unexpected result for IP %s (second call)", tt.ip)
139+
})
140+
}
141+
}
142+
143+
// TestIPInRangesCacheExpiration tests the cache expiration behavior.
144+
func TestIPInRangesCacheExpiration(t *testing.T) {
145+
// Set a short cache expiration time for testing
146+
cache = expirable.NewLRU[string, bool](MaxKeys, nil, time.Millisecond*10)
147+
148+
clientIP := net.ParseIP("192.168.1.100")
149+
assert.NotNil(t, clientIP, "Failed to parse IP")
150+
151+
// Mock predefined CIDRs
152+
data.IPRanges = predefinedCIDRs
153+
154+
// First call (not cached)
155+
result := IPInRanges(clientIP, validCIDRs, testLogger)
156+
assert.True(t, result, "Expected IP to be in range (first call)")
157+
158+
// Wait for cache to expire
159+
time.Sleep(time.Millisecond * 20)
160+
161+
// Second call (cache expired)
162+
result = IPInRanges(clientIP, validCIDRs, testLogger)
163+
assert.True(t, result, "Expected IP to be in range (second call, cache expired)")
164+
}

0 commit comments

Comments
 (0)