Skip to content

Commit 8df910f

Browse files
committed
refactor: better upstream key
1 parent adc393a commit 8df910f

File tree

5 files changed

+166
-156
lines changed

5 files changed

+166
-156
lines changed

capybara-core/src/pipeline/http/pipeline_access_log.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ mod tests {
457457

458458
{
459459
use tokio::time;
460-
time::sleep(time::Duration::from_millis(123)).await;
460+
time::sleep(Duration::from_millis(123)).await;
461461
}
462462

463463
assert!(p

capybara-core/src/proto.rs

+73-80
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,70 @@
1+
use async_trait::async_trait;
2+
use rustls::pki_types::ServerName;
13
use std::fmt::{Display, Formatter};
24
use std::net::{IpAddr, SocketAddr};
35
use std::str::FromStr;
46

5-
use async_trait::async_trait;
6-
use rustls::pki_types::ServerName;
7-
87
use capybara_util::cachestr::Cachestr;
98

109
use crate::{CapybaraError, Result};
1110

1211
#[derive(Clone, Hash, Eq, PartialEq)]
1312
pub enum UpstreamKey {
14-
Tcp(SocketAddr),
15-
Tls(SocketAddr, ServerName<'static>),
16-
TcpHP(Cachestr, u16),
17-
TlsHP(Cachestr, u16, ServerName<'static>),
13+
Tcp(Addr),
14+
Tls(Addr),
1815
Tag(Cachestr),
1916
}
2017

18+
#[derive(Clone, Hash, Eq, PartialEq)]
19+
pub enum Addr {
20+
SocketAddr(SocketAddr),
21+
Host(Cachestr, u16),
22+
}
23+
24+
impl Addr {
25+
fn parse_from(s: &str, default_port: Option<u16>) -> Result<Self> {
26+
let (host, port) = host_and_port(s)?;
27+
28+
let port = match port {
29+
None => {
30+
default_port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?
31+
}
32+
Some(port) => port,
33+
};
34+
35+
if let Ok(addr) = host.parse::<IpAddr>() {
36+
return Ok(Addr::SocketAddr(SocketAddr::new(addr, port)));
37+
}
38+
39+
Ok(Addr::Host(Cachestr::from(host), port))
40+
}
41+
}
42+
43+
impl Display for Addr {
44+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45+
match self {
46+
Addr::SocketAddr(addr) => write!(f, "{}", addr),
47+
Addr::Host(host, port) => write!(f, "{}:{}", host, port),
48+
}
49+
}
50+
}
51+
52+
#[inline]
53+
fn host_and_port(s: &str) -> Result<(&str, Option<u16>)> {
54+
let mut sp = s.splitn(2, ':');
55+
56+
match sp.next() {
57+
None => Err(CapybaraError::InvalidUpstream(s.to_string().into())),
58+
Some(first) => match sp.next() {
59+
Some(second) => match second.parse::<u16>() {
60+
Ok(port) => Ok((first, Some(port))),
61+
Err(_) => Err(CapybaraError::InvalidUpstream(s.to_string().into())),
62+
},
63+
None => Ok((first, None)),
64+
},
65+
}
66+
}
67+
2168
impl FromStr for UpstreamKey {
2269
type Err = CapybaraError;
2370

@@ -31,29 +78,13 @@ impl FromStr for UpstreamKey {
3178
port == 443
3279
}
3380

34-
fn host_and_port(s: &str) -> Result<(&str, Option<u16>)> {
35-
let mut sp = s.splitn(2, ':');
36-
37-
match sp.next() {
38-
None => Err(CapybaraError::InvalidUpstream(s.to_string().into())),
39-
Some(first) => match sp.next() {
40-
Some(second) => match second.parse::<u16>() {
41-
Ok(port) => Ok((first, Some(port))),
42-
Err(_) => Err(CapybaraError::InvalidUpstream(s.to_string().into())),
43-
},
44-
None => Ok((first, None)),
45-
},
46-
}
47-
}
48-
4981
fn to_sni(sni: &str) -> Result<ServerName<'static>> {
5082
ServerName::try_from(sni)
5183
.map_err(|_| CapybaraError::InvalidTlsSni(sni.to_string().into()))
5284
.map(|it| it.to_owned())
5385
}
5486

5587
// FIXME: too many duplicated codes
56-
5788
if let Some(suffix) = s.strip_prefix("upstream://") {
5889
return if suffix.is_empty() {
5990
Err(CapybaraError::InvalidUpstream(s.to_string().into()))
@@ -63,74 +94,42 @@ impl FromStr for UpstreamKey {
6394
}
6495

6596
if let Some(suffix) = s.strip_prefix("tcp://") {
66-
let (host, port) = host_and_port(suffix)?;
67-
let port = port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?;
68-
return Ok(match host.parse::<IpAddr>() {
69-
Ok(ip) => UpstreamKey::Tcp(SocketAddr::new(ip, port)),
70-
Err(_) => UpstreamKey::TcpHP(Cachestr::from(host), port),
71-
});
97+
let addr = Addr::parse_from(suffix, None)?;
98+
return Ok(UpstreamKey::Tcp(addr));
7299
}
73100

74101
if let Some(suffix) = s.strip_prefix("tls://") {
75-
let (host, port) = host_and_port(suffix)?;
76-
let port = port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?;
77-
return Ok(match host.parse::<IpAddr>() {
78-
Ok(ip) => {
79-
let server_name = ServerName::from(ip);
80-
UpstreamKey::Tls(SocketAddr::new(ip, port), server_name)
81-
}
82-
Err(_) => UpstreamKey::TlsHP(Cachestr::from(host), port, to_sni(host)?),
83-
});
102+
let addr = Addr::parse_from(suffix, Some(443))?;
103+
return Ok(UpstreamKey::Tls(addr));
84104
}
85105

86106
if let Some(suffix) = s.strip_prefix("http://") {
87-
let (host, port) = host_and_port(suffix)?;
88-
let port = port.unwrap_or(80);
89-
return Ok(match host.parse::<IpAddr>() {
90-
Ok(ip) => UpstreamKey::Tcp(SocketAddr::new(ip, port)),
91-
Err(_) => UpstreamKey::TcpHP(Cachestr::from(host), port),
92-
});
107+
let addr = Addr::parse_from(suffix, Some(80))?;
108+
return Ok(UpstreamKey::Tcp(addr));
93109
}
94110

95111
if let Some(suffix) = s.strip_prefix("https://") {
96-
let (host, port) = host_and_port(suffix)?;
97-
let port = port.unwrap_or(443);
98-
return Ok(match host.parse::<IpAddr>() {
99-
Ok(ip) => {
100-
let server_name = ServerName::from(ip);
101-
UpstreamKey::Tls(SocketAddr::new(ip, port), server_name)
102-
}
103-
Err(_) => UpstreamKey::TlsHP(Cachestr::from(host), port, to_sni(host)?),
104-
});
112+
let addr = Addr::parse_from(suffix, Some(443))?;
113+
return Ok(UpstreamKey::Tls(addr));
105114
}
106115

107116
let (host, port) = host_and_port(s)?;
108117
let port = port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?;
109-
Ok(match host.parse::<IpAddr>() {
110-
Ok(ip) => UpstreamKey::Tcp(SocketAddr::new(ip, port)),
111-
Err(_) => UpstreamKey::TcpHP(Cachestr::from(host), port),
112-
})
118+
let addr = match host.parse::<IpAddr>() {
119+
Ok(ip) => Addr::SocketAddr(SocketAddr::new(ip, port)),
120+
Err(_) => Addr::Host(Cachestr::from(host), port),
121+
};
122+
123+
Ok(UpstreamKey::Tcp(addr))
113124
}
114125
}
115126

116127
impl Display for UpstreamKey {
117128
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
118129
match self {
119130
UpstreamKey::Tcp(addr) => write!(f, "tcp://{}", addr),
120-
UpstreamKey::Tls(addr, sni) => {
121-
if let ServerName::DnsName(name) = sni {
122-
return write!(f, "tls://{}?sni={}", addr, name.as_ref());
123-
}
124-
write!(f, "tls://{}", addr)
125-
}
126-
UpstreamKey::TcpHP(addr, port) => write!(f, "tcp://{}:{}", addr, port),
127-
UpstreamKey::TlsHP(addr, port, sni) => {
128-
if let ServerName::DnsName(name) = sni {
129-
return write!(f, "tls://{}:{}?sni={}", addr, port, name.as_ref());
130-
}
131-
write!(f, "tls://{}:{}", addr, port)
132-
}
133-
UpstreamKey::Tag(tag) => write!(f, "upstream://{}", tag.as_ref()),
131+
UpstreamKey::Tls(addr) => write!(f, "tls://{}", addr),
132+
UpstreamKey::Tag(tag) => write!(f, "upstream://{}", tag),
134133
}
135134
}
136135
}
@@ -182,18 +181,12 @@ mod tests {
182181
("https://127.0.0.1:8443", "tls://127.0.0.1:8443"),
183182
// schema+host
184183
("http://example.com", "tcp://example.com:80"),
185-
(
186-
"https://example.com",
187-
"tls://example.com:443?sni=example.com",
188-
),
184+
("https://example.com", "tls://example.com:443"),
189185
// schema+host+port
190186
("tcp://localhost:8080", "tcp://localhost:8080"),
191-
("tls://localhost:8443", "tls://localhost:8443?sni=localhost"),
187+
("tls://localhost:8443", "tls://localhost:8443"),
192188
("http://localhost:8080", "tcp://localhost:8080"),
193-
(
194-
"https://localhost:8443",
195-
"tls://localhost:8443?sni=localhost",
196-
),
189+
("https://localhost:8443", "tls://localhost:8443"),
197190
] {
198191
assert!(s.parse::<UpstreamKey>().is_ok_and(|it| {
199192
let actual = it.to_string();

capybara-core/src/upstream/misc.rs

+43-31
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
use rustls::pki_types::ServerName;
12
use std::fmt::{Display, Formatter};
23
use std::net::{IpAddr, SocketAddr};
3-
44
use tokio::net::TcpStream;
55

6-
use crate::proto::UpstreamKey;
6+
use crate::proto::{Addr, UpstreamKey};
77
use crate::resolver::DEFAULT_RESOLVER;
88
use crate::transport::{tcp, tls};
9-
use crate::Result;
9+
use crate::{CapybaraError, Result};
1010

1111
pub(crate) enum ClientStream {
1212
Tcp(TcpStream),
@@ -27,36 +27,48 @@ impl Display for ClientStream {
2727

2828
pub(crate) async fn establish(upstream: &UpstreamKey, buff_size: usize) -> Result<ClientStream> {
2929
let stream = match upstream {
30-
UpstreamKey::Tcp(addr) => ClientStream::Tcp(
31-
tcp::TcpStreamBuilder::new(*addr)
32-
.buff_size(buff_size)
33-
.build()?,
34-
),
35-
UpstreamKey::Tls(addr, sni) => {
36-
let stream = tcp::TcpStreamBuilder::new(*addr)
37-
.buff_size(buff_size)
38-
.build()?;
39-
let c = tls::TlsConnectorBuilder::new().build()?;
40-
ClientStream::Tls(c.connect(Clone::clone(sni), stream).await?)
41-
}
42-
UpstreamKey::TcpHP(domain, port) => {
43-
let ip = resolve(domain.as_ref()).await?;
44-
let addr = SocketAddr::new(ip, *port);
45-
ClientStream::Tcp(
46-
tcp::TcpStreamBuilder::new(addr)
30+
UpstreamKey::Tcp(addr) => match addr {
31+
Addr::SocketAddr(addr) => ClientStream::Tcp(
32+
tcp::TcpStreamBuilder::new(*addr)
4733
.buff_size(buff_size)
4834
.build()?,
49-
)
50-
}
51-
UpstreamKey::TlsHP(domain, port, sni) => {
52-
let ip = resolve(domain.as_ref()).await?;
53-
let addr = SocketAddr::new(ip, *port);
54-
let stream = tcp::TcpStreamBuilder::new(addr)
55-
.buff_size(buff_size)
56-
.build()?;
57-
let c = tls::TlsConnectorBuilder::new().build()?;
58-
let stream = c.connect(Clone::clone(sni), stream).await?;
59-
ClientStream::Tls(stream)
35+
),
36+
Addr::Host(host, port) => {
37+
let ip = resolve(host.as_ref()).await?;
38+
let addr = SocketAddr::new(ip, *port);
39+
ClientStream::Tcp(
40+
tcp::TcpStreamBuilder::new(addr)
41+
.buff_size(buff_size)
42+
.build()?,
43+
)
44+
}
45+
},
46+
UpstreamKey::Tls(addr) => {
47+
match addr {
48+
Addr::SocketAddr(addr) => {
49+
let stream = tcp::TcpStreamBuilder::new(*addr)
50+
.buff_size(buff_size)
51+
.build()?;
52+
let c = tls::TlsConnectorBuilder::new().build()?;
53+
54+
let sni = ServerName::from(addr.ip());
55+
ClientStream::Tls(c.connect(sni, stream).await?)
56+
}
57+
Addr::Host(host, port) => {
58+
let ip = resolve(host.as_ref()).await?;
59+
let addr = SocketAddr::new(ip, *port);
60+
let stream = tcp::TcpStreamBuilder::new(addr)
61+
.buff_size(buff_size)
62+
.build()?;
63+
let c = tls::TlsConnectorBuilder::new().build()?;
64+
// TODO: how to reduce creating times of sni?
65+
let sni = ServerName::try_from(host.as_ref())
66+
.map_err(|e| CapybaraError::Other(e.into()))?
67+
.to_owned();
68+
let stream = c.connect(sni, stream).await?;
69+
ClientStream::Tls(stream)
70+
}
71+
}
6072
}
6173
UpstreamKey::Tag(tag) => {
6274
todo!("establish with tag is not supported yet")

0 commit comments

Comments
 (0)