Skip to content

Commit 3de3d03

Browse files
committed
Reconfigure
1 parent 7e611f3 commit 3de3d03

File tree

1 file changed

+127
-159
lines changed

1 file changed

+127
-159
lines changed

src/imp/mbedtls.rs

Lines changed: 127 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ fn cert_to_vec(certs_in: &[::Certificate]) -> Vec<MbedtlsCert> {
152152
certs
153153
}
154154

155-
struct TlsState {
155+
#[derive(Debug)]
156+
pub struct TlsStream<S> {
156157
ca_certs: *mut Vec<MbedtlsCert>,
157158
ca_cert_list: *mut CertList<'static>,
158159
cred_pk: *mut Pk,
@@ -162,20 +163,20 @@ struct TlsState {
162163
rng: *mut CtrDrbg<'static>,
163164
config: *mut Config<'static>,
164165
ctx: *mut Context<'static>,
166+
session: *mut Session<'static>,
167+
socket: *mut S,
165168
}
166169

167-
unsafe impl Sync for TlsState {}
168-
unsafe impl Send for TlsState {}
169-
170-
impl Clone for TlsState {
171-
fn clone(&self) -> TlsState {
172-
panic!("yolo")
173-
}
174-
}
170+
/// ???
171+
unsafe impl<S> Sync for TlsStream<S> {}
172+
unsafe impl<S> Send for TlsStream<S> {}
175173

176-
impl Drop for TlsState {
174+
impl<S> Drop for TlsStream<S> {
177175
fn drop(&mut self) {
178176
unsafe {
177+
Box::from_raw(self.session);
178+
Box::from_raw(self.socket);
179+
179180
Box::from_raw(self.ctx);
180181
Box::from_raw(self.config);
181182
Box::from_raw(self.rng);
@@ -202,122 +203,6 @@ impl Drop for TlsState {
202203
}
203204
}
204205

205-
impl TlsState {
206-
fn new_client(trust_roots: &[::Certificate],
207-
min_version: Option<Version>,
208-
max_version: Option<Version>,
209-
accept_invalid_certs: bool) -> TlsResult<Self> {
210-
211-
unsafe {
212-
let ca_vec = Box::into_raw(Box::new(cert_to_vec(trust_roots)));
213-
let ca_list = Box::into_raw(Box::new(CertList::from_vec(&mut *ca_vec).ok_or(TlsError::AesInvalidKeyLength)?));
214-
let entropy = Box::into_raw(Box::new(OsEntropy::new()));
215-
let rng = Box::into_raw(Box::new(CtrDrbg::new(&mut *entropy, None)?));
216-
let config = Box::into_raw(Box::new(Config::new(Endpoint::Client, Transport::Stream, Preset::Default)));
217-
(*config).set_rng(Some(&mut *rng));
218-
(*config).set_ca_list(Some(&mut *ca_list), None);
219-
220-
if accept_invalid_certs {
221-
(*config).set_authmode(mbedtls::ssl::config::AuthMode::None);
222-
}
223-
224-
if let Some(min_version) = min_version {
225-
(*config).set_min_version(min_version)?;
226-
}
227-
if let Some(max_version) = max_version {
228-
(*config).set_max_version(max_version)?;
229-
}
230-
231-
let ctx = Box::into_raw(Box::new(Context::new(&*config)?));
232-
233-
Ok(TlsState {
234-
ca_certs: ca_vec,
235-
ca_cert_list: ca_list,
236-
cred_pk: ::std::ptr::null_mut(),
237-
cred_certs: ::std::ptr::null_mut(),
238-
cred_cert_list: ::std::ptr::null_mut(),
239-
entropy: entropy,
240-
rng: rng,
241-
config: config,
242-
ctx: ctx,
243-
})
244-
}
245-
}
246-
247-
fn new_server(cert_chain: &[MbedtlsCert],
248-
key: &mut Pk,
249-
min_version: Option<Version>,
250-
max_version: Option<Version>) -> TlsResult<Self> {
251-
fn pk_clone(pk: &mut Pk) -> TlsResult<Pk> {
252-
let der = pk.write_private_der_vec()?;
253-
Pk::from_private_key(&der, None)
254-
}
255-
256-
unsafe {
257-
let pk = Box::into_raw(Box::new(pk_clone(key)?));
258-
let cert_chain = Box::into_raw(Box::new(cert_chain.to_vec()));
259-
let cert_list = Box::into_raw(Box::new(CertList::from_vec(&mut *cert_chain).ok_or(TlsError::CamelliaInvalidInputLength)?));
260-
let entropy = Box::into_raw(Box::new(OsEntropy::new()));
261-
let rng = Box::into_raw(Box::new(CtrDrbg::new(&mut *entropy, None)?));
262-
let config = Box::into_raw(Box::new(Config::new(Endpoint::Server, Transport::Stream, Preset::Default)));
263-
(*config).set_rng(Some(&mut *rng));
264-
(*config).push_cert(&mut *cert_list, &mut *pk)?;
265-
266-
if let Some(min_version) = min_version {
267-
(*config).set_min_version(min_version)?;
268-
}
269-
if let Some(max_version) = max_version {
270-
(*config).set_max_version(max_version)?;
271-
}
272-
273-
let ctx = Box::into_raw(Box::new(Context::new(&*config)?));
274-
275-
Ok(TlsState {
276-
ca_certs: ::std::ptr::null_mut(),
277-
ca_cert_list: ::std::ptr::null_mut(),
278-
cred_pk: pk,
279-
cred_certs: cert_chain,
280-
cred_cert_list: cert_list,
281-
entropy: entropy,
282-
rng: rng,
283-
config: config,
284-
ctx: ctx,
285-
})
286-
}
287-
}
288-
289-
fn establish<S: io::Read + io::Write>(&self, stream: S, hostname: Option<&str>) -> TlsResult<TlsStream<S>> {
290-
unsafe {
291-
let stream_ptr = Box::into_raw(Box::new(stream));
292-
let session = (*self.ctx).establish(&mut *stream_ptr, hostname)?;
293-
let yolo_session = Box::into_raw(Box::new(std::mem::transmute::<Session<'_>, Session<'static>>(session)));
294-
Ok(TlsStream {
295-
session: yolo_session,
296-
socket: stream_ptr,
297-
})
298-
}
299-
}
300-
}
301-
302-
#[derive(Debug)]
303-
pub struct TlsStream<S> {
304-
session: *mut Session<'static>,
305-
socket: *mut S,
306-
}
307-
308-
unsafe impl<S> Sync for TlsStream<S> {}
309-
unsafe impl<S> Send for TlsStream<S> {}
310-
311-
impl<S> Drop for TlsStream<S> {
312-
fn drop(&mut self) {
313-
//println!("Dropping TlsStream");
314-
unsafe {
315-
Box::from_raw(self.session);
316-
Box::from_raw(self.socket);
317-
}
318-
}
319-
}
320-
321206
#[derive(Debug)]
322207
pub struct MidHandshakeTlsStream<S> {
323208
stream: TlsStream<S>,
@@ -351,8 +236,12 @@ where
351236

352237
#[derive(Clone)]
353238
pub struct TlsConnector {
354-
state: TlsState,
355-
accept_bad_hostname: bool,
239+
min_protocol: Option<Protocol>,
240+
max_protocol: Option<Protocol>,
241+
root_certificates: Vec<::Certificate>,
242+
accept_invalid_certs: bool,
243+
accept_invalid_hostnames: bool,
244+
use_sni: bool,
356245
}
357246

358247
impl TlsConnector {
@@ -361,71 +250,150 @@ impl TlsConnector {
361250
return Err(Error::Custom("Client authentication not supported".to_owned()));
362251
}
363252

364-
let min_version = map_version(builder.min_protocol);
365-
let max_version = map_version(builder.max_protocol);
366-
367-
let state = if builder.root_certificates.len() > 0 {
368-
TlsState::new_client(&builder.root_certificates, min_version, max_version, builder.accept_invalid_certs)?
253+
let trust_roots = if builder.root_certificates.len() > 0 {
254+
builder.root_certificates.clone()
369255
} else {
370-
let trust_roots = load_ca_certs("/usr/share/ca-certificates/mozilla")?;
371-
TlsState::new_client(&trust_roots, min_version, max_version, builder.accept_invalid_certs)?
256+
load_ca_certs("/usr/share/ca-certificates/mozilla")?
372257
};
373258

374-
Ok(TlsConnector { state, accept_bad_hostname: builder.accept_invalid_certs })
259+
Ok(TlsConnector {
260+
min_protocol: builder.min_protocol,
261+
max_protocol: builder.max_protocol,
262+
root_certificates: trust_roots,
263+
accept_invalid_certs: builder.accept_invalid_certs,
264+
accept_invalid_hostnames: builder.accept_invalid_hostnames,
265+
use_sni: builder.use_sni
266+
})
375267
}
376268

377269
pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
378270
where
379271
S: io::Read + io::Write
380272
{
381-
let channel = if self.accept_bad_hostname {
382-
self.state.establish(stream, None)?
383-
} else {
384-
self.state.establish(stream, Some(domain))?
385-
};
273+
// If any of the ? fail then memory leaks ...
386274

387-
//println!("After establish, stream is {:?} at {:?}", channel.socket, channel.socket as *const S);
275+
unsafe {
276+
let ca_vec = Box::into_raw(Box::new(cert_to_vec(&self.root_certificates)));
277+
let ca_list = Box::into_raw(Box::new(CertList::from_vec(&mut *ca_vec).ok_or(TlsError::AesInvalidKeyLength)?));
278+
let entropy = Box::into_raw(Box::new(OsEntropy::new()));
279+
let rng = Box::into_raw(Box::new(CtrDrbg::new(&mut *entropy, None)?));
280+
let config = Box::into_raw(Box::new(Config::new(Endpoint::Client, Transport::Stream, Preset::Default)));
281+
(*config).set_rng(Some(&mut *rng));
282+
(*config).set_ca_list(Some(&mut *ca_list), None);
388283

389-
Ok(channel)
284+
if self.accept_invalid_certs {
285+
(*config).set_authmode(mbedtls::ssl::config::AuthMode::None);
286+
}
287+
288+
if let Some(min_version) = map_version(self.min_protocol) {
289+
(*config).set_min_version(min_version)?;
290+
}
291+
if let Some(max_version) = map_version(self.max_protocol) {
292+
(*config).set_max_version(max_version)?;
293+
}
294+
295+
let ctx = Box::into_raw(Box::new(Context::new(&*config)?));
296+
297+
let hostname = if self.accept_invalid_hostnames { None } else { Some(domain) };
298+
299+
let stream_ptr = Box::into_raw(Box::new(stream));
300+
let session = (*ctx).establish(&mut *stream_ptr, hostname)?;
301+
let session = Box::into_raw(Box::new(std::mem::transmute::<Session<'_>, Session<'static>>(session))); // yolo
302+
303+
Ok(TlsStream {
304+
ca_certs: ca_vec,
305+
ca_cert_list: ca_list,
306+
cred_pk: ::std::ptr::null_mut(),
307+
cred_certs: ::std::ptr::null_mut(),
308+
cred_cert_list: ::std::ptr::null_mut(),
309+
entropy: entropy,
310+
rng: rng,
311+
config: config,
312+
ctx: ctx,
313+
session: session,
314+
socket: stream_ptr,
315+
})
316+
}
390317
}
391318
}
392319

393320
#[derive(Clone)]
394321
pub struct TlsAcceptor {
395-
state: TlsState,
322+
identity: Pfx,
323+
min_protocol: Option<Protocol>,
324+
max_protocol: Option<Protocol>,
396325
}
397326

398327
impl TlsAcceptor {
399328
pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
329+
Ok(TlsAcceptor {
330+
identity: (builder.identity.0).0.clone(),
331+
min_protocol: builder.min_protocol,
332+
max_protocol: builder.max_protocol
333+
})
334+
}
400335

401-
let mut keys = (builder.identity.0).0.private_keys()?;
402-
let certificates = (builder.identity.0).0.certificates()?;
336+
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
337+
where
338+
S: io::Read + io::Write
339+
{
340+
let mut keys = self.identity.private_keys().map_err(Error::Pkcs12).map_err(HandshakeError::Failure)?;
341+
let certificates = self.identity.certificates().map_err(Error::Pkcs12).map_err(HandshakeError::Failure)?;
403342

404343
if keys.len() != 1 {
405-
return Err(Error::Custom("Unexpected number of keys in PKCS12 file".to_owned()))
344+
return Err(HandshakeError::Failure(Error::Custom("Unexpected number of keys in PKCS12 file".to_owned())))
406345
}
407346
if certificates.len() == 0 {
408-
return Err(Error::Custom("PKCS12 file is missing certificate chain".to_owned()))
347+
return Err(HandshakeError::Failure(Error::Custom("PKCS12 file is missing certificate chain".to_owned())))
409348
}
410349

411350
let mut cert_chain = vec![];
412-
413351
for cert in certificates {
414352
cert_chain.push(cert.0);
415353
}
416354

417-
let state = TlsState::new_server(&cert_chain, &mut keys[0].0,
418-
map_version(builder.min_protocol),
419-
map_version(builder.max_protocol)).map_err(Error::Normal)?;
355+
fn pk_clone(pk: &mut Pk) -> TlsResult<Pk> {
356+
let der = pk.write_private_der_vec()?;
357+
Pk::from_private_key(&der, None)
358+
}
420359

421-
Ok(TlsAcceptor { state })
422-
}
360+
unsafe {
361+
let pk = Box::into_raw(Box::new(pk_clone(&mut keys[0].0)?));
362+
let cert_chain = Box::into_raw(Box::new(cert_chain.to_vec()));
363+
let cert_list = Box::into_raw(Box::new(CertList::from_vec(&mut *cert_chain).ok_or(TlsError::CamelliaInvalidInputLength)?));
364+
let entropy = Box::into_raw(Box::new(OsEntropy::new()));
365+
let rng = Box::into_raw(Box::new(CtrDrbg::new(&mut *entropy, None)?));
366+
let config = Box::into_raw(Box::new(Config::new(Endpoint::Server, Transport::Stream, Preset::Default)));
367+
(*config).set_rng(Some(&mut *rng));
368+
(*config).push_cert(&mut *cert_list, &mut *pk)?;
423369

424-
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
425-
where
426-
S: io::Read + io::Write
427-
{
428-
Ok(self.state.establish(stream, None)?)
370+
if let Some(min_version) = map_version(self.min_protocol) {
371+
(*config).set_min_version(min_version)?;
372+
}
373+
if let Some(max_version) = map_version(self.max_protocol) {
374+
(*config).set_max_version(max_version)?;
375+
}
376+
377+
let ctx = Box::into_raw(Box::new(Context::new(&*config)?));
378+
379+
let stream_ptr = Box::into_raw(Box::new(stream));
380+
let session = (*ctx).establish(&mut *stream_ptr, None)?;
381+
let session = Box::into_raw(Box::new(std::mem::transmute::<Session<'_>, Session<'static>>(session))); // yolo
382+
383+
Ok(TlsStream {
384+
ca_certs: ::std::ptr::null_mut(),
385+
ca_cert_list: ::std::ptr::null_mut(),
386+
cred_pk: pk,
387+
cred_certs: cert_chain,
388+
cred_cert_list: cert_list,
389+
entropy: entropy,
390+
rng: rng,
391+
config: config,
392+
ctx: ctx,
393+
session: session,
394+
socket: stream_ptr,
395+
})
396+
}
429397
}
430398
}
431399

0 commit comments

Comments
 (0)