Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation for lambda_http #42

Open
edeetee opened this issue Oct 30, 2024 · 0 comments
Open

Implementation for lambda_http #42

edeetee opened this issue Oct 30, 2024 · 0 comments

Comments

@edeetee
Copy link

edeetee commented Oct 30, 2024

I wrote this for some internal code, thought I would leave it here, free free to use it.

use std::{
    net::{IpAddr, SocketAddr},
    sync::Arc,
    time::Duration,
};

use axum::http::Request;
use governor::{clock::QuantaInstant, middleware::NoOpMiddleware};
use lambda_http::{lambda_runtime::Config, request::RequestContext, RequestExt};
use tower_governor::{
    governor::{GovernorConfig, GovernorConfigBuilder},
    key_extractor::KeyExtractor,
    GovernorError,
};

pub fn is_on_lambda() -> bool {
    let is_on_lambda = std::env::var("AWS_LAMBDA_FUNCTION_NAME").is_ok();
    is_on_lambda
}

pub fn lambda_config() -> Option<Config> {
    if is_on_lambda() {
        Some(Config::from_env())
    } else {
        None
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LocalOrLambdaPeerIpKeyExtractor;

impl KeyExtractor for LocalOrLambdaPeerIpKeyExtractor {
    type Key = IpAddr;

    //type Key: Clone + Hash + Eq;
    fn extract<T>(&self, req: &Request<T>) -> Result<Self::Key, GovernorError> {
        if is_on_lambda() {
            extract_ip_from_lambda(req)
        } else {
            extract_ip_from_local(req)
        }
    }
}

fn extract_ip_from_lambda<T>(req: &Request<T>) -> Result<IpAddr, GovernorError> {
    let ip_str = match req.request_context() {
        RequestContext::ApiGatewayV2(x) => x.http.source_ip,
        RequestContext::ApiGatewayV1(x) => x.identity.source_ip,
        RequestContext::WebSocket(x) => x.identity.source_ip,
        RequestContext::Alb(_) => None,
    }
    .ok_or(GovernorError::UnableToExtractKey)?;

    ip_str.parse().map_err(|_| {
        tracing::warn!("Failed to parse IP address: {}", ip_str);
        GovernorError::UnableToExtractKey
    })
}

fn extract_ip_from_local<T>(req: &Request<T>) -> Result<IpAddr, GovernorError> {
    req.extensions()
        .get::<axum::extract::ConnectInfo<SocketAddr>>()
        .map(|addr| addr.ip())
        .ok_or(GovernorError::UnableToExtractKey)
}

pub fn axum_governor<K: KeyExtractor>(
    conf_b: &mut GovernorConfigBuilder<K, NoOpMiddleware<QuantaInstant>>,
) -> Arc<GovernorConfig<LocalOrLambdaPeerIpKeyExtractor, NoOpMiddleware<QuantaInstant>>> {
    // Allow bursts with up to five requests per IP address
    // and replenishes one element every two seconds
    // We Box it because Axum 0.6 requires all Layers to be Clone
    // and thus we need a static reference to it
    let governor_conf = Arc::new(
        conf_b
            .key_extractor(LocalOrLambdaPeerIpKeyExtractor)
            .finish()
            .unwrap(),
    );

    let governor_limiter = governor_conf.limiter().clone();
    let interval = Duration::from_secs(60);
    // a separate background task to clean up
    std::thread::spawn(move || loop {
        std::thread::sleep(interval);
        tracing::info!("rate limiting storage size: {}", governor_limiter.len());
        governor_limiter.retain_recent();
    });

    governor_conf
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant