@@ -37,6 +37,8 @@ pin_project_lite::pin_project! {
37
37
#[ pin]
38
38
body: Body ,
39
39
local: TypeMap ,
40
+ local_addr: Option <String >,
41
+ peer_addr: Option <String >,
40
42
}
41
43
}
42
44
@@ -53,6 +55,56 @@ impl Request {
53
55
sender : Some ( sender) ,
54
56
receiver : Some ( receiver) ,
55
57
local : TypeMap :: new ( ) ,
58
+ peer_addr : None ,
59
+ local_addr : None ,
60
+ }
61
+ }
62
+
63
+ /// Sets a string representation of the peer address of this
64
+ /// request. This might take the form of an ip/fqdn and port or a
65
+ /// local socket address.
66
+ pub fn set_peer_addr ( & mut self , peer_addr : Option < impl std:: string:: ToString > ) {
67
+ self . peer_addr = peer_addr. map ( |addr| addr. to_string ( ) ) ;
68
+ }
69
+
70
+ /// Sets a string representation of the local address that this
71
+ /// request was received on. This might take the form of an ip/fqdn and
72
+ /// port, or a local socket address.
73
+ pub fn set_local_addr ( & mut self , local_addr : Option < impl std:: string:: ToString > ) {
74
+ self . local_addr = local_addr. map ( |addr| addr. to_string ( ) ) ;
75
+ }
76
+
77
+ /// Get the peer socket address for the underlying transport, if
78
+ /// that information is available for this request.
79
+ pub fn peer_addr ( & self ) -> Option < & str > {
80
+ self . peer_addr . as_deref ( )
81
+ }
82
+
83
+ /// Get the local socket address for the underlying transport, if
84
+ /// that information is available for this request.
85
+ pub fn local_addr ( & self ) -> Option < & str > {
86
+ self . local_addr . as_deref ( )
87
+ }
88
+
89
+ /// Get the remote address for this request.
90
+ pub fn remote ( & self ) -> Option < & str > {
91
+ self . forwarded_for ( ) . or ( self . peer_addr ( ) )
92
+ }
93
+
94
+ fn forwarded_for ( & self ) -> Option < & str > {
95
+ if let Some ( header) = self . header ( & "Forwarded" . parse ( ) . unwrap ( ) ) {
96
+ header. as_str ( ) . split ( ";" ) . find_map ( |key_equals_value| {
97
+ let parts = key_equals_value. split ( "=" ) . collect :: < Vec < _ > > ( ) ;
98
+ if parts. len ( ) == 2 && parts[ 0 ] . eq_ignore_ascii_case ( "for" ) {
99
+ Some ( parts[ 1 ] )
100
+ } else {
101
+ None
102
+ }
103
+ } )
104
+ } else if let Some ( header) = self . header ( & "X-Forwarded-For" . parse ( ) . unwrap ( ) ) {
105
+ header. as_str ( ) . split ( "," ) . next ( )
106
+ } else {
107
+ None
56
108
}
57
109
}
58
110
@@ -523,6 +575,8 @@ impl Clone for Request {
523
575
receiver : self . receiver . clone ( ) ,
524
576
body : Body :: empty ( ) ,
525
577
local : TypeMap :: new ( ) ,
578
+ peer_addr : self . peer_addr . clone ( ) ,
579
+ local_addr : self . local_addr . clone ( ) ,
526
580
}
527
581
}
528
582
}
@@ -598,3 +652,94 @@ impl<'a> IntoIterator for &'a mut Request {
598
652
self . headers . iter_mut ( )
599
653
}
600
654
}
655
+
656
+ #[ cfg( test) ]
657
+ mod tests {
658
+ use super :: * ;
659
+
660
+ fn build_test_request ( ) -> Request {
661
+ Request :: new ( Method :: Get , "http://irrelevant/" . parse ( ) . unwrap ( ) )
662
+ }
663
+
664
+ fn set_x_forwarded_for ( request : & mut Request , client : & ' static str ) {
665
+ request
666
+ . insert_header (
667
+ "x-forwarded-for" ,
668
+ format ! ( "{},proxy.com,other-proxy.com" , client) ,
669
+ )
670
+ . unwrap ( ) ;
671
+ }
672
+
673
+ fn set_forwarded ( request : & mut Request , client : & ' static str ) {
674
+ request
675
+ . insert_header (
676
+ "Forwarded" ,
677
+ format ! ( "by=something.com;for={};host=host.com;proto=http" , client) ,
678
+ )
679
+ . unwrap ( ) ;
680
+ }
681
+
682
+ #[ test]
683
+ fn test_remote_and_forwarded_for_when_forwarded_is_properly_formatted ( ) {
684
+ let mut request = build_test_request ( ) ;
685
+ request. set_peer_addr ( Some ( "127.0.0.1:8000" ) ) ;
686
+ set_forwarded ( & mut request, "127.0.0.1:8001" ) ;
687
+
688
+ assert_eq ! ( request. forwarded_for( ) , Some ( "127.0.0.1:8001" ) ) ;
689
+ assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8001" ) ) ;
690
+ }
691
+
692
+ #[ test]
693
+ fn test_remote_and_forwarded_for_when_forwarded_is_improperly_formatted ( ) {
694
+ let mut request = build_test_request ( ) ;
695
+ request. set_peer_addr ( Some (
696
+ "127.0.0.1:8000" . parse :: < std:: net:: SocketAddr > ( ) . unwrap ( ) ,
697
+ ) ) ;
698
+
699
+ request
700
+ . insert_header ( "Forwarded" , "this is an improperly ;;; formatted header" )
701
+ . unwrap ( ) ;
702
+
703
+ assert_eq ! ( request. forwarded_for( ) , None ) ;
704
+ assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8000" ) ) ;
705
+ }
706
+
707
+ #[ test]
708
+ fn test_remote_and_forwarded_for_when_x_forwarded_for_is_set ( ) {
709
+ let mut request = build_test_request ( ) ;
710
+ request. set_peer_addr ( Some (
711
+ std:: path:: PathBuf :: from ( "/dev/random" ) . to_str ( ) . unwrap ( ) ,
712
+ ) ) ;
713
+ set_x_forwarded_for ( & mut request, "forwarded-host.com" ) ;
714
+
715
+ assert_eq ! ( request. forwarded_for( ) , Some ( "forwarded-host.com" ) ) ;
716
+ assert_eq ! ( request. remote( ) , Some ( "forwarded-host.com" ) ) ;
717
+ }
718
+
719
+ #[ test]
720
+ fn test_remote_and_forwarded_for_when_both_forwarding_headers_are_set ( ) {
721
+ let mut request = build_test_request ( ) ;
722
+ set_forwarded ( & mut request, "forwarded.com" ) ;
723
+ set_x_forwarded_for ( & mut request, "forwarded-for-client.com" ) ;
724
+ request. peer_addr = Some ( "127.0.0.1:8000" . into ( ) ) ;
725
+
726
+ assert_eq ! ( request. forwarded_for( ) , Some ( "forwarded.com" . into( ) ) ) ;
727
+ assert_eq ! ( request. remote( ) , Some ( "forwarded.com" . into( ) ) ) ;
728
+ }
729
+
730
+ #[ test]
731
+ fn test_remote_falling_back_to_peer_addr ( ) {
732
+ let mut request = build_test_request ( ) ;
733
+ request. peer_addr = Some ( "127.0.0.1:8000" . into ( ) ) ;
734
+
735
+ assert_eq ! ( request. forwarded_for( ) , None ) ;
736
+ assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8000" . into( ) ) ) ;
737
+ }
738
+
739
+ #[ test]
740
+ fn test_remote_and_forwarded_for_when_no_remote_available ( ) {
741
+ let request = build_test_request ( ) ;
742
+ assert_eq ! ( request. forwarded_for( ) , None ) ;
743
+ assert_eq ! ( request. remote( ) , None ) ;
744
+ }
745
+ }
0 commit comments