1
+ use itertools:: { iproduct, Product } ;
1
2
use std:: net:: { IpAddr , SocketAddr } ;
2
3
3
4
pub struct SocketIterator < ' s > {
4
- ips : & ' s [ IpAddr ] ,
5
- ports : & ' s [ u16 ] ,
6
- ip_idx : usize ,
7
- ip_len : usize ,
8
- port_idx : usize ,
9
- port_len : usize ,
5
+ // product_it is a cartesian product iterator over
6
+ // the slices of ports and IP addresses.
7
+ //
8
+ // The IP/port order is intentionally reversed here since we want
9
+ // the itertools::iproduct! macro below to generate the pairs with
10
+ // all the IPs for one port before moving on to the next one
11
+ // ("hold the port, go through all the IPs, then advance the port...").
12
+ // See also the comments in the iterator implementation for an example.
13
+ product_it :
14
+ Product < Box < std:: slice:: Iter < ' s , u16 > > , Box < std:: slice:: Iter < ' s , std:: net:: IpAddr > > > ,
10
15
}
11
16
12
17
/// An iterator that receives a slice of IPs and ports and returns a Socket
@@ -16,13 +21,10 @@ pub struct SocketIterator<'s> {
16
21
/// generating a vector containing all these combinations.
17
22
impl < ' s > SocketIterator < ' s > {
18
23
pub fn new ( ips : & ' s [ IpAddr ] , ports : & ' s [ u16 ] ) -> Self {
24
+ let ports_it = Box :: new ( ports. into_iter ( ) ) ;
25
+ let ips_it = Box :: new ( ips. into_iter ( ) ) ;
19
26
Self {
20
- ip_idx : 0 ,
21
- ip_len : ips. len ( ) ,
22
- port_idx : 0 ,
23
- port_len : ports. len ( ) ,
24
- ips,
25
- ports,
27
+ product_it : iproduct ! ( ports_it, ips_it) ,
26
28
}
27
29
}
28
30
}
@@ -41,20 +43,10 @@ impl<'s> Iterator for SocketIterator<'s> {
41
43
/// it.next(); // 192.168.0.1:443
42
44
/// it.next(); // None
43
45
fn next ( & mut self ) -> Option < Self :: Item > {
44
- if self . port_idx == self . port_len {
45
- return None ;
46
+ match self . product_it . next ( ) {
47
+ None => None ,
48
+ Some ( ( port, ip) ) => Some ( SocketAddr :: new ( * ip, * port) ) ,
46
49
}
47
-
48
- self . ip_idx = self . ip_idx % self . ip_len ;
49
-
50
- let socket = SocketAddr :: new ( self . ips [ self . ip_idx ] , self . ports [ self . port_idx ] ) ;
51
- self . ip_idx += 1 ;
52
-
53
- if self . ip_idx == self . ip_len {
54
- self . port_idx += 1 ;
55
- }
56
-
57
- Some ( socket)
58
50
}
59
51
}
60
52
0 commit comments