@@ -4,6 +4,7 @@ mod test;
4
4
mod resolver_config;
5
5
6
6
use std:: {
7
+ borrow:: Cow ,
7
8
cmp:: Ordering ,
8
9
collections:: HashSet ,
9
10
convert:: TryFrom ,
@@ -91,14 +92,11 @@ lazy_static! {
91
92
} ;
92
93
93
94
static ref ILLEGAL_DATABASE_CHARACTERS : HashSet <& ' static char > = {
94
- [ '/' , '\\' , ' ' , '"' , '$' , '.' ] . iter( ) . collect( )
95
+ [ '/' , '\\' , ' ' , '"' , '$' ] . iter( ) . collect( )
95
96
} ;
96
97
}
97
98
98
99
/// An enum representing the address of a MongoDB server.
99
- ///
100
- /// Currently this just supports addresses that can be connected to over TCP, but alternative
101
- /// address types may be supported in the future (e.g. Unix Domain Socket paths).
102
100
#[ derive( Clone , Debug , Eq , Serialize ) ]
103
101
#[ non_exhaustive]
104
102
pub enum ServerAddress {
@@ -112,6 +110,12 @@ pub enum ServerAddress {
112
110
/// The default is 27017.
113
111
port : Option < u16 > ,
114
112
} ,
113
+ /// A Unix Domain Socket path.
114
+ #[ cfg( unix) ]
115
+ Unix {
116
+ /// The path to the Unix Domain Socket.
117
+ path : PathBuf ,
118
+ } ,
115
119
}
116
120
117
121
impl < ' de > Deserialize < ' de > for ServerAddress {
@@ -144,6 +148,10 @@ impl PartialEq for ServerAddress {
144
148
port : other_port,
145
149
} ,
146
150
) => host == other_host && port. unwrap_or ( 27017 ) == other_port. unwrap_or ( 27017 ) ,
151
+ #[ cfg( unix) ]
152
+ ( Self :: Unix { path } , Self :: Unix { path : other_path } ) => path == other_path,
153
+ #[ cfg( unix) ]
154
+ _ => false ,
147
155
}
148
156
}
149
157
}
@@ -158,6 +166,8 @@ impl Hash for ServerAddress {
158
166
host. hash ( state) ;
159
167
port. unwrap_or ( 27017 ) . hash ( state) ;
160
168
}
169
+ #[ cfg( unix) ]
170
+ Self :: Unix { path } => path. hash ( state) ,
161
171
}
162
172
}
163
173
}
@@ -173,6 +183,15 @@ impl ServerAddress {
173
183
/// Parses an address string into a `ServerAddress`.
174
184
pub fn parse ( address : impl AsRef < str > ) -> Result < Self > {
175
185
let address = address. as_ref ( ) ;
186
+ // checks if the address is a unix domain socket
187
+ #[ cfg( unix) ]
188
+ {
189
+ if address. ends_with ( ".sock" ) {
190
+ return Ok ( ServerAddress :: Unix {
191
+ path : PathBuf :: from ( address) ,
192
+ } ) ;
193
+ }
194
+ }
176
195
let mut parts = address. split ( ':' ) ;
177
196
let hostname = match parts. next ( ) {
178
197
Some ( part) => {
@@ -243,18 +262,29 @@ impl ServerAddress {
243
262
"port" : port. map( |i| Bson :: Int32 ( i. into( ) ) ) . unwrap_or( Bson :: Null )
244
263
}
245
264
}
265
+ #[ cfg( unix) ]
266
+ Self :: Unix { path } => {
267
+ doc ! {
268
+ "host" : path. to_string_lossy( ) . as_ref( ) ,
269
+ "port" : Bson :: Null ,
270
+ }
271
+ }
246
272
}
247
273
}
248
274
249
- pub ( crate ) fn host ( & self ) -> & str {
275
+ pub ( crate ) fn host ( & self ) -> Cow < ' _ , str > {
250
276
match self {
251
- Self :: Tcp { host, .. } => host. as_str ( ) ,
277
+ Self :: Tcp { host, .. } => Cow :: Borrowed ( host. as_str ( ) ) ,
278
+ #[ cfg( unix) ]
279
+ Self :: Unix { path } => path. to_string_lossy ( ) ,
252
280
}
253
281
}
254
282
255
283
pub ( crate ) fn port ( & self ) -> Option < u16 > {
256
284
match self {
257
285
Self :: Tcp { port, .. } => * port,
286
+ #[ cfg( unix) ]
287
+ Self :: Unix { .. } => None ,
258
288
}
259
289
}
260
290
}
@@ -265,6 +295,8 @@ impl fmt::Display for ServerAddress {
265
295
Self :: Tcp { host, port } => {
266
296
write ! ( fmt, "{}:{}" , host, port. unwrap_or( DEFAULT_PORT ) )
267
297
}
298
+ #[ cfg( unix) ]
299
+ Self :: Unix { path } => write ! ( fmt, "{}" , path. display( ) ) ,
268
300
}
269
301
}
270
302
}
@@ -1580,10 +1612,26 @@ impl ConnectionString {
1580
1612
None => ( None , None ) ,
1581
1613
} ;
1582
1614
1583
- let host_list: Result < Vec < _ > > =
1584
- hosts_section. split ( ',' ) . map ( ServerAddress :: parse) . collect ( ) ;
1585
-
1586
- let host_list = host_list?;
1615
+ let mut host_list = Vec :: with_capacity ( hosts_section. len ( ) ) ;
1616
+ for host in hosts_section. split ( ',' ) {
1617
+ let address = if host. ends_with ( ".sock" ) {
1618
+ #[ cfg( unix) ]
1619
+ {
1620
+ ServerAddress :: parse ( percent_decode (
1621
+ host,
1622
+ "Unix domain sockets must be URL-encoded" ,
1623
+ ) ?)
1624
+ }
1625
+ #[ cfg( not( unix) ) ]
1626
+ return Err ( ErrorKind :: InvalidArgument {
1627
+ message : "Unix domain sockets are not supported on this platform" . to_string ( ) ,
1628
+ }
1629
+ . into ( ) ) ;
1630
+ } else {
1631
+ ServerAddress :: parse ( host)
1632
+ } ?;
1633
+ host_list. push ( address) ;
1634
+ }
1587
1635
1588
1636
let hosts = if srv {
1589
1637
if host_list. len ( ) != 1 {
@@ -1592,16 +1640,26 @@ impl ConnectionString {
1592
1640
}
1593
1641
. into ( ) ) ;
1594
1642
}
1595
- // Unwrap safety: the `len` check above guarantees this can't fail.
1596
- let ServerAddress :: Tcp { host, port } = host_list. into_iter ( ) . next ( ) . unwrap ( ) ;
1597
1643
1598
- if port. is_some ( ) {
1599
- return Err ( ErrorKind :: InvalidArgument {
1600
- message : "a port cannot be specified with 'mongodb+srv'" . into ( ) ,
1644
+ // Unwrap safety: the `len` check above guarantees this can't fail.
1645
+ match host_list. into_iter ( ) . next ( ) . unwrap ( ) {
1646
+ ServerAddress :: Tcp { host, port } => {
1647
+ if port. is_some ( ) {
1648
+ return Err ( ErrorKind :: InvalidArgument {
1649
+ message : "a port cannot be specified with 'mongodb+srv'" . into ( ) ,
1650
+ }
1651
+ . into ( ) ) ;
1652
+ }
1653
+ HostInfo :: DnsRecord ( host)
1654
+ }
1655
+ #[ cfg( unix) ]
1656
+ ServerAddress :: Unix { .. } => {
1657
+ return Err ( ErrorKind :: InvalidArgument {
1658
+ message : "unix sockets cannot be used with 'mongodb+srv'" . into ( ) ,
1659
+ }
1660
+ . into ( ) ) ;
1601
1661
}
1602
- . into ( ) ) ;
1603
1662
}
1604
- HostInfo :: DnsRecord ( host)
1605
1663
} else {
1606
1664
HostInfo :: HostIdentifiers ( host_list)
1607
1665
} ;
@@ -2299,18 +2357,39 @@ mod tests {
2299
2357
#[ test]
2300
2358
fn test_parse_address_with_from_str ( ) {
2301
2359
let x = "localhost:27017" . parse :: < ServerAddress > ( ) . unwrap ( ) ;
2302
- let ServerAddress :: Tcp { host, port } = x;
2303
- assert_eq ! ( host, "localhost" ) ;
2304
- assert_eq ! ( port, Some ( 27017 ) ) ;
2360
+ match x {
2361
+ ServerAddress :: Tcp { host, port } => {
2362
+ assert_eq ! ( host, "localhost" ) ;
2363
+ assert_eq ! ( port, Some ( 27017 ) ) ;
2364
+ }
2365
+ #[ cfg( unix) ]
2366
+ _ => panic ! ( "expected ServerAddress::Tcp" ) ,
2367
+ }
2305
2368
2306
2369
// Port defaults to 27017 (so this doesn't fail)
2307
2370
let x = "localhost" . parse :: < ServerAddress > ( ) . unwrap ( ) ;
2308
- let ServerAddress :: Tcp { host, port } = x;
2309
- assert_eq ! ( host, "localhost" ) ;
2310
- assert_eq ! ( port, None ) ;
2371
+ match x {
2372
+ ServerAddress :: Tcp { host, port } => {
2373
+ assert_eq ! ( host, "localhost" ) ;
2374
+ assert_eq ! ( port, None ) ;
2375
+ }
2376
+ #[ cfg( unix) ]
2377
+ _ => panic ! ( "expected ServerAddress::Tcp" ) ,
2378
+ }
2311
2379
2312
2380
let x = "localhost:not a number" . parse :: < ServerAddress > ( ) ;
2313
2381
assert ! ( x. is_err( ) ) ;
2382
+
2383
+ #[ cfg( unix) ]
2384
+ {
2385
+ let x = "/path/to/socket.sock" . parse :: < ServerAddress > ( ) . unwrap ( ) ;
2386
+ match x {
2387
+ ServerAddress :: Unix { path } => {
2388
+ assert_eq ! ( path. to_str( ) . unwrap( ) , "/path/to/socket.sock" ) ;
2389
+ }
2390
+ _ => panic ! ( "expected ServerAddress::Unix" ) ,
2391
+ }
2392
+ }
2314
2393
}
2315
2394
2316
2395
#[ cfg_attr( feature = "tokio-runtime" , tokio:: test) ]
0 commit comments