Skip to content

Add unix domain socket for telemetry server #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/http_server/example_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ telemetry:
# Enables telemetry server
enabled: true
# Telemetry server address.
# Can be either a TCP socket address (e.g., "127.0.0.1:8080")
# or a Unix domain socket path (e.g., "/tmp/telemetry.sock") on Unix systems.
addr: "127.0.0.1:0"
# Example Unix socket configuration (uncomment to use):
# addr: "/tmp/telemetry.sock"
# HTTP endpoints configuration.
endpoints:
Example endpoint:
Expand Down
8 changes: 6 additions & 2 deletions examples/http_server/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod settings;

use self::settings::{EndpointSettings, HttpServerSettings, ResponseSettings};
use anyhow::anyhow;
use foundations::addr::ListenAddr;
use foundations::cli::{Arg, ArgAction, Cli};
use foundations::settings::collections::Map;
use foundations::telemetry::{self, log, tracing, TelemetryConfig, TelemetryContext};
Expand Down Expand Up @@ -56,8 +57,11 @@ async fn main() -> BootstrapResult<()> {
custom_server_routes: vec![],
})?;

if let Some(tele_serv_addr) = tele_driver.server_addr() {
log::info!("Telemetry server is listening on http://{}", tele_serv_addr);
if let Some(addr) = tele_driver.server_addr() {
match addr {
ListenAddr::Tcp(addr) => log::info!("Telemetry server is listening on http://{addr}"),
ListenAddr::Unix(path) => log::info!("Telemetry server is listening on {path:?}"),
}
}

// Spawn TCP listeners for each endpoint. Note that `Map<EndpointsSettings>` is ordered, so
Expand Down
71 changes: 71 additions & 0 deletions foundations/src/addr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//! Network address types that support both TCP and Unix domain sockets.
//!
//! This module provides the [`ListenAddr`] enum, a flexible address type that can represent
//! either TCP socket addresses or Unix domain socket paths.

#[cfg(feature = "settings")]
use crate::settings::Settings;
#[cfg(any(feature = "telemetry-server", feature = "settings"))]
use serde::Deserialize;
#[cfg(feature = "settings")]
use serde::Serialize;
use std::fmt;
use std::net::{Ipv4Addr, SocketAddr};

/// Address that can be either TCP socket or Unix domain socket endpoint
#[derive(Clone, Debug)]
#[cfg_attr(
any(feature = "telemetry-server", feature = "settings"),
derive(Deserialize)
)]
#[cfg_attr(feature = "settings", derive(Serialize))]
#[cfg_attr(
any(feature = "telemetry-server", feature = "settings"),
serde(untagged)
)]
pub enum ListenAddr {
/// TCP network socket address
Tcp(std::net::SocketAddr),
/// Unix domain socket path
#[cfg(unix)]
Unix(std::path::PathBuf),
}

impl Default for ListenAddr {
fn default() -> Self {
ListenAddr::Tcp((Ipv4Addr::LOCALHOST, 0).into())
}
}

#[cfg(feature = "settings")]
impl From<crate::settings::net::SocketAddr> for ListenAddr {
fn from(addr: crate::settings::net::SocketAddr) -> Self {
ListenAddr::Tcp(addr.into())
}
}

impl From<SocketAddr> for ListenAddr {
fn from(addr: SocketAddr) -> Self {
ListenAddr::Tcp(addr)
}
}

#[cfg(unix)]
impl From<std::path::PathBuf> for ListenAddr {
fn from(path: std::path::PathBuf) -> Self {
ListenAddr::Unix(path)
}
}

impl fmt::Display for ListenAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ListenAddr::Tcp(addr) => write!(f, "{addr}"),
#[cfg(unix)]
ListenAddr::Unix(path) => write!(f, "{}", path.display()),
}
}
}

#[cfg(feature = "settings")]
impl Settings for ListenAddr {}
2 changes: 2 additions & 0 deletions foundations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@

mod utils;

pub mod addr;

#[cfg(feature = "cli")]
pub mod cli;

Expand Down
11 changes: 6 additions & 5 deletions foundations/src/telemetry/driver.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "telemetry-server")]
use crate::addr::ListenAddr;
use crate::utils::feature_use;
use crate::BootstrapResult;
use futures_util::future::BoxFuture;
Expand All @@ -9,7 +11,6 @@ use std::task::{Context, Poll};

feature_use!(cfg(feature = "telemetry-server"), {
use super::server::TelemetryServerFuture;
use std::net::SocketAddr;
});

/// A future that drives async telemetry functionality and that is returned
Expand All @@ -21,7 +22,7 @@ feature_use!(cfg(feature = "telemetry-server"), {
/// [security syscall-related]: `crate::security`
pub struct TelemetryDriver {
#[cfg(feature = "telemetry-server")]
server_addr: Option<SocketAddr>,
server_addr: Option<ListenAddr>,

#[cfg(feature = "telemetry-server")]
server_fut: Option<TelemetryServerFuture>,
Expand All @@ -36,7 +37,7 @@ impl TelemetryDriver {
) -> Self {
Self {
#[cfg(feature = "telemetry-server")]
server_addr: server_fut.as_ref().map(|fut| fut.local_addr()),
server_addr: server_fut.as_ref().and_then(|fut| fut.local_addr().ok()),

#[cfg(feature = "telemetry-server")]
server_fut,
Expand All @@ -49,8 +50,8 @@ impl TelemetryDriver {
///
/// Returns `None` if the server wasn't spawned.
#[cfg(feature = "telemetry-server")]
pub fn server_addr(&self) -> Option<SocketAddr> {
self.server_addr
pub fn server_addr(&self) -> Option<&ListenAddr> {
self.server_addr.as_ref()
}

/// Instructs the telemetry driver and server to perform an orderly shutdown when the given
Expand Down
165 changes: 141 additions & 24 deletions foundations/src/telemetry/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(feature = "metrics")]
use super::metrics;
use super::settings::TelemetrySettings;
use crate::addr::ListenAddr;
use crate::telemetry::log;
use crate::BootstrapResult;
use anyhow::Context as _;
Expand All @@ -14,18 +15,127 @@ use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
#[cfg(unix)]
use tokio::net::{TcpStream, UnixListener, UnixStream};
use tokio::sync::watch;

mod router;

use router::Router;

enum TelemetryStream {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}

impl AsyncRead for TelemetryStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(unix)]
TelemetryStream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

impl AsyncWrite for TelemetryStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(unix)]
TelemetryStream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
#[cfg(unix)]
TelemetryStream::Unix(stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
TelemetryStream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
#[cfg(unix)]
TelemetryStream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

enum TelemetryListener {
Tcp(TcpListener),
#[cfg(unix)]
Unix(UnixListener),
}

impl TelemetryListener {
pub(crate) fn local_addr(&self) -> BootstrapResult<ListenAddr> {
match self {
TelemetryListener::Tcp(listener) => Ok(listener.local_addr()?.into()),
#[cfg(unix)]
TelemetryListener::Unix(listener) => match listener.local_addr()?.as_pathname() {
Some(path) => Ok(path.to_path_buf().into()),
None => Err(anyhow::anyhow!("unix socket listener has no pathname")),
},
}
}

pub(crate) async fn accept(&self) -> std::io::Result<TelemetryStream> {
match self {
TelemetryListener::Tcp(listener) => listener
.accept()
.await
.map(|(conn, _)| TelemetryStream::Tcp(conn)),
#[cfg(unix)]
TelemetryListener::Unix(listener) => listener
.accept()
.await
.map(|(conn, _)| TelemetryStream::Unix(conn)),
}
}

pub(crate) fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<TelemetryStream>> {
match self {
TelemetryListener::Tcp(listener) => match std::task::ready!(listener.poll_accept(cx)) {
Ok((conn, _)) => std::task::Poll::Ready(Ok(TelemetryStream::Tcp(conn))),
Err(e) => std::task::Poll::Ready(Err(e)),
},
#[cfg(unix)]
TelemetryListener::Unix(listener) => {
match std::task::ready!(listener.poll_accept(cx)) {
Ok((conn, _)) => std::task::Poll::Ready(Ok(TelemetryStream::Unix(conn))),
Err(e) => std::task::Poll::Ready(Err(e)),
}
}
}
}
}

pub use router::{
BoxError, TelemetryRouteHandler, TelemetryRouteHandlerFuture, TelemetryServerRoute,
};

pub(super) struct TelemetryServerFuture {
listener: TcpListener,
listener: TelemetryListener,
router: Router,
}

Expand All @@ -47,27 +157,38 @@ impl TelemetryServerFuture {
.map_err(|err| anyhow::anyhow!(err))?;
}

let addr = settings.server.addr;

#[cfg(feature = "settings")]
let addr = SocketAddr::from(addr);

let router = Router::new(custom_routes, settings);

let listener = {
let std_listener = std::net::TcpListener::from(
bind_socket(addr).with_context(|| format!("binding to socket {addr:?}"))?,
);

std_listener.set_nonblocking(true)?;
let router = Router::new(custom_routes, Arc::clone(&settings));

let listener = match &settings.server.addr {
ListenAddr::Tcp(addr) => {
let std_listener = std::net::TcpListener::from(
bind_socket(*addr)
.with_context(|| format!("binding to TCP socket {addr:?}"))?,
);
std_listener.set_nonblocking(true)?;
let tokio_listener = tokio::net::TcpListener::from_std(std_listener)?;
TelemetryListener::Tcp(tokio_listener)
}
#[cfg(unix)]
ListenAddr::Unix(path) => {
// Remove existing socket file if it exists to avoid bind errors
if path.exists() {
if let Err(e) = std::fs::remove_file(path) {
log::warn!("failed to remove existing Unix socket file"; "path" => %path.display(), "error" => e);
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to silently delete existing sockets? I have a gut feeling that this could hide subtle problems

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be easy to get into a bad state during a crash and a restart because we can't reuse old sockets.


tokio::net::TcpListener::from_std(std_listener)?
let unix_listener = UnixListener::bind(path)
.with_context(|| format!("binding to Unix socket {path:?}"))?;
TelemetryListener::Unix(unix_listener)
}
};

Ok(Some(TelemetryServerFuture { listener, router }))
}
pub(super) fn local_addr(&self) -> SocketAddr {
self.listener.local_addr().unwrap()

pub(super) fn local_addr(&self) -> BootstrapResult<ListenAddr> {
self.listener.local_addr()
}

// Adapted from Hyper 0.14 Server stuff and axum::serve::serve.
Expand All @@ -87,15 +208,12 @@ impl TelemetryServerFuture {
let (close_tx, close_rx) = watch::channel(());
let listener = self.listener;

pin_mut!(listener);

loop {
let socket = tokio::select! {
conn = listener.accept() => match conn {
Ok((conn, _)) => TokioIo::new(conn),
Ok(conn) => TokioIo::new(conn),
Err(e) => {
log::warn!("failed to accept connection"; "error" => e);

continue;
}
},
Expand Down Expand Up @@ -140,11 +258,10 @@ impl Future for TelemetryServerFuture {
let this = &mut *self;

loop {
let socket = match ready!(Pin::new(&mut this.listener).poll_accept(cx)) {
Ok((conn, _)) => TokioIo::new(conn),
let socket = match ready!(this.listener.poll_accept(cx)) {
Ok(conn) => TokioIo::new(conn),
Err(e) => {
log::warn!("failed to accept connection"; "error" => e);

continue;
}
};
Expand Down
Loading
Loading