1use std::{collections::HashMap, error::Error};
9
10use axum::{
11 extract::{
12 Form, FromRequest, FromRequestParts,
13 rejection::{FailedToDeserializeForm, FormRejection},
14 },
15 response::{IntoResponse, Response},
16};
17use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
18use headers::{Authorization, Header, HeaderMapExt, HeaderName, authorization::Bearer};
19use http::{HeaderMap, HeaderValue, Request, StatusCode, header::WWW_AUTHENTICATE};
20use mas_data_model::{Clock, Session};
21use mas_storage::{
22 RepositoryAccess,
23 oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
24};
25use serde::{Deserialize, de::DeserializeOwned};
26use thiserror::Error;
27
28use crate::log_context::RecordAsRequester;
29
30#[derive(Debug, Deserialize)]
31struct AuthorizedForm<F> {
32 #[serde(default)]
33 access_token: Option<String>,
34
35 #[serde(flatten)]
36 inner: F,
37}
38
39#[derive(Debug)]
40enum AccessToken {
41 Form(String),
42 Header(String),
43 None,
44}
45
46impl AccessToken {
47 async fn fetch<E>(
48 &self,
49 repo: &mut impl RepositoryAccess<Error = E>,
50 ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
51 let token = match self {
52 AccessToken::Form(t) | AccessToken::Header(t) => t,
53 AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
54 };
55
56 let token = repo
57 .oauth2_access_token()
58 .find_by_token(token.as_str())
59 .await?
60 .ok_or(AuthorizationVerificationError::InvalidToken)?;
61
62 let session = repo
63 .oauth2_session()
64 .lookup(token.session_id)
65 .await?
66 .ok_or(AuthorizationVerificationError::InvalidToken)?;
67
68 session.maybe_record_as_requester();
69
70 Ok((token, session))
71 }
72}
73
74#[derive(Debug)]
75pub struct UserAuthorization<F = ()> {
76 access_token: AccessToken,
77 form: Option<F>,
78}
79
80impl<F: Send> UserAuthorization<F> {
81 pub async fn protected_form<E>(
90 self,
91 repo: &mut impl RepositoryAccess<Error = E>,
92 clock: &impl Clock,
93 ) -> Result<(Session, F), AuthorizationVerificationError<E>> {
94 let Some(form) = self.form else {
95 return Err(AuthorizationVerificationError::MissingForm);
96 };
97
98 let (token, session) = self.access_token.fetch(repo).await?;
99
100 if !token.is_valid(clock.now()) || !session.is_valid() {
101 return Err(AuthorizationVerificationError::InvalidToken);
102 }
103
104 Ok((session, form))
105 }
106
107 pub async fn protected<E>(
114 self,
115 repo: &mut impl RepositoryAccess<Error = E>,
116 clock: &impl Clock,
117 ) -> Result<Session, AuthorizationVerificationError<E>> {
118 let (token, session) = self.access_token.fetch(repo).await?;
119
120 if !token.is_valid(clock.now()) || !session.is_valid() {
121 return Err(AuthorizationVerificationError::InvalidToken);
122 }
123
124 if !token.is_used() {
125 repo.oauth2_access_token().mark_used(clock, token).await?;
127 }
128
129 Ok(session)
130 }
131}
132
133pub enum UserAuthorizationError {
134 InvalidHeader,
135 TokenInFormAndHeader,
136 BadForm(FailedToDeserializeForm),
137 Internal(Box<dyn Error>),
138}
139
140#[derive(Debug, Error)]
141pub enum AuthorizationVerificationError<E> {
142 #[error("missing token")]
143 MissingToken,
144
145 #[error("invalid token")]
146 InvalidToken,
147
148 #[error("missing form")]
149 MissingForm,
150
151 #[error(transparent)]
152 Internal(#[from] E),
153}
154
155enum BearerError {
156 InvalidRequest,
157 InvalidToken,
158 #[expect(dead_code)]
159 InsufficientScope {
160 scope: Option<HeaderValue>,
161 },
162}
163
164impl BearerError {
165 fn error(&self) -> HeaderValue {
166 match self {
167 BearerError::InvalidRequest => HeaderValue::from_static("invalid_request"),
168 BearerError::InvalidToken => HeaderValue::from_static("invalid_token"),
169 BearerError::InsufficientScope { .. } => HeaderValue::from_static("insufficient_scope"),
170 }
171 }
172
173 fn params(&self) -> HashMap<&'static str, HeaderValue> {
174 match self {
175 BearerError::InsufficientScope { scope: Some(scope) } => {
176 let mut m = HashMap::new();
177 m.insert("scope", scope.clone());
178 m
179 }
180 _ => HashMap::new(),
181 }
182 }
183}
184
185enum WwwAuthenticate {
186 #[expect(dead_code)]
187 Basic { realm: HeaderValue },
188 Bearer {
189 realm: Option<HeaderValue>,
190 error: BearerError,
191 error_description: Option<HeaderValue>,
192 },
193}
194
195impl Header for WwwAuthenticate {
196 fn name() -> &'static HeaderName {
197 &WWW_AUTHENTICATE
198 }
199
200 fn decode<'i, I>(_values: &mut I) -> Result<Self, headers::Error>
201 where
202 Self: Sized,
203 I: Iterator<Item = &'i http::HeaderValue>,
204 {
205 Err(headers::Error::invalid())
206 }
207
208 fn encode<E: Extend<http::HeaderValue>>(&self, values: &mut E) {
209 let (scheme, params) = match self {
210 WwwAuthenticate::Basic { realm } => {
211 let mut params = HashMap::new();
212 params.insert("realm", realm.clone());
213 ("Basic", params)
214 }
215 WwwAuthenticate::Bearer {
216 realm,
217 error,
218 error_description,
219 } => {
220 let mut params = error.params();
221 params.insert("error", error.error());
222
223 if let Some(realm) = realm {
224 params.insert("realm", realm.clone());
225 }
226
227 if let Some(error_description) = error_description {
228 params.insert("error_description", error_description.clone());
229 }
230
231 ("Bearer", params)
232 }
233 };
234
235 let params = params.into_iter().map(|(k, v)| format!(" {k}={v:?}"));
236 let value: String = std::iter::once(scheme.to_owned()).chain(params).collect();
237 let value = HeaderValue::from_str(&value).unwrap();
238 values.extend(std::iter::once(value));
239 }
240}
241
242impl IntoResponse for UserAuthorizationError {
243 fn into_response(self) -> Response {
244 match self {
245 Self::BadForm(_) | Self::InvalidHeader | Self::TokenInFormAndHeader => {
246 let mut headers = HeaderMap::new();
247
248 headers.typed_insert(WwwAuthenticate::Bearer {
249 realm: None,
250 error: BearerError::InvalidRequest,
251 error_description: None,
252 });
253 (StatusCode::BAD_REQUEST, headers).into_response()
254 }
255 Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
256 }
257 }
258}
259
260impl<E> IntoResponse for AuthorizationVerificationError<E>
261where
262 E: ToString,
263{
264 fn into_response(self) -> Response {
265 match self {
266 Self::MissingForm | Self::MissingToken => {
267 let mut headers = HeaderMap::new();
268
269 headers.typed_insert(WwwAuthenticate::Bearer {
270 realm: None,
271 error: BearerError::InvalidRequest,
272 error_description: None,
273 });
274 (StatusCode::BAD_REQUEST, headers).into_response()
275 }
276 Self::InvalidToken => {
277 let mut headers = HeaderMap::new();
278
279 headers.typed_insert(WwwAuthenticate::Bearer {
280 realm: None,
281 error: BearerError::InvalidToken,
282 error_description: None,
283 });
284 (StatusCode::BAD_REQUEST, headers).into_response()
285 }
286 Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
287 }
288 }
289}
290
291impl<S, F> FromRequest<S> for UserAuthorization<F>
292where
293 F: DeserializeOwned,
294 S: Send + Sync,
295{
296 type Rejection = UserAuthorizationError;
297
298 async fn from_request(
299 req: Request<axum::body::Body>,
300 state: &S,
301 ) -> Result<Self, Self::Rejection> {
302 let (mut parts, body) = req.into_parts();
303 let header =
304 TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, state).await;
305
306 let token_from_header = match header {
308 Ok(header) => Some(header.token().to_owned()),
309 Err(err) => match err.reason() {
310 TypedHeaderRejectionReason::Missing => None,
312 _ => return Err(UserAuthorizationError::InvalidHeader),
314 },
315 };
316
317 let req = Request::from_parts(parts, body);
318
319 let (token_from_form, form) =
321 match Form::<AuthorizedForm<F>>::from_request(req, state).await {
322 Ok(Form(form)) => (form.access_token, Some(form.inner)),
323 Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
325 Err(FormRejection::FailedToDeserializeForm(err)) => {
327 return Err(UserAuthorizationError::BadForm(err));
328 }
329 Err(e) => return Err(UserAuthorizationError::Internal(Box::new(e))),
331 };
332
333 let access_token = match (token_from_header, token_from_form) {
334 (Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
336 (Some(t), None) => AccessToken::Header(t),
337 (None, Some(t)) => AccessToken::Form(t),
338 (None, None) => AccessToken::None,
339 };
340
341 Ok(UserAuthorization { access_token, form })
342 }
343}