api/auth/
extractors.rs

1use axum::{
2    extract::{FromRequestParts},
3    http::{request::Parts, StatusCode},
4};
5use axum_extra::extract::TypedHeader;
6use headers::{Authorization, authorization::Bearer};
7use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
8use std::collections::HashMap;
9use std::env;
10use crate::auth::claims::{Claims, AuthUser};
11
12/// Implements extraction of `AuthUser` from request headers.
13///
14/// This middleware checks for a valid Bearer token in the `Authorization` header,
15/// verifies the JWT using the secret from `JWT_SECRET` environment variable,
16/// and extracts the user claims into an `AuthUser` instance.
17///
18/// # Errors
19/// - Returns `401 Unauthorized` if the header is missing, malformed, or the token is invalid or expired.
20///
21/// # Example
22/// ```ignore
23/// use axum::response::IntoResponse;
24/// use api::auth::claims::AuthUser;
25///
26/// async fn protected_route(user: AuthUser) -> impl IntoResponse {
27///     // User is now available
28///     format!("User ID: {}", user.0.sub)
29/// }
30/// ```
31impl<S> FromRequestParts<S> for AuthUser
32where
33    S: Send + Sync,
34{
35    type Rejection = (StatusCode, &'static str);
36
37    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
38        // Try Authorization header first
39        if let Ok(TypedHeader(Authorization(bearer))) =
40            TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
41        {
42            return decode_token(bearer.token());
43        }
44
45        // Fallback to query param `?token=...`
46        if let Some(query) = &parts.uri.query() {
47            let parsed: HashMap<String, String> = url::form_urlencoded::parse(query.as_bytes())
48                .into_owned()
49                .collect();
50
51            if let Some(token) = parsed.get("token") {
52                return decode_token(token);
53            }
54        }
55
56        Err((StatusCode::UNAUTHORIZED, "Missing or invalid Authorization header"))
57    }
58}
59
60fn decode_token(token: &str) -> Result<AuthUser, (StatusCode, &'static str)> {
61    let secret = env::var("JWT_SECRET").map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "JWT secret not set"))?;
62    let data = decode::<Claims>(
63        token,
64        &DecodingKey::from_secret(secret.as_bytes()),
65        &Validation::new(Algorithm::HS256),
66    )
67    .map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid or expired token"))?;
68
69    Ok(AuthUser(data.claims))
70}