1#![deny(clippy::future_not_send)]
9#![allow(
10 clippy::unused_async,
12 clippy::too_many_arguments,
14 clippy::let_with_type_underscore,
17)]
18
19use std::{
20 convert::Infallible,
21 sync::{Arc, LazyLock},
22 time::Duration,
23};
24
25use axum::{
26 Extension, Router,
27 extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
28 http::Method,
29 response::{Html, IntoResponse},
30 routing::{get, post},
31};
32use headers::HeaderName;
33use hyper::{
34 StatusCode, Version,
35 header::{
36 ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
37 X_FRAME_OPTIONS,
38 },
39};
40use mas_axum_utils::{InternalError, cookies::CookieJar};
41use mas_data_model::SiteConfig;
42use mas_http::CorsLayerExt;
43use mas_keystore::{Encrypter, Keystore};
44use mas_matrix::HomeserverConnection;
45use mas_policy::Policy;
46use mas_router::{Route, UrlBuilder};
47use mas_storage::{BoxRepository, BoxRepositoryFactory};
48use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
49use opentelemetry::metrics::Meter;
50use sqlx::PgPool;
51use tower::util::AndThenLayer;
52use tower_http::{
53 cors::{Any, CorsLayer},
54 set_header::SetResponseHeaderLayer,
55};
56
57use self::{graphql::ExtraRouterParameters, passwords::PasswordManager};
58
59mod admin;
60mod compat;
61mod graphql;
62mod health;
63mod oauth2;
64pub mod passwords;
65pub mod upstream_oauth2;
66mod views;
67
68mod activity_tracker;
69mod captcha;
70#[cfg(test)]
71mod cleanup_tests;
72mod client_ip;
73mod preferred_language;
74mod rate_limit;
75mod session;
76#[cfg(test)]
77mod test_utils;
78
79static METER: LazyLock<Meter> = LazyLock::new(|| {
80 let scope = opentelemetry::InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
81 .with_version(env!("CARGO_PKG_VERSION"))
82 .with_schema_url(opentelemetry_semantic_conventions::SCHEMA_URL)
83 .build();
84
85 opentelemetry::global::meter_with_scope(scope)
86});
87
88#[macro_export]
91macro_rules! impl_from_error_for_route {
92 ($route_error:ty : $error:ty) => {
93 impl From<$error> for $route_error {
94 fn from(e: $error) -> Self {
95 Self::Internal(Box::new(e))
96 }
97 }
98 };
99 ($error:ty) => {
100 impl_from_error_for_route!(self::RouteError: $error);
101 };
102}
103
104pub use mas_axum_utils::{ErrorWrapper, cookies::CookieManager};
105use mas_data_model::{BoxClock, BoxRng};
106
107pub use self::{
108 activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
109 admin::router as admin_api_router,
110 client_ip::ClientIp,
111 graphql::{
112 GraphQLOperation, Schema as GraphQLSchema, schema as graphql_schema,
113 schema_builder as graphql_schema_builder,
114 },
115 preferred_language::PreferredLanguage,
116 rate_limit::{Limiter, RequesterFingerprint},
117 upstream_oauth2::cache::MetadataCache,
118};
119
120pub fn healthcheck_router<S>() -> Router<S>
121where
122 S: Clone + Send + Sync + 'static,
123 PgPool: FromRef<S>,
124{
125 Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
126}
127
128pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
129where
130 S: Clone + Send + Sync + 'static,
131 graphql::Schema: FromRef<S>,
132 BoundActivityTracker: FromRequestParts<S>,
133 BoxRepository: FromRequestParts<S>,
134 BoxClock: FromRequestParts<S>,
135 Encrypter: FromRef<S>,
136 CookieJar: FromRequestParts<S>,
137 Limiter: FromRef<S>,
138 RequesterFingerprint: FromRequestParts<S>,
139{
140 let mut router = Router::new()
141 .route(
142 mas_router::GraphQL::route(),
143 get(self::graphql::get).post(self::graphql::post),
144 )
145 .layer(Extension(ExtraRouterParameters {
148 undocumented_oauth2_access,
149 }))
150 .layer(
151 CorsLayer::new()
152 .allow_origin(Any)
153 .allow_methods(Any)
154 .allow_otel_headers([
155 AUTHORIZATION,
156 ACCEPT,
157 ACCEPT_LANGUAGE,
158 CONTENT_LANGUAGE,
159 CONTENT_TYPE,
160 ]),
161 );
162
163 if playground {
164 router = router.route(
165 mas_router::GraphQLPlayground::route(),
166 get(self::graphql::playground),
167 );
168 }
169
170 router
171}
172
173pub fn discovery_router<S>() -> Router<S>
174where
175 S: Clone + Send + Sync + 'static,
176 Keystore: FromRef<S>,
177 SiteConfig: FromRef<S>,
178 UrlBuilder: FromRef<S>,
179 BoxClock: FromRequestParts<S>,
180 BoxRng: FromRequestParts<S>,
181{
182 Router::new()
183 .route(
184 mas_router::OidcConfiguration::route(),
185 get(self::oauth2::discovery::get),
186 )
187 .route(
188 mas_router::Webfinger::route(),
189 get(self::oauth2::webfinger::get),
190 )
191 .layer(
192 CorsLayer::new()
193 .allow_origin(Any)
194 .allow_methods(Any)
195 .allow_otel_headers([
196 AUTHORIZATION,
197 ACCEPT,
198 ACCEPT_LANGUAGE,
199 CONTENT_LANGUAGE,
200 CONTENT_TYPE,
201 ])
202 .max_age(Duration::from_hours(1)),
203 )
204}
205
206pub fn api_router<S>() -> Router<S>
207where
208 S: Clone + Send + Sync + 'static,
209 Keystore: FromRef<S>,
210 UrlBuilder: FromRef<S>,
211 BoxRepository: FromRequestParts<S>,
212 ActivityTracker: FromRequestParts<S>,
213 BoundActivityTracker: FromRequestParts<S>,
214 Encrypter: FromRef<S>,
215 reqwest::Client: FromRef<S>,
216 SiteConfig: FromRef<S>,
217 Templates: FromRef<S>,
218 Arc<dyn HomeserverConnection>: FromRef<S>,
219 BoxClock: FromRequestParts<S>,
220 BoxRng: FromRequestParts<S>,
221 Policy: FromRequestParts<S>,
222{
223 Router::new()
225 .route(
226 mas_router::OAuth2Keys::route(),
227 get(self::oauth2::keys::get),
228 )
229 .route(
230 mas_router::OidcUserinfo::route(),
231 get(self::oauth2::userinfo::get).post(self::oauth2::userinfo::get),
232 )
233 .route(
234 mas_router::OAuth2Introspection::route(),
235 post(self::oauth2::introspection::post),
236 )
237 .route(
238 mas_router::OAuth2Revocation::route(),
239 post(self::oauth2::revoke::post),
240 )
241 .route(
242 mas_router::OAuth2TokenEndpoint::route(),
243 post(self::oauth2::token::post),
244 )
245 .route(
246 mas_router::OAuth2RegistrationEndpoint::route(),
247 post(self::oauth2::registration::post),
248 )
249 .route(
250 mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
251 post(self::oauth2::device::authorize::post),
252 )
253 .layer(
254 CorsLayer::new()
255 .allow_origin(Any)
256 .allow_methods(Any)
257 .allow_otel_headers([
258 AUTHORIZATION,
259 ACCEPT,
260 ACCEPT_LANGUAGE,
261 CONTENT_LANGUAGE,
262 CONTENT_TYPE,
263 HeaderName::from_static("x-requested-with"),
265 ])
266 .max_age(Duration::from_hours(1)),
267 )
268}
269
270pub fn compat_router<S>(templates: Templates) -> Router<S>
271where
272 S: Clone + Send + Sync + 'static,
273 UrlBuilder: FromRef<S>,
274 SiteConfig: FromRef<S>,
275 Arc<dyn HomeserverConnection>: FromRef<S>,
276 PasswordManager: FromRef<S>,
277 Limiter: FromRef<S>,
278 BoxRepositoryFactory: FromRef<S>,
279 BoundActivityTracker: FromRequestParts<S>,
280 RequesterFingerprint: FromRequestParts<S>,
281 BoxRepository: FromRequestParts<S>,
282 BoxClock: FromRequestParts<S>,
283 BoxRng: FromRequestParts<S>,
284 Policy: FromRequestParts<S>,
285{
286 let human_router = Router::new()
288 .route(
289 mas_router::CompatLoginSsoRedirect::route(),
290 get(self::compat::login_sso_redirect::get),
291 )
292 .route(
293 mas_router::CompatLoginSsoRedirectIdp::route(),
294 get(self::compat::login_sso_redirect::get),
295 )
296 .route(
297 mas_router::CompatLoginSsoRedirectSlash::route(),
298 get(self::compat::login_sso_redirect::get),
299 )
300 .layer(AndThenLayer::new(
301 async move |response: axum::response::Response| {
302 Ok::<_, Infallible>(recover_error(&templates, response))
303 },
304 ));
305
306 let api_router = Router::new()
308 .route(
309 mas_router::CompatLogin::route(),
310 get(self::compat::login::get).post(self::compat::login::post),
311 )
312 .route(
313 mas_router::CompatLogout::route(),
314 post(self::compat::logout::post),
315 )
316 .route(
317 mas_router::CompatLogoutAll::route(),
318 post(self::compat::logout_all::post),
319 )
320 .route(
321 mas_router::CompatRefresh::route(),
322 post(self::compat::refresh::post),
323 )
324 .layer(
325 CorsLayer::new()
326 .allow_origin(Any)
327 .allow_methods(Any)
328 .allow_otel_headers([
329 AUTHORIZATION,
330 ACCEPT,
331 ACCEPT_LANGUAGE,
332 CONTENT_LANGUAGE,
333 CONTENT_TYPE,
334 HeaderName::from_static("x-requested-with"),
335 ])
336 .max_age(Duration::from_hours(1)),
337 );
338
339 Router::new().merge(human_router).merge(api_router)
340}
341
342pub fn human_router<S>(templates: Templates) -> Router<S>
343where
344 S: Clone + Send + Sync + 'static,
345 UrlBuilder: FromRef<S>,
346 PreferredLanguage: FromRequestParts<S>,
347 BoxRepository: FromRequestParts<S>,
348 CookieJar: FromRequestParts<S>,
349 BoundActivityTracker: FromRequestParts<S>,
350 RequesterFingerprint: FromRequestParts<S>,
351 Encrypter: FromRef<S>,
352 Templates: FromRef<S>,
353 Keystore: FromRef<S>,
354 PasswordManager: FromRef<S>,
355 MetadataCache: FromRef<S>,
356 SiteConfig: FromRef<S>,
357 Limiter: FromRef<S>,
358 reqwest::Client: FromRef<S>,
359 Arc<dyn HomeserverConnection>: FromRef<S>,
360 BoxClock: FromRequestParts<S>,
361 BoxRng: FromRequestParts<S>,
362 Policy: FromRequestParts<S>,
363{
364 Router::new()
365 .route(
367 "/account",
368 get(
369 async |State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| {
370 let prefix = url_builder.prefix().unwrap_or_default();
371 let route = mas_router::Account::route();
372 let destination = if let Some(query) = query {
373 format!("{prefix}{route}?{query}")
374 } else {
375 format!("{prefix}{route}")
376 };
377
378 axum::response::Redirect::to(&destination)
379 },
380 ),
381 )
382 .route(mas_router::Account::route(), get(self::views::app::get))
383 .route(
384 mas_router::AccountWildcard::route(),
385 get(self::views::app::get),
386 )
387 .route(
388 mas_router::AccountRecoveryFinish::route(),
389 get(self::views::app::get_anonymous),
390 )
391 .route(
392 mas_router::ChangePasswordDiscovery::route(),
393 get(async |State(url_builder): State<UrlBuilder>| {
394 url_builder.redirect(&mas_router::AccountPasswordChange)
395 }),
396 )
397 .route(mas_router::Index::route(), get(self::views::index::get))
398 .route(
399 mas_router::Login::route(),
400 get(self::views::login::get).post(self::views::login::post),
401 )
402 .route(mas_router::Logout::route(), post(self::views::logout::post))
403 .route(
404 mas_router::Register::route(),
405 get(self::views::register::get),
406 )
407 .route(
408 mas_router::PasswordRegister::route(),
409 get(self::views::register::password::get).post(self::views::register::password::post),
410 )
411 .route(
412 mas_router::RegisterVerifyEmail::route(),
413 get(self::views::register::steps::verify_email::get)
414 .post(self::views::register::steps::verify_email::post),
415 )
416 .route(
417 mas_router::RegisterToken::route(),
418 get(self::views::register::steps::registration_token::get)
419 .post(self::views::register::steps::registration_token::post),
420 )
421 .route(
422 mas_router::RegisterDisplayName::route(),
423 get(self::views::register::steps::display_name::get)
424 .post(self::views::register::steps::display_name::post),
425 )
426 .route(
427 mas_router::RegisterFinish::route(),
428 get(self::views::register::steps::finish::get),
429 )
430 .route(
431 mas_router::AccountRecoveryStart::route(),
432 get(self::views::recovery::start::get).post(self::views::recovery::start::post),
433 )
434 .route(
435 mas_router::AccountRecoveryProgress::route(),
436 get(self::views::recovery::progress::get).post(self::views::recovery::progress::post),
437 )
438 .route(
439 mas_router::OAuth2AuthorizationEndpoint::route(),
440 get(self::oauth2::authorization::get),
441 )
442 .route(
443 mas_router::Consent::route(),
444 get(self::oauth2::authorization::consent::get)
445 .post(self::oauth2::authorization::consent::post),
446 )
447 .route(
448 mas_router::CompatLoginSsoComplete::route(),
449 get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
450 )
451 .route(
452 mas_router::UpstreamOAuth2Authorize::route(),
453 get(self::upstream_oauth2::authorize::get),
454 )
455 .route(
456 mas_router::UpstreamOAuth2Callback::route(),
457 get(self::upstream_oauth2::callback::handler)
458 .post(self::upstream_oauth2::callback::handler),
459 )
460 .route(
461 mas_router::UpstreamOAuth2Link::route(),
462 get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
463 )
464 .route(
465 mas_router::UpstreamOAuth2BackchannelLogout::route(),
466 post(self::upstream_oauth2::backchannel_logout::post),
467 )
468 .route(
469 mas_router::DeviceCodeLink::route(),
470 get(self::oauth2::device::link::get).post(self::oauth2::device::link::post),
471 )
472 .route(
473 mas_router::DeviceCodeConsent::route(),
474 get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
475 )
476 .layer(AndThenLayer::new(
477 async move |response: axum::response::Response| {
478 Ok::<_, Infallible>(recover_error(&templates, response))
479 },
480 ))
481 .layer(SetResponseHeaderLayer::if_not_present(
482 X_FRAME_OPTIONS,
483 http::HeaderValue::from_static("DENY"),
484 ))
485}
486
487fn recover_error(
488 templates: &Templates,
489 response: axum::response::Response,
490) -> axum::response::Response {
491 let ext = response.extensions().get::<ErrorContext>();
493 if let Some(ctx) = ext
494 && let Ok(res) = templates.render_error(ctx)
495 {
496 let (mut parts, _original_body) = response.into_parts();
497 parts.headers.remove(CONTENT_TYPE);
498 parts.headers.remove(CONTENT_LENGTH);
499 return (parts, Html(res)).into_response();
500 }
501
502 response
503}
504
505pub async fn fallback(
511 State(templates): State<Templates>,
512 OriginalUri(uri): OriginalUri,
513 method: Method,
514 version: Version,
515 PreferredLanguage(locale): PreferredLanguage,
516) -> Result<impl IntoResponse, InternalError> {
517 let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
518 let res = templates.render_not_found(&ctx)?;
521
522 Ok((StatusCode::NOT_FOUND, Html(res)))
523}