use std::{env, sync::Arc}; use crate::error::Result; use jsonwebtoken::{ Algorithm, DecodingKey, Validation, decode, decode_header, errors::ErrorKind, jwk::JwkSet, }; use serde::Deserialize; #[derive(Clone)] pub struct User { pub id: String, pub name: String, } pub struct JWT { reqwest_client: reqwest::Client, } pub trait AuthImpl { async fn for_protected(&self, token: &str) -> Result; } pub enum Auth { JWT(JWT), } #[derive(Deserialize)] struct Claims { sub: String, name: String, } impl JWT { pub fn new(reqwest_client: reqwest::Client) -> Self { Self { reqwest_client } } } impl AuthImpl for JWT { async fn for_protected(&self, token: &str) -> Result { let frontend_url = env::var("FRONTEND_BASE_URL").unwrap_or_else(|_| "http://localhost:5173".to_string()); let jwks = get_jwks( &self.reqwest_client, &format!("{frontend_url}/api/auth/jwks"), ) .await?; let header = decode_header(token)?; let kid = header.kid.ok_or(crate::error::Error::Unauthorized)?; let jwk = jwks.find(&kid).ok_or(crate::error::Error::Unauthorized)?; let decoding_key = DecodingKey::from_jwk(jwk)?; let mut validation = Validation::new(Algorithm::EdDSA); validation.set_issuer(&[&frontend_url]); validation.set_audience(&[&frontend_url]); let token_data = decode::(token, &decoding_key, &validation).map_err(|e| match e.kind() { ErrorKind::ExpiredSignature => crate::error::Error::TokenExpired, _ => crate::error::Error::Unauthorized, })?; Ok(User { id: token_data.claims.sub, name: token_data.claims.name, }) } } impl AuthImpl for Auth { async fn for_protected(&self, token: &str) -> Result { match self { Auth::JWT(jwt) => jwt.for_protected(token).await, } } } async fn get_jwks(client: &reqwest::Client, jwks_url: &str) -> Result { Ok(client.get(jwks_url).send().await?.json::().await?) }