Skip to content

Commit 3ccd082

Browse files
committed
Add iterators for fetching hostmaps
1 parent 750e4a8 commit 3ccd082

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

control.go

+45
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package nebula
22

33
import (
44
"context"
5+
"iter"
56
"net/netip"
67
"os"
78
"os/signal"
@@ -120,6 +121,15 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
120121
}
121122
}
122123

124+
// ListHostmapHostsIter returns an iter with details about the actual or pending (handshaking) hostmap by vpn ip
125+
func (c *Control) ListHostmapHostsIter(pendingMap bool) iter.Seq[*ControlHostInfo] {
126+
if pendingMap {
127+
return listHostMapHostsIter(c.f.handshakeManager)
128+
} else {
129+
return listHostMapHostsIter(c.f.hostMap)
130+
}
131+
}
132+
123133
// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
124134
func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
125135
if pendingMap {
@@ -129,6 +139,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
129139
}
130140
}
131141

142+
// ListHostmapIndexesIter returns an iter with details about the actual or pending (handshaking) hostmap by local index id
143+
func (c *Control) ListHostmapIndexesIter(pendingMap bool) iter.Seq[*ControlHostInfo] {
144+
if pendingMap {
145+
return listHostMapIndexesIter(c.f.handshakeManager)
146+
} else {
147+
return listHostMapIndexesIter(c.f.hostMap)
148+
}
149+
}
150+
132151
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
133152
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
134153
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
@@ -306,6 +325,19 @@ func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
306325
return hosts
307326
}
308327

328+
func listHostMapHostsIter(hl controlHostLister) iter.Seq[*ControlHostInfo] {
329+
pr := hl.GetPreferredRanges()
330+
331+
return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) {
332+
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
333+
host := copyHostInfo(hostinfo, pr)
334+
if !yield(&host) {
335+
return // Stop iteration early if yield returns false
336+
}
337+
})
338+
})
339+
}
340+
309341
func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
310342
hosts := make([]ControlHostInfo, 0)
311343
pr := hl.GetPreferredRanges()
@@ -314,3 +346,16 @@ func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
314346
})
315347
return hosts
316348
}
349+
350+
func listHostMapIndexesIter(hl controlHostLister) iter.Seq[*ControlHostInfo] {
351+
pr := hl.GetPreferredRanges()
352+
353+
return iter.Seq[*ControlHostInfo](func(yield func(*ControlHostInfo) bool) {
354+
hl.ForEachIndex(func(hostinfo *HostInfo) {
355+
host := copyHostInfo(hostinfo, pr)
356+
if !yield(&host) {
357+
return // Stop iteration early if yield returns false
358+
}
359+
})
360+
})
361+
}

control_test.go

+88
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,94 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
110110
})
111111
}
112112

113+
func TestListHostMapHostsIter(t *testing.T) {
114+
l := logrus.New()
115+
hm := newHostMap(l, netip.Prefix{})
116+
hm.preferredRanges.Store(&[]netip.Prefix{})
117+
118+
hosts := []struct {
119+
vpnIp netip.Addr
120+
remoteAddr netip.AddrPort
121+
localIndexId uint32
122+
remoteIndexId uint32
123+
}{
124+
{vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201},
125+
{vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202},
126+
{vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203},
127+
}
128+
129+
for _, h := range hosts {
130+
hm.unlockedAddHostInfo(&HostInfo{
131+
remote: h.remoteAddr,
132+
ConnectionState: &ConnectionState{
133+
peerCert: nil,
134+
},
135+
localIndexId: h.localIndexId,
136+
remoteIndexId: h.remoteIndexId,
137+
vpnIp: h.vpnIp,
138+
}, &Interface{})
139+
}
140+
141+
iter := listHostMapHostsIter(hm)
142+
var results []ControlHostInfo
143+
144+
for h := range iter {
145+
results = append(results, *h)
146+
}
147+
148+
assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator")
149+
for i, h := range hosts {
150+
assert.Equal(t, h.vpnIp, results[i].VpnIp)
151+
assert.Equal(t, h.localIndexId, results[i].LocalIndex)
152+
assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex)
153+
assert.Equal(t, h.remoteAddr, results[i].CurrentRemote)
154+
}
155+
}
156+
157+
func TestListHostMapIndexesIter(t *testing.T) {
158+
l := logrus.New()
159+
hm := newHostMap(l, netip.Prefix{})
160+
hm.preferredRanges.Store(&[]netip.Prefix{})
161+
162+
hosts := []struct {
163+
vpnIp netip.Addr
164+
remoteAddr netip.AddrPort
165+
localIndexId uint32
166+
remoteIndexId uint32
167+
}{
168+
{vpnIp: netip.MustParseAddr("0.0.0.2"), remoteAddr: netip.MustParseAddrPort("0.0.0.101:4445"), localIndexId: 202, remoteIndexId: 201},
169+
{vpnIp: netip.MustParseAddr("0.0.0.3"), remoteAddr: netip.MustParseAddrPort("0.0.0.102:4446"), localIndexId: 203, remoteIndexId: 202},
170+
{vpnIp: netip.MustParseAddr("0.0.0.4"), remoteAddr: netip.MustParseAddrPort("0.0.0.103:4447"), localIndexId: 204, remoteIndexId: 203},
171+
}
172+
173+
for _, h := range hosts {
174+
hm.unlockedAddHostInfo(&HostInfo{
175+
remote: h.remoteAddr,
176+
ConnectionState: &ConnectionState{
177+
peerCert: nil,
178+
},
179+
localIndexId: h.localIndexId,
180+
remoteIndexId: h.remoteIndexId,
181+
vpnIp: h.vpnIp,
182+
}, &Interface{})
183+
}
184+
185+
iter := listHostMapIndexesIter(hm)
186+
var results []ControlHostInfo
187+
188+
for h := range iter {
189+
results = append(results, *h)
190+
}
191+
192+
assert.Equal(t, len(hosts), len(results), "expected number of hosts in iterator")
193+
for i, h := range hosts {
194+
assert.Equal(t, h.vpnIp, results[i].VpnIp)
195+
assert.Equal(t, h.localIndexId, results[i].LocalIndex)
196+
assert.Equal(t, h.remoteIndexId, results[i].RemoteIndex)
197+
assert.Equal(t, h.remoteAddr, results[i].CurrentRemote)
198+
}
199+
}
200+
113201
func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
114202
val := reflect.ValueOf(actualStruct).Elem()
115203
fields := make([]string, val.NumField())

0 commit comments

Comments
 (0)