Skip to main content

mas_axum_utils/
user_authorization.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::{collections::HashMap, error::Error};
9
10use axum::{
11    extract::{
12        Form, FromRequest, FromRequestParts,
13        rejection::{FailedToDeserializeForm, FormRejection},
14    },
15    response::{IntoResponse, Response},
16};
17use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
18use headers::{Authorization, Header, HeaderMapExt, HeaderName, authorization::Bearer};
19use http::{HeaderMap, HeaderValue, Request, StatusCode, header::WWW_AUTHENTICATE};
20use mas_data_model::{Clock, Session};
21use mas_storage::{
22    RepositoryAccess,
23    oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
24};
25use serde::{Deserialize, de::DeserializeOwned};
26use thiserror::Error;
27
28use crate::log_context::RecordAsRequester;
29
30#[derive(Debug, Deserialize)]
31struct AuthorizedForm<F> {
32    #[serde(default)]
33    access_token: Option<String>,
34
35    #[serde(flatten)]
36    inner: F,
37}
38
39#[derive(Debug)]
40enum AccessToken {
41    Form(String),
42    Header(String),
43    None,
44}
45
46impl AccessToken {
47    async fn fetch<E>(
48        &self,
49        repo: &mut impl RepositoryAccess<Error = E>,
50    ) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
51        let token = match self {
52            AccessToken::Form(t) | AccessToken::Header(t) => t,
53            AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
54        };
55
56        let token = repo
57            .oauth2_access_token()
58            .find_by_token(token.as_str())
59            .await?
60            .ok_or(AuthorizationVerificationError::InvalidToken)?;
61
62        let session = repo
63            .oauth2_session()
64            .lookup(token.session_id)
65            .await?
66            .ok_or(AuthorizationVerificationError::InvalidToken)?;
67
68        session.maybe_record_as_requester();
69
70        Ok((token, session))
71    }
72}
73
74#[derive(Debug)]
75pub struct UserAuthorization<F = ()> {
76    access_token: AccessToken,
77    form: Option<F>,
78}
79
80impl<F: Send> UserAuthorization<F> {
81    // TODO: take scopes to validate as parameter
82    /// Verify a user authorization and return the session and the protected
83    /// form value
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if the token is invalid, if the user session ended or
88    /// if the form is missing
89    pub async fn protected_form<E>(
90        self,
91        repo: &mut impl RepositoryAccess<Error = E>,
92        clock: &impl Clock,
93    ) -> Result<(Session, F), AuthorizationVerificationError<E>> {
94        let Some(form) = self.form else {
95            return Err(AuthorizationVerificationError::MissingForm);
96        };
97
98        let (token, session) = self.access_token.fetch(repo).await?;
99
100        if !token.is_valid(clock.now()) || !session.is_valid() {
101            return Err(AuthorizationVerificationError::InvalidToken);
102        }
103
104        Ok((session, form))
105    }
106
107    // TODO: take scopes to validate as parameter
108    /// Verify a user authorization and return the session
109    ///
110    /// # Errors
111    ///
112    /// Returns an error if the token is invalid or if the user session ended
113    pub async fn protected<E>(
114        self,
115        repo: &mut impl RepositoryAccess<Error = E>,
116        clock: &impl Clock,
117    ) -> Result<Session, AuthorizationVerificationError<E>> {
118        let (token, session) = self.access_token.fetch(repo).await?;
119
120        if !token.is_valid(clock.now()) || !session.is_valid() {
121            return Err(AuthorizationVerificationError::InvalidToken);
122        }
123
124        if !token.is_used() {
125            // Mark the token as used
126            repo.oauth2_access_token().mark_used(clock, token).await?;
127        }
128
129        Ok(session)
130    }
131}
132
133pub enum UserAuthorizationError {
134    InvalidHeader,
135    TokenInFormAndHeader,
136    BadForm(FailedToDeserializeForm),
137    Internal(Box<dyn Error>),
138}
139
140#[derive(Debug, Error)]
141pub enum AuthorizationVerificationError<E> {
142    #[error("missing token")]
143    MissingToken,
144
145    #[error("invalid token")]
146    InvalidToken,
147
148    #[error("missing form")]
149    MissingForm,
150
151    #[error(transparent)]
152    Internal(#[from] E),
153}
154
155enum BearerError {
156    InvalidRequest,
157    InvalidToken,
158    #[expect(dead_code)]
159    InsufficientScope {
160        scope: Option<HeaderValue>,
161    },
162}
163
164impl BearerError {
165    fn error(&self) -> HeaderValue {
166        match self {
167            BearerError::InvalidRequest => HeaderValue::from_static("invalid_request"),
168            BearerError::InvalidToken => HeaderValue::from_static("invalid_token"),
169            BearerError::InsufficientScope { .. } => HeaderValue::from_static("insufficient_scope"),
170        }
171    }
172
173    fn params(&self) -> HashMap<&'static str, HeaderValue> {
174        match self {
175            BearerError::InsufficientScope { scope: Some(scope) } => {
176                let mut m = HashMap::new();
177                m.insert("scope", scope.clone());
178                m
179            }
180            _ => HashMap::new(),
181        }
182    }
183}
184
185enum WwwAuthenticate {
186    #[expect(dead_code)]
187    Basic { realm: HeaderValue },
188    Bearer {
189        realm: Option<HeaderValue>,
190        error: BearerError,
191        error_description: Option<HeaderValue>,
192    },
193}
194
195impl Header for WwwAuthenticate {
196    fn name() -> &'static HeaderName {
197        &WWW_AUTHENTICATE
198    }
199
200    fn decode<'i, I>(_values: &mut I) -> Result<Self, headers::Error>
201    where
202        Self: Sized,
203        I: Iterator<Item = &'i http::HeaderValue>,
204    {
205        Err(headers::Error::invalid())
206    }
207
208    fn encode<E: Extend<http::HeaderValue>>(&self, values: &mut E) {
209        let (scheme, params) = match self {
210            WwwAuthenticate::Basic { realm } => {
211                let mut params = HashMap::new();
212                params.insert("realm", realm.clone());
213                ("Basic", params)
214            }
215            WwwAuthenticate::Bearer {
216                realm,
217                error,
218                error_description,
219            } => {
220                let mut params = error.params();
221                params.insert("error", error.error());
222
223                if let Some(realm) = realm {
224                    params.insert("realm", realm.clone());
225                }
226
227                if let Some(error_description) = error_description {
228                    params.insert("error_description", error_description.clone());
229                }
230
231                ("Bearer", params)
232            }
233        };
234
235        let params = params.into_iter().map(|(k, v)| format!(" {k}={v:?}"));
236        let value: String = std::iter::once(scheme.to_owned()).chain(params).collect();
237        let value = HeaderValue::from_str(&value).unwrap();
238        values.extend(std::iter::once(value));
239    }
240}
241
242impl IntoResponse for UserAuthorizationError {
243    fn into_response(self) -> Response {
244        match self {
245            Self::BadForm(_) | Self::InvalidHeader | Self::TokenInFormAndHeader => {
246                let mut headers = HeaderMap::new();
247
248                headers.typed_insert(WwwAuthenticate::Bearer {
249                    realm: None,
250                    error: BearerError::InvalidRequest,
251                    error_description: None,
252                });
253                (StatusCode::BAD_REQUEST, headers).into_response()
254            }
255            Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
256        }
257    }
258}
259
260impl<E> IntoResponse for AuthorizationVerificationError<E>
261where
262    E: ToString,
263{
264    fn into_response(self) -> Response {
265        match self {
266            Self::MissingForm | Self::MissingToken => {
267                let mut headers = HeaderMap::new();
268
269                headers.typed_insert(WwwAuthenticate::Bearer {
270                    realm: None,
271                    error: BearerError::InvalidRequest,
272                    error_description: None,
273                });
274                (StatusCode::BAD_REQUEST, headers).into_response()
275            }
276            Self::InvalidToken => {
277                let mut headers = HeaderMap::new();
278
279                headers.typed_insert(WwwAuthenticate::Bearer {
280                    realm: None,
281                    error: BearerError::InvalidToken,
282                    error_description: None,
283                });
284                (StatusCode::BAD_REQUEST, headers).into_response()
285            }
286            Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
287        }
288    }
289}
290
291impl<S, F> FromRequest<S> for UserAuthorization<F>
292where
293    F: DeserializeOwned,
294    S: Send + Sync,
295{
296    type Rejection = UserAuthorizationError;
297
298    async fn from_request(
299        req: Request<axum::body::Body>,
300        state: &S,
301    ) -> Result<Self, Self::Rejection> {
302        let (mut parts, body) = req.into_parts();
303        let header =
304            TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, state).await;
305
306        // Take the Authorization header
307        let token_from_header = match header {
308            Ok(header) => Some(header.token().to_owned()),
309            Err(err) => match err.reason() {
310                // If it's missing it is fine
311                TypedHeaderRejectionReason::Missing => None,
312                // If the header could not be parsed, return the error
313                _ => return Err(UserAuthorizationError::InvalidHeader),
314            },
315        };
316
317        let req = Request::from_parts(parts, body);
318
319        // Take the form value
320        let (token_from_form, form) =
321            match Form::<AuthorizedForm<F>>::from_request(req, state).await {
322                Ok(Form(form)) => (form.access_token, Some(form.inner)),
323                // If it is not a form, continue
324                Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
325                // If the form could not be read, return a Bad Request error
326                Err(FormRejection::FailedToDeserializeForm(err)) => {
327                    return Err(UserAuthorizationError::BadForm(err));
328                }
329                // Other errors (body read twice, byte stream broke) return an internal error
330                Err(e) => return Err(UserAuthorizationError::Internal(Box::new(e))),
331            };
332
333        let access_token = match (token_from_header, token_from_form) {
334            // Ensure the token should not be in both the form and the access token
335            (Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
336            (Some(t), None) => AccessToken::Header(t),
337            (None, Some(t)) => AccessToken::Form(t),
338            (None, None) => AccessToken::None,
339        };
340
341        Ok(UserAuthorization { access_token, form })
342    }
343}