Skip to main content

mas_handlers/graphql/
mod.rs

1// Copyright 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8#![allow(clippy::module_name_repetitions)]
9
10use std::{net::IpAddr, ops::Deref, sync::Arc};
11
12use async_graphql::{
13    EmptySubscription, InputObject,
14    extensions::Tracing,
15    http::{GraphQLPlaygroundConfig, MultipartOptions, playground_source},
16    parser::types::{DocumentOperations, OperationType},
17};
18use axum::{
19    Extension, Json,
20    body::Body,
21    extract::{RawQuery, State as AxumState},
22    http::StatusCode,
23    response::{Html, IntoResponse, Response},
24};
25use axum_extra::typed_header::TypedHeader;
26use chrono::{DateTime, Utc};
27use futures_util::TryStreamExt;
28use headers::{Authorization, ContentType, HeaderValue, authorization::Bearer};
29use hyper::header::CACHE_CONTROL;
30use mas_axum_utils::{
31    InternalError, RecordAsRequester, SessionInfo, SessionInfoExt, cookies::CookieJar,
32    sentry::SentryEventID,
33};
34use mas_data_model::{
35    BoxClock, BoxRng, BrowserSession, Clock, Session, SiteConfig, SystemClock, User,
36};
37use mas_matrix::HomeserverConnection;
38use mas_policy::{InstantiateError, Policy, PolicyFactory};
39use mas_router::UrlBuilder;
40use mas_storage::{BoxRepository, BoxRepositoryFactory, RepositoryError};
41use opentelemetry_semantic_conventions::trace::{
42    GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME, GRAPHQL_OPERATION_TYPE,
43};
44use rand::{SeedableRng, thread_rng};
45use rand_chacha::ChaChaRng;
46use state::has_session_ended;
47use tracing::{Instrument, info_span};
48use ulid::Ulid;
49
50mod model;
51mod mutations;
52mod query;
53mod state;
54
55pub use self::state::{BoxState, State};
56use self::{
57    model::{CreationEvent, Node},
58    mutations::Mutation,
59    query::Query,
60};
61use crate::{
62    BoundActivityTracker, Limiter, RequesterFingerprint, impl_from_error_for_route,
63    passwords::PasswordManager,
64};
65
66#[cfg(test)]
67mod tests;
68
69/// Extra parameters we get from the listener configuration, because they are
70/// per-listener options. We pass them through request extensions.
71#[derive(Debug, Clone)]
72pub struct ExtraRouterParameters {
73    pub undocumented_oauth2_access: bool,
74}
75
76struct GraphQLState {
77    repository_factory: BoxRepositoryFactory,
78    homeserver_connection: Arc<dyn HomeserverConnection>,
79    policy_factory: Arc<PolicyFactory>,
80    site_config: SiteConfig,
81    password_manager: PasswordManager,
82    url_builder: UrlBuilder,
83    limiter: Limiter,
84}
85
86#[async_trait::async_trait]
87impl state::State for GraphQLState {
88    async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
89        self.repository_factory.create().await
90    }
91
92    async fn policy(&self) -> Result<Policy, InstantiateError> {
93        self.policy_factory.instantiate().await
94    }
95
96    fn password_manager(&self) -> PasswordManager {
97        self.password_manager.clone()
98    }
99
100    fn site_config(&self) -> &SiteConfig {
101        &self.site_config
102    }
103
104    fn homeserver_connection(&self) -> &dyn HomeserverConnection {
105        self.homeserver_connection.as_ref()
106    }
107
108    fn url_builder(&self) -> &UrlBuilder {
109        &self.url_builder
110    }
111
112    fn limiter(&self) -> &Limiter {
113        &self.limiter
114    }
115
116    fn clock(&self) -> BoxClock {
117        let clock = SystemClock::default();
118        Box::new(clock)
119    }
120
121    fn rng(&self) -> BoxRng {
122        #[expect(clippy::disallowed_methods)]
123        let rng = thread_rng();
124
125        let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
126        Box::new(rng)
127    }
128}
129
130#[must_use]
131pub fn schema(
132    repository_factory: BoxRepositoryFactory,
133    policy_factory: &Arc<PolicyFactory>,
134    homeserver_connection: impl HomeserverConnection + 'static,
135    site_config: SiteConfig,
136    password_manager: PasswordManager,
137    url_builder: UrlBuilder,
138    limiter: Limiter,
139) -> Schema {
140    let state = GraphQLState {
141        repository_factory,
142        policy_factory: Arc::clone(policy_factory),
143        homeserver_connection: Arc::new(homeserver_connection),
144        site_config,
145        password_manager,
146        url_builder,
147        limiter,
148    };
149    let state: BoxState = Box::new(state);
150
151    schema_builder().extension(Tracing).data(state).finish()
152}
153
154fn span_and_operation_for_graphql_request(
155    request: &mut async_graphql::Request,
156) -> (tracing::Span, GraphQLOperation) {
157    let span = info_span!(
158        "GraphQL operation",
159        "otel.name" = tracing::field::Empty,
160        "otel.kind" = "server",
161        { GRAPHQL_DOCUMENT } = request.query,
162        { GRAPHQL_OPERATION_NAME } = tracing::field::Empty,
163        { GRAPHQL_OPERATION_TYPE } = tracing::field::Empty,
164    );
165
166    let mut graphql_operation = GraphQLOperation {
167        operation_type: None,
168        operation_name: None,
169    };
170
171    // We need to clone the operation_name before parsing the query, else we're
172    // going to have a borrow conflict between request.parsed_query() and
173    // request.operation_name
174    let operation_name = request.operation_name.clone();
175    if let Ok(document) = request.parsed_query() {
176        match (&document.operations, operation_name) {
177            // A single anonymous operation, with no name requested: the
178            // document defines no name for it, so we only record the type.
179            (DocumentOperations::Single(operation), None) => {
180                span.record("otel.name", format!("GraphQL {}", operation.node.ty));
181                span.record(
182                    GRAPHQL_OPERATION_TYPE,
183                    tracing::field::display(operation.node.ty),
184                );
185                graphql_operation.operation_type = Some(operation.node.ty);
186            }
187
188            (DocumentOperations::Multiple(operations), Some(name)) => {
189                if let Some((name, operation)) = operations.get_key_value(name.as_str()) {
190                    span.record(
191                        "otel.name",
192                        format!("GraphQL {} {}", operation.node.ty, name),
193                    );
194                    span.record(
195                        GRAPHQL_OPERATION_TYPE,
196                        tracing::field::display(operation.node.ty),
197                    );
198                    span.record(GRAPHQL_OPERATION_NAME, tracing::field::display(name));
199                    graphql_operation.operation_type = Some(operation.node.ty);
200                    graphql_operation.operation_name = Some(name.to_string());
201                }
202            }
203
204            (DocumentOperations::Multiple(operations), None) if operations.len() == 1 => {
205                let mut iter = operations.iter();
206                let (name, operation) = iter.next().unwrap();
207                span.record(
208                    "otel.name",
209                    format!("GraphQL {} {}", operation.node.ty, name),
210                );
211                span.record(
212                    GRAPHQL_OPERATION_TYPE,
213                    tracing::field::display(operation.node.ty),
214                );
215                span.record(GRAPHQL_OPERATION_NAME, name.as_ref());
216                graphql_operation.operation_type = Some(operation.node.ty);
217                graphql_operation.operation_name = Some(name.to_string());
218            }
219
220            // Cases the executor rejects, so we don't record a misleading
221            // operation: a single anonymous operation with a name requested, or
222            // several named operations with no requested name (ambiguous).
223            (DocumentOperations::Single(_), Some(_)) | (DocumentOperations::Multiple(_), None) => {}
224        }
225    }
226
227    (span, graphql_operation)
228}
229
230/// The GraphQL operation being executed, attached to the response extensions so
231/// the HTTP logging middleware can record it on the request log line.
232#[derive(Clone, Debug)]
233pub struct GraphQLOperation {
234    /// The type of the operation: query, mutation or subscription.
235    pub operation_type: Option<OperationType>,
236    /// The name of the operation, as defined in the query document.
237    pub operation_name: Option<String>,
238}
239
240#[derive(thiserror::Error, Debug)]
241pub enum RouteError {
242    #[error(transparent)]
243    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
244
245    #[error("Loading of some database objects failed")]
246    LoadFailed,
247
248    #[error("Invalid access token")]
249    InvalidToken,
250
251    #[error("Missing scope")]
252    MissingScope,
253
254    #[error(transparent)]
255    ParseRequest(#[from] async_graphql::ParseRequestError),
256}
257
258impl_from_error_for_route!(mas_storage::RepositoryError);
259
260impl IntoResponse for RouteError {
261    fn into_response(self) -> Response {
262        let event_id = sentry::capture_error(&self);
263
264        let response = match self {
265            e @ (Self::Internal(_) | Self::LoadFailed) => {
266                let error = async_graphql::Error::new_with_source(e);
267                (
268                    StatusCode::INTERNAL_SERVER_ERROR,
269                    Json(serde_json::json!({"errors": [error]})),
270                )
271                    .into_response()
272            }
273
274            Self::InvalidToken => {
275                let error = async_graphql::Error::new("Invalid token");
276                (
277                    StatusCode::UNAUTHORIZED,
278                    Json(serde_json::json!({"errors": [error]})),
279                )
280                    .into_response()
281            }
282
283            Self::MissingScope => {
284                let error = async_graphql::Error::new("Missing urn:mas:graphql:* scope");
285                (
286                    StatusCode::UNAUTHORIZED,
287                    Json(serde_json::json!({"errors": [error]})),
288                )
289                    .into_response()
290            }
291
292            Self::ParseRequest(e) => {
293                let error = async_graphql::Error::new_with_source(e);
294                (
295                    StatusCode::BAD_REQUEST,
296                    Json(serde_json::json!({"errors": [error]})),
297                )
298                    .into_response()
299            }
300        };
301
302        (SentryEventID::from(event_id), response).into_response()
303    }
304}
305
306async fn get_requester(
307    undocumented_oauth2_access: bool,
308    clock: &impl Clock,
309    activity_tracker: &BoundActivityTracker,
310    mut repo: BoxRepository,
311    session_info: &SessionInfo,
312    user_agent: Option<String>,
313    token: Option<&str>,
314) -> Result<Requester, RouteError> {
315    let entity = if let Some(token) = token {
316        // If we haven't enabled undocumented_oauth2_access on the listener, we bail out
317        if !undocumented_oauth2_access {
318            return Err(RouteError::InvalidToken);
319        }
320
321        let token = repo
322            .oauth2_access_token()
323            .find_by_token(token)
324            .await?
325            .ok_or(RouteError::InvalidToken)?;
326
327        let session = repo
328            .oauth2_session()
329            .lookup(token.session_id)
330            .await?
331            .ok_or(RouteError::LoadFailed)?;
332
333        activity_tracker
334            .record_oauth2_session(clock, &session)
335            .await;
336
337        // Load the user if there is one
338        let user = if let Some(user_id) = session.user_id {
339            let user = repo
340                .user()
341                .lookup(user_id)
342                .await?
343                .ok_or(RouteError::LoadFailed)?;
344            Some(user)
345        } else {
346            None
347        };
348
349        // If there is a user for this session, check that it is not locked
350        let user_valid = user.as_ref().is_none_or(User::is_valid);
351
352        if !token.is_valid(clock.now()) || !session.is_valid() || !user_valid {
353            return Err(RouteError::InvalidToken);
354        }
355
356        if !session.scope.contains("urn:mas:graphql:*") {
357            return Err(RouteError::MissingScope);
358        }
359
360        if let Some(user) = &user {
361            user.maybe_record_as_requester();
362        }
363
364        RequestingEntity::OAuth2Session(Box::new((session, user)))
365    } else {
366        let maybe_session = session_info.load_active_session(&mut repo).await?;
367
368        if let Some(session) = maybe_session.as_ref() {
369            activity_tracker
370                .record_browser_session(clock, session)
371                .await;
372        }
373
374        RequestingEntity::from(maybe_session)
375    };
376
377    let requester = Requester {
378        entity,
379        ip_address: activity_tracker.ip(),
380        user_agent,
381    };
382
383    repo.cancel().await?;
384    Ok(requester)
385}
386
387pub async fn post(
388    AxumState(schema): AxumState<Schema>,
389    Extension(ExtraRouterParameters {
390        undocumented_oauth2_access,
391    }): Extension<ExtraRouterParameters>,
392    clock: BoxClock,
393    repo: BoxRepository,
394    activity_tracker: BoundActivityTracker,
395    cookie_jar: CookieJar,
396    content_type: Option<TypedHeader<ContentType>>,
397    authorization: Option<TypedHeader<Authorization<Bearer>>>,
398    user_agent: Option<TypedHeader<headers::UserAgent>>,
399    body: Body,
400) -> Result<impl IntoResponse, RouteError> {
401    let body = body.into_data_stream();
402    let token = authorization
403        .as_ref()
404        .map(|TypedHeader(Authorization(bearer))| bearer.token());
405    let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
406    let (session_info, mut cookie_jar) = cookie_jar.session_info();
407    let requester = get_requester(
408        undocumented_oauth2_access,
409        &clock,
410        &activity_tracker,
411        repo,
412        &session_info,
413        user_agent,
414        token,
415    )
416    .await?;
417
418    let content_type = content_type.map(|TypedHeader(h)| h.to_string());
419
420    let mut request = async_graphql::http::receive_body(
421        content_type,
422        body.map_err(std::io::Error::other).into_async_read(),
423        MultipartOptions::default(),
424    )
425    .await?
426    .data(requester); // XXX: this should probably return another error response?
427
428    let (span, operation) = span_and_operation_for_graphql_request(&mut request);
429    let mut response = schema.execute(request).instrument(span).await;
430
431    if has_session_ended(&mut response) {
432        let session_info = session_info.mark_session_ended();
433        cookie_jar = cookie_jar.update_session_info(&session_info);
434    }
435
436    let cache_control = response
437        .cache_control
438        .value()
439        .and_then(|v| HeaderValue::from_str(&v).ok())
440        .map(|h| [(CACHE_CONTROL, h)]);
441
442    let headers = response.http_headers.clone();
443
444    Ok((
445        headers,
446        cache_control,
447        cookie_jar,
448        Extension(operation),
449        Json(response),
450    ))
451}
452
453pub async fn get(
454    AxumState(schema): AxumState<Schema>,
455    Extension(ExtraRouterParameters {
456        undocumented_oauth2_access,
457    }): Extension<ExtraRouterParameters>,
458    clock: BoxClock,
459    repo: BoxRepository,
460    activity_tracker: BoundActivityTracker,
461    cookie_jar: CookieJar,
462    authorization: Option<TypedHeader<Authorization<Bearer>>>,
463    user_agent: Option<TypedHeader<headers::UserAgent>>,
464    RawQuery(query): RawQuery,
465) -> Result<impl IntoResponse, InternalError> {
466    let token = authorization
467        .as_ref()
468        .map(|TypedHeader(Authorization(bearer))| bearer.token());
469    let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
470    let (session_info, mut cookie_jar) = cookie_jar.session_info();
471    let requester = get_requester(
472        undocumented_oauth2_access,
473        &clock,
474        &activity_tracker,
475        repo,
476        &session_info,
477        user_agent,
478        token,
479    )
480    .await?;
481
482    let mut request =
483        async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
484
485    let (span, operation) = span_and_operation_for_graphql_request(&mut request);
486    let mut response = schema.execute(request).instrument(span).await;
487
488    if has_session_ended(&mut response) {
489        let session_info = session_info.mark_session_ended();
490        cookie_jar = cookie_jar.update_session_info(&session_info);
491    }
492
493    let cache_control = response
494        .cache_control
495        .value()
496        .and_then(|v| HeaderValue::from_str(&v).ok())
497        .map(|h| [(CACHE_CONTROL, h)]);
498
499    let headers = response.http_headers.clone();
500
501    Ok((
502        headers,
503        cache_control,
504        cookie_jar,
505        Extension(operation),
506        Json(response),
507    ))
508}
509
510pub async fn playground() -> impl IntoResponse {
511    Html(playground_source(
512        GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
513    ))
514}
515
516pub type Schema = async_graphql::Schema<Query, Mutation, EmptySubscription>;
517pub type SchemaBuilder = async_graphql::SchemaBuilder<Query, Mutation, EmptySubscription>;
518
519#[must_use]
520pub fn schema_builder() -> SchemaBuilder {
521    async_graphql::Schema::build(Query::new(), Mutation::new(), EmptySubscription)
522        .register_output_type::<Node>()
523        .register_output_type::<CreationEvent>()
524}
525
526pub struct Requester {
527    entity: RequestingEntity,
528    ip_address: Option<IpAddr>,
529    user_agent: Option<String>,
530}
531
532impl Requester {
533    pub fn fingerprint(&self) -> RequesterFingerprint {
534        if let Some(ip) = self.ip_address {
535            RequesterFingerprint::new(ip)
536        } else {
537            RequesterFingerprint::EMPTY
538        }
539    }
540
541    pub fn for_policy(&self) -> mas_policy::Requester {
542        mas_policy::Requester {
543            ip_address: self.ip_address,
544            user_agent: self.user_agent.clone(),
545        }
546    }
547}
548
549impl Deref for Requester {
550    type Target = RequestingEntity;
551
552    fn deref(&self) -> &Self::Target {
553        &self.entity
554    }
555}
556
557/// The identity of the requester.
558#[derive(Debug, Clone, Default, PartialEq, Eq)]
559pub enum RequestingEntity {
560    /// The requester presented no authentication information.
561    #[default]
562    Anonymous,
563
564    /// The requester is a browser session, stored in a cookie.
565    BrowserSession(Box<BrowserSession>),
566
567    /// The requester is a `OAuth2` session, with an access token.
568    OAuth2Session(Box<(Session, Option<User>)>),
569}
570
571trait OwnerId {
572    fn owner_id(&self) -> Option<Ulid>;
573}
574
575impl OwnerId for User {
576    fn owner_id(&self) -> Option<Ulid> {
577        Some(self.id)
578    }
579}
580
581impl OwnerId for BrowserSession {
582    fn owner_id(&self) -> Option<Ulid> {
583        Some(self.user.id)
584    }
585}
586
587impl OwnerId for mas_data_model::UserEmail {
588    fn owner_id(&self) -> Option<Ulid> {
589        Some(self.user_id)
590    }
591}
592
593impl OwnerId for Session {
594    fn owner_id(&self) -> Option<Ulid> {
595        self.user_id
596    }
597}
598
599impl OwnerId for mas_data_model::CompatSession {
600    fn owner_id(&self) -> Option<Ulid> {
601        Some(self.user_id)
602    }
603}
604
605impl OwnerId for mas_data_model::UpstreamOAuthLink {
606    fn owner_id(&self) -> Option<Ulid> {
607        self.user_id
608    }
609}
610
611/// A dumb wrapper around a `Ulid` to implement `OwnerId` for it.
612pub struct UserId(Ulid);
613
614impl OwnerId for UserId {
615    fn owner_id(&self) -> Option<Ulid> {
616        Some(self.0)
617    }
618}
619
620impl RequestingEntity {
621    fn browser_session(&self) -> Option<&BrowserSession> {
622        match self {
623            Self::BrowserSession(session) => Some(session),
624            Self::OAuth2Session(_) | Self::Anonymous => None,
625        }
626    }
627
628    fn user(&self) -> Option<&User> {
629        match self {
630            Self::BrowserSession(session) => Some(&session.user),
631            Self::OAuth2Session(tuple) => tuple.1.as_ref(),
632            Self::Anonymous => None,
633        }
634    }
635
636    fn oauth2_session(&self) -> Option<&Session> {
637        match self {
638            Self::OAuth2Session(tuple) => Some(&tuple.0),
639            Self::BrowserSession(_) | Self::Anonymous => None,
640        }
641    }
642
643    /// Returns true if the requester can access the resource.
644    fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool {
645        // If the requester is an admin, they can do anything.
646        if self.is_admin() {
647            return true;
648        }
649
650        // Otherwise, they must be the owner of the resource.
651        let Some(owner_id) = resource.owner_id() else {
652            return false;
653        };
654
655        let Some(user) = self.user() else {
656            return false;
657        };
658
659        user.id == owner_id
660    }
661
662    fn is_admin(&self) -> bool {
663        match self {
664            Self::OAuth2Session(tuple) => {
665                // TODO: is this the right scope?
666                // This has to be in sync with the policy
667                tuple.0.scope.contains("urn:mas:admin")
668            }
669            Self::BrowserSession(_) | Self::Anonymous => false,
670        }
671    }
672
673    fn is_unauthenticated(&self) -> bool {
674        matches!(self, Self::Anonymous)
675    }
676}
677
678impl From<BrowserSession> for RequestingEntity {
679    fn from(session: BrowserSession) -> Self {
680        Self::BrowserSession(Box::new(session))
681    }
682}
683
684impl<T> From<Option<T>> for RequestingEntity
685where
686    T: Into<RequestingEntity>,
687{
688    fn from(session: Option<T>) -> Self {
689        session.map(Into::into).unwrap_or_default()
690    }
691}
692
693/// A filter for dates, with a lower bound and an upper bound
694#[derive(InputObject, Default, Clone, Copy)]
695pub struct DateFilter {
696    /// The lower bound of the date range
697    after: Option<DateTime<Utc>>,
698
699    /// The upper bound of the date range
700    before: Option<DateTime<Utc>>,
701}