Skip to main content

mas_handlers/
lib.rs

1// Copyright 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8#![deny(clippy::future_not_send)]
9#![allow(
10    // Some axum handlers need that
11    clippy::unused_async,
12    // Because of how axum handlers work, we sometime have take many arguments
13    clippy::too_many_arguments,
14    // Code generated by tracing::instrument trigger this when returning an `impl Trait`
15    // See https://github.com/tokio-rs/tracing/issues/2613
16    clippy::let_with_type_underscore,
17)]
18
19use std::{
20    convert::Infallible,
21    sync::{Arc, LazyLock},
22    time::Duration,
23};
24
25use axum::{
26    Extension, Router,
27    extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
28    http::Method,
29    response::{Html, IntoResponse},
30    routing::{get, post},
31};
32use headers::HeaderName;
33use hyper::{
34    StatusCode, Version,
35    header::{
36        ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
37        X_FRAME_OPTIONS,
38    },
39};
40use mas_axum_utils::{InternalError, cookies::CookieJar};
41use mas_data_model::SiteConfig;
42use mas_http::CorsLayerExt;
43use mas_keystore::{Encrypter, Keystore};
44use mas_matrix::HomeserverConnection;
45use mas_policy::Policy;
46use mas_router::{Route, UrlBuilder};
47use mas_storage::{BoxRepository, BoxRepositoryFactory};
48use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
49use opentelemetry::metrics::Meter;
50use sqlx::PgPool;
51use tower::util::AndThenLayer;
52use tower_http::{
53    cors::{Any, CorsLayer},
54    set_header::SetResponseHeaderLayer,
55};
56
57use self::{graphql::ExtraRouterParameters, passwords::PasswordManager};
58
59mod admin;
60mod compat;
61mod graphql;
62mod health;
63mod oauth2;
64pub mod passwords;
65pub mod upstream_oauth2;
66mod views;
67
68mod activity_tracker;
69mod captcha;
70#[cfg(test)]
71mod cleanup_tests;
72mod client_ip;
73mod preferred_language;
74mod rate_limit;
75mod session;
76#[cfg(test)]
77mod test_utils;
78
79static METER: LazyLock<Meter> = LazyLock::new(|| {
80    let scope = opentelemetry::InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
81        .with_version(env!("CARGO_PKG_VERSION"))
82        .with_schema_url(opentelemetry_semantic_conventions::SCHEMA_URL)
83        .build();
84
85    opentelemetry::global::meter_with_scope(scope)
86});
87
88/// Implement `From<E>` for `RouteError`, for "internal server error" kind of
89/// errors.
90#[macro_export]
91macro_rules! impl_from_error_for_route {
92    ($route_error:ty : $error:ty) => {
93        impl From<$error> for $route_error {
94            fn from(e: $error) -> Self {
95                Self::Internal(Box::new(e))
96            }
97        }
98    };
99    ($error:ty) => {
100        impl_from_error_for_route!(self::RouteError: $error);
101    };
102}
103
104pub use mas_axum_utils::{ErrorWrapper, cookies::CookieManager};
105use mas_data_model::{BoxClock, BoxRng};
106
107pub use self::{
108    activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
109    admin::router as admin_api_router,
110    client_ip::ClientIp,
111    graphql::{
112        GraphQLOperation, Schema as GraphQLSchema, schema as graphql_schema,
113        schema_builder as graphql_schema_builder,
114    },
115    preferred_language::PreferredLanguage,
116    rate_limit::{Limiter, RequesterFingerprint},
117    upstream_oauth2::cache::MetadataCache,
118};
119
120pub fn healthcheck_router<S>() -> Router<S>
121where
122    S: Clone + Send + Sync + 'static,
123    PgPool: FromRef<S>,
124{
125    Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
126}
127
128pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
129where
130    S: Clone + Send + Sync + 'static,
131    graphql::Schema: FromRef<S>,
132    BoundActivityTracker: FromRequestParts<S>,
133    BoxRepository: FromRequestParts<S>,
134    BoxClock: FromRequestParts<S>,
135    Encrypter: FromRef<S>,
136    CookieJar: FromRequestParts<S>,
137    Limiter: FromRef<S>,
138    RequesterFingerprint: FromRequestParts<S>,
139{
140    let mut router = Router::new()
141        .route(
142            mas_router::GraphQL::route(),
143            get(self::graphql::get).post(self::graphql::post),
144        )
145        // Pass the undocumented_oauth2_access parameter through the request extension, as it is
146        // per-listener
147        .layer(Extension(ExtraRouterParameters {
148            undocumented_oauth2_access,
149        }))
150        .layer(
151            CorsLayer::new()
152                .allow_origin(Any)
153                .allow_methods(Any)
154                .allow_otel_headers([
155                    AUTHORIZATION,
156                    ACCEPT,
157                    ACCEPT_LANGUAGE,
158                    CONTENT_LANGUAGE,
159                    CONTENT_TYPE,
160                ]),
161        );
162
163    if playground {
164        router = router.route(
165            mas_router::GraphQLPlayground::route(),
166            get(self::graphql::playground),
167        );
168    }
169
170    router
171}
172
173pub fn discovery_router<S>() -> Router<S>
174where
175    S: Clone + Send + Sync + 'static,
176    Keystore: FromRef<S>,
177    SiteConfig: FromRef<S>,
178    UrlBuilder: FromRef<S>,
179    BoxClock: FromRequestParts<S>,
180    BoxRng: FromRequestParts<S>,
181{
182    Router::new()
183        .route(
184            mas_router::OidcConfiguration::route(),
185            get(self::oauth2::discovery::get),
186        )
187        .route(
188            mas_router::Webfinger::route(),
189            get(self::oauth2::webfinger::get),
190        )
191        .layer(
192            CorsLayer::new()
193                .allow_origin(Any)
194                .allow_methods(Any)
195                .allow_otel_headers([
196                    AUTHORIZATION,
197                    ACCEPT,
198                    ACCEPT_LANGUAGE,
199                    CONTENT_LANGUAGE,
200                    CONTENT_TYPE,
201                ])
202                .max_age(Duration::from_hours(1)),
203        )
204}
205
206pub fn api_router<S>() -> Router<S>
207where
208    S: Clone + Send + Sync + 'static,
209    Keystore: FromRef<S>,
210    UrlBuilder: FromRef<S>,
211    BoxRepository: FromRequestParts<S>,
212    ActivityTracker: FromRequestParts<S>,
213    BoundActivityTracker: FromRequestParts<S>,
214    Encrypter: FromRef<S>,
215    reqwest::Client: FromRef<S>,
216    SiteConfig: FromRef<S>,
217    Templates: FromRef<S>,
218    Arc<dyn HomeserverConnection>: FromRef<S>,
219    BoxClock: FromRequestParts<S>,
220    BoxRng: FromRequestParts<S>,
221    Policy: FromRequestParts<S>,
222{
223    // All those routes are API-like, with a common CORS layer
224    Router::new()
225        .route(
226            mas_router::OAuth2Keys::route(),
227            get(self::oauth2::keys::get),
228        )
229        .route(
230            mas_router::OidcUserinfo::route(),
231            get(self::oauth2::userinfo::get).post(self::oauth2::userinfo::get),
232        )
233        .route(
234            mas_router::OAuth2Introspection::route(),
235            post(self::oauth2::introspection::post),
236        )
237        .route(
238            mas_router::OAuth2Revocation::route(),
239            post(self::oauth2::revoke::post),
240        )
241        .route(
242            mas_router::OAuth2TokenEndpoint::route(),
243            post(self::oauth2::token::post),
244        )
245        .route(
246            mas_router::OAuth2RegistrationEndpoint::route(),
247            post(self::oauth2::registration::post),
248        )
249        .route(
250            mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
251            post(self::oauth2::device::authorize::post),
252        )
253        .layer(
254            CorsLayer::new()
255                .allow_origin(Any)
256                .allow_methods(Any)
257                .allow_otel_headers([
258                    AUTHORIZATION,
259                    ACCEPT,
260                    ACCEPT_LANGUAGE,
261                    CONTENT_LANGUAGE,
262                    CONTENT_TYPE,
263                    // Swagger will send this header, so we have to allow it to avoid CORS errors
264                    HeaderName::from_static("x-requested-with"),
265                ])
266                .max_age(Duration::from_hours(1)),
267        )
268}
269
270pub fn compat_router<S>(templates: Templates) -> Router<S>
271where
272    S: Clone + Send + Sync + 'static,
273    UrlBuilder: FromRef<S>,
274    SiteConfig: FromRef<S>,
275    Arc<dyn HomeserverConnection>: FromRef<S>,
276    PasswordManager: FromRef<S>,
277    Limiter: FromRef<S>,
278    BoxRepositoryFactory: FromRef<S>,
279    BoundActivityTracker: FromRequestParts<S>,
280    RequesterFingerprint: FromRequestParts<S>,
281    BoxRepository: FromRequestParts<S>,
282    BoxClock: FromRequestParts<S>,
283    BoxRng: FromRequestParts<S>,
284    Policy: FromRequestParts<S>,
285{
286    // A sub-router for human-facing routes with error handling
287    let human_router = Router::new()
288        .route(
289            mas_router::CompatLoginSsoRedirect::route(),
290            get(self::compat::login_sso_redirect::get),
291        )
292        .route(
293            mas_router::CompatLoginSsoRedirectIdp::route(),
294            get(self::compat::login_sso_redirect::get),
295        )
296        .route(
297            mas_router::CompatLoginSsoRedirectSlash::route(),
298            get(self::compat::login_sso_redirect::get),
299        )
300        .layer(AndThenLayer::new(
301            async move |response: axum::response::Response| {
302                Ok::<_, Infallible>(recover_error(&templates, response))
303            },
304        ));
305
306    // A sub-router for API-facing routes with CORS
307    let api_router = Router::new()
308        .route(
309            mas_router::CompatLogin::route(),
310            get(self::compat::login::get).post(self::compat::login::post),
311        )
312        .route(
313            mas_router::CompatLogout::route(),
314            post(self::compat::logout::post),
315        )
316        .route(
317            mas_router::CompatLogoutAll::route(),
318            post(self::compat::logout_all::post),
319        )
320        .route(
321            mas_router::CompatRefresh::route(),
322            post(self::compat::refresh::post),
323        )
324        .layer(
325            CorsLayer::new()
326                .allow_origin(Any)
327                .allow_methods(Any)
328                .allow_otel_headers([
329                    AUTHORIZATION,
330                    ACCEPT,
331                    ACCEPT_LANGUAGE,
332                    CONTENT_LANGUAGE,
333                    CONTENT_TYPE,
334                    HeaderName::from_static("x-requested-with"),
335                ])
336                .max_age(Duration::from_hours(1)),
337        );
338
339    Router::new().merge(human_router).merge(api_router)
340}
341
342pub fn human_router<S>(templates: Templates) -> Router<S>
343where
344    S: Clone + Send + Sync + 'static,
345    UrlBuilder: FromRef<S>,
346    PreferredLanguage: FromRequestParts<S>,
347    BoxRepository: FromRequestParts<S>,
348    CookieJar: FromRequestParts<S>,
349    BoundActivityTracker: FromRequestParts<S>,
350    RequesterFingerprint: FromRequestParts<S>,
351    Encrypter: FromRef<S>,
352    Templates: FromRef<S>,
353    Keystore: FromRef<S>,
354    PasswordManager: FromRef<S>,
355    MetadataCache: FromRef<S>,
356    SiteConfig: FromRef<S>,
357    Limiter: FromRef<S>,
358    reqwest::Client: FromRef<S>,
359    Arc<dyn HomeserverConnection>: FromRef<S>,
360    BoxClock: FromRequestParts<S>,
361    BoxRng: FromRequestParts<S>,
362    Policy: FromRequestParts<S>,
363{
364    Router::new()
365        // XXX: hard-coded redirect from /account to /account/
366        .route(
367            "/account",
368            get(
369                async |State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| {
370                    let prefix = url_builder.prefix().unwrap_or_default();
371                    let route = mas_router::Account::route();
372                    let destination = if let Some(query) = query {
373                        format!("{prefix}{route}?{query}")
374                    } else {
375                        format!("{prefix}{route}")
376                    };
377
378                    axum::response::Redirect::to(&destination)
379                },
380            ),
381        )
382        .route(mas_router::Account::route(), get(self::views::app::get))
383        .route(
384            mas_router::AccountWildcard::route(),
385            get(self::views::app::get),
386        )
387        .route(
388            mas_router::AccountRecoveryFinish::route(),
389            get(self::views::app::get_anonymous),
390        )
391        .route(
392            mas_router::ChangePasswordDiscovery::route(),
393            get(async |State(url_builder): State<UrlBuilder>| {
394                url_builder.redirect(&mas_router::AccountPasswordChange)
395            }),
396        )
397        .route(mas_router::Index::route(), get(self::views::index::get))
398        .route(
399            mas_router::Login::route(),
400            get(self::views::login::get).post(self::views::login::post),
401        )
402        .route(mas_router::Logout::route(), post(self::views::logout::post))
403        .route(
404            mas_router::Register::route(),
405            get(self::views::register::get),
406        )
407        .route(
408            mas_router::PasswordRegister::route(),
409            get(self::views::register::password::get).post(self::views::register::password::post),
410        )
411        .route(
412            mas_router::RegisterVerifyEmail::route(),
413            get(self::views::register::steps::verify_email::get)
414                .post(self::views::register::steps::verify_email::post),
415        )
416        .route(
417            mas_router::RegisterToken::route(),
418            get(self::views::register::steps::registration_token::get)
419                .post(self::views::register::steps::registration_token::post),
420        )
421        .route(
422            mas_router::RegisterDisplayName::route(),
423            get(self::views::register::steps::display_name::get)
424                .post(self::views::register::steps::display_name::post),
425        )
426        .route(
427            mas_router::RegisterFinish::route(),
428            get(self::views::register::steps::finish::get),
429        )
430        .route(
431            mas_router::AccountRecoveryStart::route(),
432            get(self::views::recovery::start::get).post(self::views::recovery::start::post),
433        )
434        .route(
435            mas_router::AccountRecoveryProgress::route(),
436            get(self::views::recovery::progress::get).post(self::views::recovery::progress::post),
437        )
438        .route(
439            mas_router::OAuth2AuthorizationEndpoint::route(),
440            get(self::oauth2::authorization::get),
441        )
442        .route(
443            mas_router::Consent::route(),
444            get(self::oauth2::authorization::consent::get)
445                .post(self::oauth2::authorization::consent::post),
446        )
447        .route(
448            mas_router::CompatLoginSsoComplete::route(),
449            get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
450        )
451        .route(
452            mas_router::UpstreamOAuth2Authorize::route(),
453            get(self::upstream_oauth2::authorize::get),
454        )
455        .route(
456            mas_router::UpstreamOAuth2Callback::route(),
457            get(self::upstream_oauth2::callback::handler)
458                .post(self::upstream_oauth2::callback::handler),
459        )
460        .route(
461            mas_router::UpstreamOAuth2Link::route(),
462            get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
463        )
464        .route(
465            mas_router::UpstreamOAuth2BackchannelLogout::route(),
466            post(self::upstream_oauth2::backchannel_logout::post),
467        )
468        .route(
469            mas_router::DeviceCodeLink::route(),
470            get(self::oauth2::device::link::get).post(self::oauth2::device::link::post),
471        )
472        .route(
473            mas_router::DeviceCodeConsent::route(),
474            get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
475        )
476        .layer(AndThenLayer::new(
477            async move |response: axum::response::Response| {
478                Ok::<_, Infallible>(recover_error(&templates, response))
479            },
480        ))
481        .layer(SetResponseHeaderLayer::if_not_present(
482            X_FRAME_OPTIONS,
483            http::HeaderValue::from_static("DENY"),
484        ))
485}
486
487fn recover_error(
488    templates: &Templates,
489    response: axum::response::Response,
490) -> axum::response::Response {
491    // Error responses should have an ErrorContext attached to them
492    let ext = response.extensions().get::<ErrorContext>();
493    if let Some(ctx) = ext
494        && let Ok(res) = templates.render_error(ctx)
495    {
496        let (mut parts, _original_body) = response.into_parts();
497        parts.headers.remove(CONTENT_TYPE);
498        parts.headers.remove(CONTENT_LENGTH);
499        return (parts, Html(res)).into_response();
500    }
501
502    response
503}
504
505/// The fallback handler for all routes that don't match anything else.
506///
507/// # Errors
508///
509/// Returns an error if the template rendering fails.
510pub async fn fallback(
511    State(templates): State<Templates>,
512    OriginalUri(uri): OriginalUri,
513    method: Method,
514    version: Version,
515    PreferredLanguage(locale): PreferredLanguage,
516) -> Result<impl IntoResponse, InternalError> {
517    let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
518    // XXX: this should look at the Accept header and return JSON if requested
519
520    let res = templates.render_not_found(&ctx)?;
521
522    Ok((StatusCode::NOT_FOUND, Html(res)))
523}