Skip to main content

mas_storage_pg/oauth2/
authorization_grant.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::collections::BTreeMap;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13    AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Clock, Pkce, Session,
14};
15use mas_iana::oauth::PkceCodeChallengeMethod;
16use mas_storage::oauth2::OAuth2AuthorizationGrantRepository;
17use oauth2_types::{requests::ResponseMode, scope::Scope};
18use rand::RngCore;
19use sqlx::{PgConnection, types::Json};
20use ulid::Ulid;
21use url::Url;
22use uuid::Uuid;
23
24use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
25
26/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
27/// connection
28pub struct PgOAuth2AuthorizationGrantRepository<'c> {
29    conn: &'c mut PgConnection,
30}
31
32impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
33    /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
34    /// PostgreSQL connection
35    pub fn new(conn: &'c mut PgConnection) -> Self {
36        Self { conn }
37    }
38}
39
40struct GrantLookup {
41    oauth2_authorization_grant_id: Uuid,
42    created_at: DateTime<Utc>,
43    cancelled_at: Option<DateTime<Utc>>,
44    fulfilled_at: Option<DateTime<Utc>>,
45    exchanged_at: Option<DateTime<Utc>>,
46    scope: String,
47    state: Option<String>,
48    nonce: Option<String>,
49    redirect_uri: String,
50    response_mode: String,
51    response_type_code: bool,
52    response_type_id_token: bool,
53    authorization_code: Option<String>,
54    code_challenge: Option<String>,
55    code_challenge_method: Option<String>,
56    login_hint: Option<String>,
57    locale: Option<String>,
58    raw_parameters: Option<Json<BTreeMap<String, String>>>,
59    oauth2_client_id: Uuid,
60    oauth2_session_id: Option<Uuid>,
61}
62
63impl TryFrom<GrantLookup> for AuthorizationGrant {
64    type Error = DatabaseInconsistencyError;
65
66    fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
67        let id = value.oauth2_authorization_grant_id.into();
68        let scope: Scope = value.scope.parse().map_err(|e| {
69            DatabaseInconsistencyError::on("oauth2_authorization_grants")
70                .column("scope")
71                .row(id)
72                .source(e)
73        })?;
74
75        let stage = match (
76            value.fulfilled_at,
77            value.exchanged_at,
78            value.cancelled_at,
79            value.oauth2_session_id,
80        ) {
81            (None, None, None, None) => AuthorizationGrantStage::Pending,
82            (Some(fulfilled_at), None, None, Some(session_id)) => {
83                AuthorizationGrantStage::Fulfilled {
84                    session_id: session_id.into(),
85                    fulfilled_at,
86                }
87            }
88            (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
89                AuthorizationGrantStage::Exchanged {
90                    session_id: session_id.into(),
91                    fulfilled_at,
92                    exchanged_at,
93                }
94            }
95            (None, None, Some(cancelled_at), None) => {
96                AuthorizationGrantStage::Cancelled { cancelled_at }
97            }
98            _ => {
99                return Err(
100                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
101                        .column("stage")
102                        .row(id),
103                );
104            }
105        };
106
107        let pkce = match (value.code_challenge, value.code_challenge_method) {
108            (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
109                Some(Pkce {
110                    challenge_method: PkceCodeChallengeMethod::Plain,
111                    challenge,
112                })
113            }
114            (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
115                challenge_method: PkceCodeChallengeMethod::S256,
116                challenge,
117            }),
118            (None, None) => None,
119            _ => {
120                return Err(
121                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
122                        .column("code_challenge_method")
123                        .row(id),
124                );
125            }
126        };
127
128        let code: Option<AuthorizationCode> =
129            match (value.response_type_code, value.authorization_code, pkce) {
130                (false, None, None) => None,
131                (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
132                _ => {
133                    return Err(
134                        DatabaseInconsistencyError::on("oauth2_authorization_grants")
135                            .column("authorization_code")
136                            .row(id),
137                    );
138                }
139            };
140
141        let redirect_uri = value.redirect_uri.parse().map_err(|e| {
142            DatabaseInconsistencyError::on("oauth2_authorization_grants")
143                .column("redirect_uri")
144                .row(id)
145                .source(e)
146        })?;
147
148        let response_mode = value.response_mode.parse().map_err(|e| {
149            DatabaseInconsistencyError::on("oauth2_authorization_grants")
150                .column("response_mode")
151                .row(id)
152                .source(e)
153        })?;
154
155        Ok(AuthorizationGrant {
156            id,
157            stage,
158            client_id: value.oauth2_client_id.into(),
159            code,
160            scope,
161            state: value.state,
162            nonce: value.nonce,
163            response_mode,
164            redirect_uri,
165            created_at: value.created_at,
166            response_type_id_token: value.response_type_id_token,
167            login_hint: value.login_hint,
168            locale: value.locale,
169            raw_parameters: value.raw_parameters.map(|Json(x)| x).unwrap_or_default(),
170        })
171    }
172}
173
174#[async_trait]
175impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
176    type Error = DatabaseError;
177
178    #[tracing::instrument(
179        name = "db.oauth2_authorization_grant.add",
180        skip_all,
181        fields(
182            db.query.text,
183            grant.id,
184            grant.scope = %scope,
185            %client.id,
186        ),
187        err,
188    )]
189    async fn add(
190        &mut self,
191        rng: &mut (dyn RngCore + Send),
192        clock: &dyn Clock,
193        client: &Client,
194        redirect_uri: Url,
195        scope: Scope,
196        code: Option<AuthorizationCode>,
197        state: Option<String>,
198        nonce: Option<String>,
199        response_mode: ResponseMode,
200        response_type_id_token: bool,
201        login_hint: Option<String>,
202        locale: Option<String>,
203        raw_parameters: BTreeMap<String, String>,
204    ) -> Result<AuthorizationGrant, Self::Error> {
205        let code_challenge = code
206            .as_ref()
207            .and_then(|c| c.pkce.as_ref())
208            .map(|p| &p.challenge);
209        let code_challenge_method = code
210            .as_ref()
211            .and_then(|c| c.pkce.as_ref())
212            .map(|p| p.challenge_method.to_string());
213        let code_str = code.as_ref().map(|c| &c.code);
214
215        let created_at = clock.now();
216        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
217        tracing::Span::current().record("grant.id", tracing::field::display(id));
218
219        sqlx::query!(
220            r#"
221                INSERT INTO oauth2_authorization_grants (
222                     oauth2_authorization_grant_id,
223                     oauth2_client_id,
224                     redirect_uri,
225                     scope,
226                     state,
227                     nonce,
228                     response_mode,
229                     code_challenge,
230                     code_challenge_method,
231                     response_type_code,
232                     response_type_id_token,
233                     authorization_code,
234                     login_hint,
235                     locale,
236                     raw_parameters,
237                     created_at
238                )
239                VALUES
240                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
241            "#,
242            Uuid::from(id),
243            Uuid::from(client.id),
244            redirect_uri.to_string(),
245            scope.to_string(),
246            state,
247            nonce,
248            response_mode.to_string(),
249            code_challenge,
250            code_challenge_method,
251            code.is_some(),
252            response_type_id_token,
253            code_str,
254            login_hint,
255            locale,
256            Json(&raw_parameters) as _,
257            created_at,
258        )
259        .traced()
260        .execute(&mut *self.conn)
261        .await?;
262
263        Ok(AuthorizationGrant {
264            id,
265            stage: AuthorizationGrantStage::Pending,
266            code,
267            redirect_uri,
268            client_id: client.id,
269            scope,
270            state,
271            nonce,
272            response_mode,
273            created_at,
274            response_type_id_token,
275            login_hint,
276            locale,
277            raw_parameters,
278        })
279    }
280
281    #[tracing::instrument(
282        name = "db.oauth2_authorization_grant.lookup",
283        skip_all,
284        fields(
285            db.query.text,
286            grant.id = %id,
287        ),
288        err,
289    )]
290    async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
291        let res = sqlx::query_as!(
292            GrantLookup,
293            r#"
294                SELECT oauth2_authorization_grant_id
295                     , created_at
296                     , cancelled_at
297                     , fulfilled_at
298                     , exchanged_at
299                     , scope
300                     , state
301                     , redirect_uri
302                     , response_mode
303                     , nonce
304                     , oauth2_client_id
305                     , authorization_code
306                     , response_type_code
307                     , response_type_id_token
308                     , code_challenge
309                     , code_challenge_method
310                     , login_hint
311                     , locale
312                     , raw_parameters AS "raw_parameters: Json<BTreeMap<String, String>>"
313                     , oauth2_session_id
314                FROM
315                    oauth2_authorization_grants
316
317                WHERE oauth2_authorization_grant_id = $1
318            "#,
319            Uuid::from(id),
320        )
321        .traced()
322        .fetch_optional(&mut *self.conn)
323        .await?;
324
325        let Some(res) = res else { return Ok(None) };
326
327        Ok(Some(res.try_into()?))
328    }
329
330    #[tracing::instrument(
331        name = "db.oauth2_authorization_grant.find_by_code",
332        skip_all,
333        fields(
334            db.query.text,
335        ),
336        err,
337    )]
338    async fn find_by_code(
339        &mut self,
340        code: &str,
341    ) -> Result<Option<AuthorizationGrant>, Self::Error> {
342        let res = sqlx::query_as!(
343            GrantLookup,
344            r#"
345                SELECT oauth2_authorization_grant_id
346                     , created_at
347                     , cancelled_at
348                     , fulfilled_at
349                     , exchanged_at
350                     , scope
351                     , state
352                     , redirect_uri
353                     , response_mode
354                     , nonce
355                     , oauth2_client_id
356                     , authorization_code
357                     , response_type_code
358                     , response_type_id_token
359                     , code_challenge
360                     , code_challenge_method
361                     , login_hint
362                     , locale
363                     , raw_parameters AS "raw_parameters: Json<BTreeMap<String, String>>"
364                     , oauth2_session_id
365                FROM
366                    oauth2_authorization_grants
367
368                WHERE authorization_code = $1
369            "#,
370            code,
371        )
372        .traced()
373        .fetch_optional(&mut *self.conn)
374        .await?;
375
376        let Some(res) = res else { return Ok(None) };
377
378        Ok(Some(res.try_into()?))
379    }
380
381    #[tracing::instrument(
382        name = "db.oauth2_authorization_grant.fulfill",
383        skip_all,
384        fields(
385            db.query.text,
386            %grant.id,
387            client.id = %grant.client_id,
388            %session.id,
389        ),
390        err,
391    )]
392    async fn fulfill(
393        &mut self,
394        clock: &dyn Clock,
395        session: &Session,
396        grant: AuthorizationGrant,
397    ) -> Result<AuthorizationGrant, Self::Error> {
398        let fulfilled_at = clock.now();
399        let res = sqlx::query!(
400            r#"
401                UPDATE oauth2_authorization_grants
402                SET fulfilled_at = $2
403                  , oauth2_session_id = $3
404                WHERE oauth2_authorization_grant_id = $1
405            "#,
406            Uuid::from(grant.id),
407            fulfilled_at,
408            Uuid::from(session.id),
409        )
410        .traced()
411        .execute(&mut *self.conn)
412        .await?;
413
414        DatabaseError::ensure_affected_rows(&res, 1)?;
415
416        // XXX: check affected rows & new methods
417        let grant = grant
418            .fulfill(fulfilled_at, session)
419            .map_err(DatabaseError::to_invalid_operation)?;
420
421        Ok(grant)
422    }
423
424    #[tracing::instrument(
425        name = "db.oauth2_authorization_grant.exchange",
426        skip_all,
427        fields(
428            db.query.text,
429            %grant.id,
430            client.id = %grant.client_id,
431        ),
432        err,
433    )]
434    async fn exchange(
435        &mut self,
436        clock: &dyn Clock,
437        grant: AuthorizationGrant,
438    ) -> Result<AuthorizationGrant, Self::Error> {
439        let exchanged_at = clock.now();
440        let res = sqlx::query!(
441            r#"
442                UPDATE oauth2_authorization_grants
443                SET exchanged_at = $2
444                WHERE oauth2_authorization_grant_id = $1
445            "#,
446            Uuid::from(grant.id),
447            exchanged_at,
448        )
449        .traced()
450        .execute(&mut *self.conn)
451        .await?;
452
453        DatabaseError::ensure_affected_rows(&res, 1)?;
454
455        let grant = grant
456            .exchange(exchanged_at)
457            .map_err(DatabaseError::to_invalid_operation)?;
458
459        Ok(grant)
460    }
461
462    #[tracing::instrument(
463        name = "db.oauth2_authorization_grant.cleanup",
464        skip_all,
465        fields(
466            db.query.text,
467            since = since.map(tracing::field::display),
468            until = %until,
469            limit = limit,
470        ),
471        err,
472    )]
473    async fn cleanup(
474        &mut self,
475        since: Option<Ulid>,
476        until: Ulid,
477        limit: usize,
478    ) -> Result<(usize, Option<Ulid>), Self::Error> {
479        // `MAX(uuid)` isn't a thing in Postgres, so we can't just re-select the
480        // deleted rows and do a MAX on the `oauth2_authorization_grant_id`.
481        // Instead, we do the aggregation on the client side, which is a little
482        // less efficient, but good enough.
483        let res = sqlx::query_scalar!(
484            r#"
485                WITH to_delete AS (
486                    SELECT oauth2_authorization_grant_id
487                    FROM oauth2_authorization_grants
488                    WHERE ($1::uuid IS NULL OR oauth2_authorization_grant_id > $1)
489                    AND oauth2_authorization_grant_id <= $2
490                    ORDER BY oauth2_authorization_grant_id
491                    LIMIT $3
492                )
493                DELETE FROM oauth2_authorization_grants
494                USING to_delete
495                WHERE oauth2_authorization_grants.oauth2_authorization_grant_id = to_delete.oauth2_authorization_grant_id
496                RETURNING oauth2_authorization_grants.oauth2_authorization_grant_id
497            "#,
498            since.map(Uuid::from),
499            Uuid::from(until),
500            i64::try_from(limit).unwrap_or(i64::MAX)
501        )
502        .traced()
503        .fetch_all(&mut *self.conn)
504        .await?;
505
506        let count = res.len();
507        let max_id = res.into_iter().max();
508
509        Ok((count, max_id.map(Ulid::from)))
510    }
511}