Skip to content

Commit 3ea6bc1

Browse files
authored
RUST-802 Support Unix Domain Sockets (mongodb#908)
1 parent c242539 commit 3ea6bc1

File tree

15 files changed

+443
-135
lines changed

15 files changed

+443
-135
lines changed

src/client/auth.rs

-23
Original file line numberDiff line numberDiff line change
@@ -397,29 +397,6 @@ pub struct Credential {
397397
}
398398

399399
impl Credential {
400-
#[cfg(all(test, not(feature = "sync"), not(feature = "tokio-sync")))]
401-
pub(crate) fn into_document(mut self) -> Document {
402-
use crate::bson::Bson;
403-
404-
let mut doc = Document::new();
405-
406-
if let Some(s) = self.username.take() {
407-
doc.insert("username", s);
408-
}
409-
410-
if let Some(s) = self.password.take() {
411-
doc.insert("password", s);
412-
} else {
413-
doc.insert("password", Bson::Null);
414-
}
415-
416-
if let Some(s) = self.source.take() {
417-
doc.insert("db", s);
418-
}
419-
420-
doc
421-
}
422-
423400
pub(crate) fn resolved_source(&self) -> &str {
424401
self.mechanism
425402
.as_ref()

src/client/options.rs

+102-23
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod test;
44
mod resolver_config;
55

66
use std::{
7+
borrow::Cow,
78
cmp::Ordering,
89
collections::HashSet,
910
convert::TryFrom,
@@ -91,14 +92,11 @@ lazy_static! {
9192
};
9293

9394
static ref ILLEGAL_DATABASE_CHARACTERS: HashSet<&'static char> = {
94-
['/', '\\', ' ', '"', '$', '.'].iter().collect()
95+
['/', '\\', ' ', '"', '$'].iter().collect()
9596
};
9697
}
9798

9899
/// An enum representing the address of a MongoDB server.
99-
///
100-
/// Currently this just supports addresses that can be connected to over TCP, but alternative
101-
/// address types may be supported in the future (e.g. Unix Domain Socket paths).
102100
#[derive(Clone, Debug, Eq, Serialize)]
103101
#[non_exhaustive]
104102
pub enum ServerAddress {
@@ -112,6 +110,12 @@ pub enum ServerAddress {
112110
/// The default is 27017.
113111
port: Option<u16>,
114112
},
113+
/// A Unix Domain Socket path.
114+
#[cfg(unix)]
115+
Unix {
116+
/// The path to the Unix Domain Socket.
117+
path: PathBuf,
118+
},
115119
}
116120

117121
impl<'de> Deserialize<'de> for ServerAddress {
@@ -144,6 +148,10 @@ impl PartialEq for ServerAddress {
144148
port: other_port,
145149
},
146150
) => host == other_host && port.unwrap_or(27017) == other_port.unwrap_or(27017),
151+
#[cfg(unix)]
152+
(Self::Unix { path }, Self::Unix { path: other_path }) => path == other_path,
153+
#[cfg(unix)]
154+
_ => false,
147155
}
148156
}
149157
}
@@ -158,6 +166,8 @@ impl Hash for ServerAddress {
158166
host.hash(state);
159167
port.unwrap_or(27017).hash(state);
160168
}
169+
#[cfg(unix)]
170+
Self::Unix { path } => path.hash(state),
161171
}
162172
}
163173
}
@@ -173,6 +183,15 @@ impl ServerAddress {
173183
/// Parses an address string into a `ServerAddress`.
174184
pub fn parse(address: impl AsRef<str>) -> Result<Self> {
175185
let address = address.as_ref();
186+
// checks if the address is a unix domain socket
187+
#[cfg(unix)]
188+
{
189+
if address.ends_with(".sock") {
190+
return Ok(ServerAddress::Unix {
191+
path: PathBuf::from(address),
192+
});
193+
}
194+
}
176195
let mut parts = address.split(':');
177196
let hostname = match parts.next() {
178197
Some(part) => {
@@ -243,18 +262,29 @@ impl ServerAddress {
243262
"port": port.map(|i| Bson::Int32(i.into())).unwrap_or(Bson::Null)
244263
}
245264
}
265+
#[cfg(unix)]
266+
Self::Unix { path } => {
267+
doc! {
268+
"host": path.to_string_lossy().as_ref(),
269+
"port": Bson::Null,
270+
}
271+
}
246272
}
247273
}
248274

249-
pub(crate) fn host(&self) -> &str {
275+
pub(crate) fn host(&self) -> Cow<'_, str> {
250276
match self {
251-
Self::Tcp { host, .. } => host.as_str(),
277+
Self::Tcp { host, .. } => Cow::Borrowed(host.as_str()),
278+
#[cfg(unix)]
279+
Self::Unix { path } => path.to_string_lossy(),
252280
}
253281
}
254282

255283
pub(crate) fn port(&self) -> Option<u16> {
256284
match self {
257285
Self::Tcp { port, .. } => *port,
286+
#[cfg(unix)]
287+
Self::Unix { .. } => None,
258288
}
259289
}
260290
}
@@ -265,6 +295,8 @@ impl fmt::Display for ServerAddress {
265295
Self::Tcp { host, port } => {
266296
write!(fmt, "{}:{}", host, port.unwrap_or(DEFAULT_PORT))
267297
}
298+
#[cfg(unix)]
299+
Self::Unix { path } => write!(fmt, "{}", path.display()),
268300
}
269301
}
270302
}
@@ -1580,10 +1612,26 @@ impl ConnectionString {
15801612
None => (None, None),
15811613
};
15821614

1583-
let host_list: Result<Vec<_>> =
1584-
hosts_section.split(',').map(ServerAddress::parse).collect();
1585-
1586-
let host_list = host_list?;
1615+
let mut host_list = Vec::with_capacity(hosts_section.len());
1616+
for host in hosts_section.split(',') {
1617+
let address = if host.ends_with(".sock") {
1618+
#[cfg(unix)]
1619+
{
1620+
ServerAddress::parse(percent_decode(
1621+
host,
1622+
"Unix domain sockets must be URL-encoded",
1623+
)?)
1624+
}
1625+
#[cfg(not(unix))]
1626+
return Err(ErrorKind::InvalidArgument {
1627+
message: "Unix domain sockets are not supported on this platform".to_string(),
1628+
}
1629+
.into());
1630+
} else {
1631+
ServerAddress::parse(host)
1632+
}?;
1633+
host_list.push(address);
1634+
}
15871635

15881636
let hosts = if srv {
15891637
if host_list.len() != 1 {
@@ -1592,16 +1640,26 @@ impl ConnectionString {
15921640
}
15931641
.into());
15941642
}
1595-
// Unwrap safety: the `len` check above guarantees this can't fail.
1596-
let ServerAddress::Tcp { host, port } = host_list.into_iter().next().unwrap();
15971643

1598-
if port.is_some() {
1599-
return Err(ErrorKind::InvalidArgument {
1600-
message: "a port cannot be specified with 'mongodb+srv'".into(),
1644+
// Unwrap safety: the `len` check above guarantees this can't fail.
1645+
match host_list.into_iter().next().unwrap() {
1646+
ServerAddress::Tcp { host, port } => {
1647+
if port.is_some() {
1648+
return Err(ErrorKind::InvalidArgument {
1649+
message: "a port cannot be specified with 'mongodb+srv'".into(),
1650+
}
1651+
.into());
1652+
}
1653+
HostInfo::DnsRecord(host)
1654+
}
1655+
#[cfg(unix)]
1656+
ServerAddress::Unix { .. } => {
1657+
return Err(ErrorKind::InvalidArgument {
1658+
message: "unix sockets cannot be used with 'mongodb+srv'".into(),
1659+
}
1660+
.into());
16011661
}
1602-
.into());
16031662
}
1604-
HostInfo::DnsRecord(host)
16051663
} else {
16061664
HostInfo::HostIdentifiers(host_list)
16071665
};
@@ -2299,18 +2357,39 @@ mod tests {
22992357
#[test]
23002358
fn test_parse_address_with_from_str() {
23012359
let x = "localhost:27017".parse::<ServerAddress>().unwrap();
2302-
let ServerAddress::Tcp { host, port } = x;
2303-
assert_eq!(host, "localhost");
2304-
assert_eq!(port, Some(27017));
2360+
match x {
2361+
ServerAddress::Tcp { host, port } => {
2362+
assert_eq!(host, "localhost");
2363+
assert_eq!(port, Some(27017));
2364+
}
2365+
#[cfg(unix)]
2366+
_ => panic!("expected ServerAddress::Tcp"),
2367+
}
23052368

23062369
// Port defaults to 27017 (so this doesn't fail)
23072370
let x = "localhost".parse::<ServerAddress>().unwrap();
2308-
let ServerAddress::Tcp { host, port } = x;
2309-
assert_eq!(host, "localhost");
2310-
assert_eq!(port, None);
2371+
match x {
2372+
ServerAddress::Tcp { host, port } => {
2373+
assert_eq!(host, "localhost");
2374+
assert_eq!(port, None);
2375+
}
2376+
#[cfg(unix)]
2377+
_ => panic!("expected ServerAddress::Tcp"),
2378+
}
23112379

23122380
let x = "localhost:not a number".parse::<ServerAddress>();
23132381
assert!(x.is_err());
2382+
2383+
#[cfg(unix)]
2384+
{
2385+
let x = "/path/to/socket.sock".parse::<ServerAddress>().unwrap();
2386+
match x {
2387+
ServerAddress::Unix { path } => {
2388+
assert_eq!(path.to_str().unwrap(), "/path/to/socket.sock");
2389+
}
2390+
_ => panic!("expected ServerAddress::Unix"),
2391+
}
2392+
}
23142393
}
23152394

23162395
#[cfg_attr(feature = "tokio-runtime", tokio::test)]

0 commit comments

Comments
 (0)