Skip to main content

mas_storage_pg/user/
mod.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8//! A module containing the PostgreSQL implementation of the user-related
9//! repositories
10
11use async_trait::async_trait;
12use mas_data_model::{Clock, User};
13use mas_storage::user::{UserFilter, UserRepository};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, SimpleExpr, extension::postgres::PgExpr as _};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use uuid::Uuid;
20
21use crate::{
22    DatabaseError,
23    filter::{Filter, StatementExt},
24    iden::{CompatSessions, OAuth2Sessions, Users},
25    pagination::QueryBuilderExt,
26    tracing::ExecuteExt,
27};
28
29mod email;
30mod password;
31mod recovery;
32mod registration;
33mod registration_token;
34mod session;
35mod terms;
36
37#[cfg(test)]
38mod tests;
39
40pub use self::{
41    email::PgUserEmailRepository, password::PgUserPasswordRepository,
42    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
43    registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
44    terms::PgUserTermsRepository,
45};
46
47/// An implementation of [`UserRepository`] for a PostgreSQL connection
48pub struct PgUserRepository<'c> {
49    conn: &'c mut PgConnection,
50}
51
52impl<'c> PgUserRepository<'c> {
53    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
54    pub fn new(conn: &'c mut PgConnection) -> Self {
55        Self { conn }
56    }
57}
58
59mod priv_ {
60    // The enum_def macro generates a public enum, which we don't want, because it
61    // triggers the missing docs warning
62    #![allow(missing_docs)]
63
64    use chrono::{DateTime, Utc};
65    use mas_storage::pagination::Node;
66    use sea_query::enum_def;
67    use ulid::Ulid;
68    use uuid::Uuid;
69
70    #[derive(Debug, Clone, sqlx::FromRow)]
71    #[enum_def]
72    pub(super) struct UserLookup {
73        pub(super) user_id: Uuid,
74        pub(super) username: String,
75        pub(super) created_at: DateTime<Utc>,
76        pub(super) locked_at: Option<DateTime<Utc>>,
77        pub(super) deactivated_at: Option<DateTime<Utc>>,
78        pub(super) can_request_admin: bool,
79        pub(super) is_guest: bool,
80    }
81
82    impl Node<Ulid> for UserLookup {
83        fn cursor(&self) -> Ulid {
84            self.user_id.into()
85        }
86    }
87}
88
89use priv_::{UserLookup, UserLookupIden};
90
91impl From<UserLookup> for User {
92    fn from(value: UserLookup) -> Self {
93        let id = value.user_id.into();
94        Self {
95            id,
96            username: value.username,
97            sub: id.to_string(),
98            created_at: value.created_at,
99            locked_at: value.locked_at,
100            deactivated_at: value.deactivated_at,
101            can_request_admin: value.can_request_admin,
102            is_guest: value.is_guest,
103        }
104    }
105}
106
107impl Filter for UserFilter<'_> {
108    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109        sea_query::Condition::all()
110            .add_option(self.state().map(|state| {
111                match state {
112                    mas_storage::user::UserState::Deactivated => {
113                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
114                    }
115                    mas_storage::user::UserState::Locked => {
116                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
117                    }
118                    mas_storage::user::UserState::Active => {
119                        Expr::col((Users::Table, Users::LockedAt))
120                            .is_null()
121                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
122                    }
123                }
124            }))
125            .add_option(self.can_request_admin().map(|can_request_admin| {
126                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
127            }))
128            .add_option(
129                self.is_guest()
130                    .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
131            )
132            .add_option(self.search().map(|search| {
133                Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
134            }))
135            .add_option(self.active_oauth2_session_for_any_of_clients().map(
136                |clients| -> SimpleExpr {
137                    let client_ids: Vec<SimpleExpr> = clients
138                        .iter()
139                        .map(|c| Expr::val(Uuid::from(*c)).into())
140                        .collect();
141                    Expr::exists(
142                        Query::select()
143                            .expr(Expr::cust("1"))
144                            .from(OAuth2Sessions::Table)
145                            .and_where(
146                                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId))
147                                    .equals((Users::Table, Users::UserId)),
148                            )
149                            .and_where(
150                                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
151                                    .is_null(),
152                            )
153                            .and_where(
154                                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
155                                    .is_in(client_ids),
156                            )
157                            .take(),
158                    )
159                },
160            ))
161            .add_option(self.has_active_oauth2_session().map(|has| -> SimpleExpr {
162                let exists = Expr::exists(
163                    Query::select()
164                        .expr(Expr::cust("1"))
165                        .from(OAuth2Sessions::Table)
166                        .and_where(
167                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId))
168                                .equals((Users::Table, Users::UserId)),
169                        )
170                        .and_where(
171                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
172                                .is_null(),
173                        )
174                        .take(),
175                );
176                if has { exists } else { exists.not() }
177            }))
178            .add_option(self.has_active_compat_session().map(|has| -> SimpleExpr {
179                let exists = Expr::exists(
180                    Query::select()
181                        .expr(Expr::cust("1"))
182                        .from(CompatSessions::Table)
183                        .and_where(
184                            Expr::col((CompatSessions::Table, CompatSessions::UserId))
185                                .equals((Users::Table, Users::UserId)),
186                        )
187                        .and_where(
188                            Expr::col((CompatSessions::Table, CompatSessions::FinishedAt))
189                                .is_null(),
190                        )
191                        .take(),
192                );
193                if has { exists } else { exists.not() }
194            }))
195    }
196}
197
198#[async_trait]
199impl UserRepository for PgUserRepository<'_> {
200    type Error = DatabaseError;
201
202    #[tracing::instrument(
203        name = "db.user.lookup",
204        skip_all,
205        fields(
206            db.query.text,
207            user.id = %id,
208        ),
209        err,
210    )]
211    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
212        let res = sqlx::query_as!(
213            UserLookup,
214            r#"
215                SELECT user_id
216                     , username
217                     , created_at
218                     , locked_at
219                     , deactivated_at
220                     , can_request_admin
221                     , is_guest
222                FROM users
223                WHERE user_id = $1
224            "#,
225            Uuid::from(id),
226        )
227        .traced()
228        .fetch_optional(&mut *self.conn)
229        .await?;
230
231        let Some(res) = res else { return Ok(None) };
232
233        Ok(Some(res.into()))
234    }
235
236    #[tracing::instrument(
237        name = "db.user.find_by_username",
238        skip_all,
239        fields(
240            db.query.text,
241            user.username = username,
242        ),
243        err,
244    )]
245    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
246        // We may have multiple users with the same username, but with a different
247        // casing. In this case, we want to return the one which matches the exact
248        // casing
249        let res = sqlx::query_as!(
250            UserLookup,
251            r#"
252                SELECT user_id
253                     , username
254                     , created_at
255                     , locked_at
256                     , deactivated_at
257                     , can_request_admin
258                     , is_guest
259                FROM users
260                WHERE LOWER(username) = LOWER($1)
261            "#,
262            username,
263        )
264        .traced()
265        .fetch_all(&mut *self.conn)
266        .await?;
267
268        match &res[..] {
269            // Happy path: there is only one user matching the username…
270            [user] => Ok(Some(user.clone().into())),
271            // …or none.
272            [] => Ok(None),
273            list => {
274                // If there are multiple users with the same username, we want to
275                // return the one which matches the exact casing
276                if let Some(user) = list.iter().find(|user| user.username == username) {
277                    Ok(Some(user.clone().into()))
278                } else {
279                    // If none match exactly, we prefer to return nothing
280                    Ok(None)
281                }
282            }
283        }
284    }
285
286    #[tracing::instrument(
287        name = "db.user.add",
288        skip_all,
289        fields(
290            db.query.text,
291            user.username = username,
292            user.id,
293        ),
294        err,
295    )]
296    async fn add(
297        &mut self,
298        rng: &mut (dyn RngCore + Send),
299        clock: &dyn Clock,
300        username: String,
301    ) -> Result<User, Self::Error> {
302        let created_at = clock.now();
303        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
304        tracing::Span::current().record("user.id", tracing::field::display(id));
305
306        let res = sqlx::query!(
307            r#"
308                INSERT INTO users (user_id, username, created_at)
309                VALUES ($1, $2, $3)
310                ON CONFLICT (username) DO NOTHING
311            "#,
312            Uuid::from(id),
313            username,
314            created_at,
315        )
316        .traced()
317        .execute(&mut *self.conn)
318        .await?;
319
320        // If the user already exists, want to return an error but not poison the
321        // transaction
322        DatabaseError::ensure_affected_rows(&res, 1)?;
323
324        Ok(User {
325            id,
326            username,
327            sub: id.to_string(),
328            created_at,
329            locked_at: None,
330            deactivated_at: None,
331            can_request_admin: false,
332            is_guest: false,
333        })
334    }
335
336    #[tracing::instrument(
337        name = "db.user.exists",
338        skip_all,
339        fields(
340            db.query.text,
341            user.username = username,
342        ),
343        err,
344    )]
345    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
346        let exists = sqlx::query_scalar!(
347            r#"
348                SELECT EXISTS(
349                    SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
350                ) AS "exists!"
351            "#,
352            username
353        )
354        .traced()
355        .fetch_one(&mut *self.conn)
356        .await?;
357
358        Ok(exists)
359    }
360
361    #[tracing::instrument(
362        name = "db.user.lock",
363        skip_all,
364        fields(
365            db.query.text,
366            %user.id,
367        ),
368        err,
369    )]
370    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
371        if user.locked_at.is_some() {
372            return Ok(user);
373        }
374
375        let locked_at = clock.now();
376        let res = sqlx::query!(
377            r#"
378                UPDATE users
379                SET locked_at = $1
380                WHERE user_id = $2
381            "#,
382            locked_at,
383            Uuid::from(user.id),
384        )
385        .traced()
386        .execute(&mut *self.conn)
387        .await?;
388
389        DatabaseError::ensure_affected_rows(&res, 1)?;
390
391        user.locked_at = Some(locked_at);
392
393        Ok(user)
394    }
395
396    #[tracing::instrument(
397        name = "db.user.unlock",
398        skip_all,
399        fields(
400            db.query.text,
401            %user.id,
402        ),
403        err,
404    )]
405    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
406        if user.locked_at.is_none() {
407            return Ok(user);
408        }
409
410        let res = sqlx::query!(
411            r#"
412                UPDATE users
413                SET locked_at = NULL
414                WHERE user_id = $1
415            "#,
416            Uuid::from(user.id),
417        )
418        .traced()
419        .execute(&mut *self.conn)
420        .await?;
421
422        DatabaseError::ensure_affected_rows(&res, 1)?;
423
424        user.locked_at = None;
425
426        Ok(user)
427    }
428
429    #[tracing::instrument(
430        name = "db.user.deactivate",
431        skip_all,
432        fields(
433            db.query.text,
434            %user.id,
435        ),
436        err,
437    )]
438    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
439        if user.deactivated_at.is_some() {
440            return Ok(user);
441        }
442
443        let deactivated_at = clock.now();
444        let res = sqlx::query!(
445            r#"
446                UPDATE users
447                SET deactivated_at = $2
448                WHERE user_id = $1
449                  AND deactivated_at IS NULL
450            "#,
451            Uuid::from(user.id),
452            deactivated_at,
453        )
454        .traced()
455        .execute(&mut *self.conn)
456        .await?;
457
458        DatabaseError::ensure_affected_rows(&res, 1)?;
459
460        user.deactivated_at = Some(deactivated_at);
461
462        Ok(user)
463    }
464
465    #[tracing::instrument(
466        name = "db.user.reactivate",
467        skip_all,
468        fields(
469            db.query.text,
470            %user.id,
471        ),
472        err,
473    )]
474    async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
475        if user.deactivated_at.is_none() {
476            return Ok(user);
477        }
478
479        let res = sqlx::query!(
480            r#"
481                UPDATE users
482                SET deactivated_at = NULL
483                WHERE user_id = $1
484            "#,
485            Uuid::from(user.id),
486        )
487        .traced()
488        .execute(&mut *self.conn)
489        .await?;
490
491        DatabaseError::ensure_affected_rows(&res, 1)?;
492
493        user.deactivated_at = None;
494
495        Ok(user)
496    }
497
498    #[tracing::instrument(
499        name = "db.user.delete_unsupported_threepids",
500        skip_all,
501        fields(
502            db.query.text,
503            %user.id,
504        ),
505        err,
506    )]
507    async fn delete_unsupported_threepids(&mut self, user: &User) -> Result<usize, Self::Error> {
508        let res = sqlx::query!(
509            r#"
510                DELETE FROM user_unsupported_third_party_ids
511                WHERE user_id = $1
512            "#,
513            Uuid::from(user.id),
514        )
515        .traced()
516        .execute(&mut *self.conn)
517        .await?;
518
519        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
520    }
521
522    #[tracing::instrument(
523        name = "db.user.set_can_request_admin",
524        skip_all,
525        fields(
526            db.query.text,
527            %user.id,
528            user.can_request_admin = can_request_admin,
529        ),
530        err,
531    )]
532    async fn set_can_request_admin(
533        &mut self,
534        mut user: User,
535        can_request_admin: bool,
536    ) -> Result<User, Self::Error> {
537        let res = sqlx::query!(
538            r#"
539                UPDATE users
540                SET can_request_admin = $2
541                WHERE user_id = $1
542            "#,
543            Uuid::from(user.id),
544            can_request_admin,
545        )
546        .traced()
547        .execute(&mut *self.conn)
548        .await?;
549
550        DatabaseError::ensure_affected_rows(&res, 1)?;
551
552        user.can_request_admin = can_request_admin;
553
554        Ok(user)
555    }
556
557    #[tracing::instrument(
558        name = "db.user.list",
559        skip_all,
560        fields(
561            db.query.text,
562        ),
563        err,
564    )]
565    async fn list(
566        &mut self,
567        filter: UserFilter<'_>,
568        pagination: mas_storage::Pagination,
569    ) -> Result<mas_storage::Page<User>, Self::Error> {
570        let (sql, arguments) = Query::select()
571            .expr_as(
572                Expr::col((Users::Table, Users::UserId)),
573                UserLookupIden::UserId,
574            )
575            .expr_as(
576                Expr::col((Users::Table, Users::Username)),
577                UserLookupIden::Username,
578            )
579            .expr_as(
580                Expr::col((Users::Table, Users::CreatedAt)),
581                UserLookupIden::CreatedAt,
582            )
583            .expr_as(
584                Expr::col((Users::Table, Users::LockedAt)),
585                UserLookupIden::LockedAt,
586            )
587            .expr_as(
588                Expr::col((Users::Table, Users::DeactivatedAt)),
589                UserLookupIden::DeactivatedAt,
590            )
591            .expr_as(
592                Expr::col((Users::Table, Users::CanRequestAdmin)),
593                UserLookupIden::CanRequestAdmin,
594            )
595            .expr_as(
596                Expr::col((Users::Table, Users::IsGuest)),
597                UserLookupIden::IsGuest,
598            )
599            .from(Users::Table)
600            .apply_filter(filter)
601            .generate_pagination((Users::Table, Users::UserId), pagination)
602            .build_sqlx(PostgresQueryBuilder);
603
604        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
605            .traced()
606            .fetch_all(&mut *self.conn)
607            .await?;
608
609        let page = pagination.process(edges).map(User::from);
610
611        Ok(page)
612    }
613
614    #[tracing::instrument(
615        name = "db.user.count",
616        skip_all,
617        fields(
618            db.query.text,
619        ),
620        err,
621    )]
622    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
623        let (sql, arguments) = Query::select()
624            .expr(Expr::col((Users::Table, Users::UserId)).count())
625            .from(Users::Table)
626            .apply_filter(filter)
627            .build_sqlx(PostgresQueryBuilder);
628
629        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
630            .traced()
631            .fetch_one(&mut *self.conn)
632            .await?;
633
634        count
635            .try_into()
636            .map_err(DatabaseError::to_invalid_operation)
637    }
638
639    #[tracing::instrument(
640        name = "db.user.acquire_lock_for_sync",
641        skip_all,
642        fields(
643            db.query.text,
644            user.id = %user.id,
645        ),
646        err,
647    )]
648    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
649        // XXX: this lock isn't stictly scoped to users, but as we don't use many
650        // postgres advisory locks, it's fine for now. Later on, we could use row-level
651        // locks to make sure we don't get into trouble
652
653        // Convert the user ID to a u128 and grab the lower 64 bits
654        // As this includes 64bit of the random part of the ULID, it should be random
655        // enough to not collide
656        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
657
658        // Use a PG advisory lock, which will be released when the transaction is
659        // committed or rolled back
660        sqlx::query!(
661            r#"
662                SELECT pg_advisory_xact_lock($1)
663            "#,
664            lock_id,
665        )
666        .traced()
667        .execute(&mut *self.conn)
668        .await?;
669
670        Ok(())
671    }
672}