1use axum::{
2 extract::{FromRequestParts},
3 http::{request::Parts, StatusCode},
4};
5use axum_extra::extract::TypedHeader;
6use headers::{Authorization, authorization::Bearer};
7use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
8use std::collections::HashMap;
9use std::env;
10use crate::auth::claims::{Claims, AuthUser};
11
12impl<S> FromRequestParts<S> for AuthUser
32where
33 S: Send + Sync,
34{
35 type Rejection = (StatusCode, &'static str);
36
37 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
38 if let Ok(TypedHeader(Authorization(bearer))) =
40 TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
41 {
42 return decode_token(bearer.token());
43 }
44
45 if let Some(query) = &parts.uri.query() {
47 let parsed: HashMap<String, String> = url::form_urlencoded::parse(query.as_bytes())
48 .into_owned()
49 .collect();
50
51 if let Some(token) = parsed.get("token") {
52 return decode_token(token);
53 }
54 }
55
56 Err((StatusCode::UNAUTHORIZED, "Missing or invalid Authorization header"))
57 }
58}
59
60fn decode_token(token: &str) -> Result<AuthUser, (StatusCode, &'static str)> {
61 let secret = env::var("JWT_SECRET").map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "JWT secret not set"))?;
62 let data = decode::<Claims>(
63 token,
64 &DecodingKey::from_secret(secret.as_bytes()),
65 &Validation::new(Algorithm::HS256),
66 )
67 .map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid or expired token"))?;
68
69 Ok(AuthUser(data.claims))
70}