1use async_trait::async_trait;
12use mas_data_model::{Clock, User};
13use mas_storage::user::{UserFilter, UserRepository};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, SimpleExpr, extension::postgres::PgExpr as _};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use uuid::Uuid;
20
21use crate::{
22 DatabaseError,
23 filter::{Filter, StatementExt},
24 iden::{CompatSessions, OAuth2Sessions, Users},
25 pagination::QueryBuilderExt,
26 tracing::ExecuteExt,
27};
28
29mod email;
30mod password;
31mod recovery;
32mod registration;
33mod registration_token;
34mod session;
35mod terms;
36
37#[cfg(test)]
38mod tests;
39
40pub use self::{
41 email::PgUserEmailRepository, password::PgUserPasswordRepository,
42 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
43 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
44 terms::PgUserTermsRepository,
45};
46
47pub struct PgUserRepository<'c> {
49 conn: &'c mut PgConnection,
50}
51
52impl<'c> PgUserRepository<'c> {
53 pub fn new(conn: &'c mut PgConnection) -> Self {
55 Self { conn }
56 }
57}
58
59mod priv_ {
60 #![allow(missing_docs)]
63
64 use chrono::{DateTime, Utc};
65 use mas_storage::pagination::Node;
66 use sea_query::enum_def;
67 use ulid::Ulid;
68 use uuid::Uuid;
69
70 #[derive(Debug, Clone, sqlx::FromRow)]
71 #[enum_def]
72 pub(super) struct UserLookup {
73 pub(super) user_id: Uuid,
74 pub(super) username: String,
75 pub(super) created_at: DateTime<Utc>,
76 pub(super) locked_at: Option<DateTime<Utc>>,
77 pub(super) deactivated_at: Option<DateTime<Utc>>,
78 pub(super) can_request_admin: bool,
79 pub(super) is_guest: bool,
80 }
81
82 impl Node<Ulid> for UserLookup {
83 fn cursor(&self) -> Ulid {
84 self.user_id.into()
85 }
86 }
87}
88
89use priv_::{UserLookup, UserLookupIden};
90
91impl From<UserLookup> for User {
92 fn from(value: UserLookup) -> Self {
93 let id = value.user_id.into();
94 Self {
95 id,
96 username: value.username,
97 sub: id.to_string(),
98 created_at: value.created_at,
99 locked_at: value.locked_at,
100 deactivated_at: value.deactivated_at,
101 can_request_admin: value.can_request_admin,
102 is_guest: value.is_guest,
103 }
104 }
105}
106
107impl Filter for UserFilter<'_> {
108 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109 sea_query::Condition::all()
110 .add_option(self.state().map(|state| {
111 match state {
112 mas_storage::user::UserState::Deactivated => {
113 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
114 }
115 mas_storage::user::UserState::Locked => {
116 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
117 }
118 mas_storage::user::UserState::Active => {
119 Expr::col((Users::Table, Users::LockedAt))
120 .is_null()
121 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
122 }
123 }
124 }))
125 .add_option(self.can_request_admin().map(|can_request_admin| {
126 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
127 }))
128 .add_option(
129 self.is_guest()
130 .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
131 )
132 .add_option(self.search().map(|search| {
133 Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
134 }))
135 .add_option(self.active_oauth2_session_for_any_of_clients().map(
136 |clients| -> SimpleExpr {
137 let client_ids: Vec<SimpleExpr> = clients
138 .iter()
139 .map(|c| Expr::val(Uuid::from(*c)).into())
140 .collect();
141 Expr::exists(
142 Query::select()
143 .expr(Expr::cust("1"))
144 .from(OAuth2Sessions::Table)
145 .and_where(
146 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId))
147 .equals((Users::Table, Users::UserId)),
148 )
149 .and_where(
150 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
151 .is_null(),
152 )
153 .and_where(
154 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
155 .is_in(client_ids),
156 )
157 .take(),
158 )
159 },
160 ))
161 .add_option(self.has_active_oauth2_session().map(|has| -> SimpleExpr {
162 let exists = Expr::exists(
163 Query::select()
164 .expr(Expr::cust("1"))
165 .from(OAuth2Sessions::Table)
166 .and_where(
167 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId))
168 .equals((Users::Table, Users::UserId)),
169 )
170 .and_where(
171 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
172 .is_null(),
173 )
174 .take(),
175 );
176 if has { exists } else { exists.not() }
177 }))
178 .add_option(self.has_active_compat_session().map(|has| -> SimpleExpr {
179 let exists = Expr::exists(
180 Query::select()
181 .expr(Expr::cust("1"))
182 .from(CompatSessions::Table)
183 .and_where(
184 Expr::col((CompatSessions::Table, CompatSessions::UserId))
185 .equals((Users::Table, Users::UserId)),
186 )
187 .and_where(
188 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt))
189 .is_null(),
190 )
191 .take(),
192 );
193 if has { exists } else { exists.not() }
194 }))
195 }
196}
197
198#[async_trait]
199impl UserRepository for PgUserRepository<'_> {
200 type Error = DatabaseError;
201
202 #[tracing::instrument(
203 name = "db.user.lookup",
204 skip_all,
205 fields(
206 db.query.text,
207 user.id = %id,
208 ),
209 err,
210 )]
211 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
212 let res = sqlx::query_as!(
213 UserLookup,
214 r#"
215 SELECT user_id
216 , username
217 , created_at
218 , locked_at
219 , deactivated_at
220 , can_request_admin
221 , is_guest
222 FROM users
223 WHERE user_id = $1
224 "#,
225 Uuid::from(id),
226 )
227 .traced()
228 .fetch_optional(&mut *self.conn)
229 .await?;
230
231 let Some(res) = res else { return Ok(None) };
232
233 Ok(Some(res.into()))
234 }
235
236 #[tracing::instrument(
237 name = "db.user.find_by_username",
238 skip_all,
239 fields(
240 db.query.text,
241 user.username = username,
242 ),
243 err,
244 )]
245 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
246 let res = sqlx::query_as!(
250 UserLookup,
251 r#"
252 SELECT user_id
253 , username
254 , created_at
255 , locked_at
256 , deactivated_at
257 , can_request_admin
258 , is_guest
259 FROM users
260 WHERE LOWER(username) = LOWER($1)
261 "#,
262 username,
263 )
264 .traced()
265 .fetch_all(&mut *self.conn)
266 .await?;
267
268 match &res[..] {
269 [user] => Ok(Some(user.clone().into())),
271 [] => Ok(None),
273 list => {
274 if let Some(user) = list.iter().find(|user| user.username == username) {
277 Ok(Some(user.clone().into()))
278 } else {
279 Ok(None)
281 }
282 }
283 }
284 }
285
286 #[tracing::instrument(
287 name = "db.user.add",
288 skip_all,
289 fields(
290 db.query.text,
291 user.username = username,
292 user.id,
293 ),
294 err,
295 )]
296 async fn add(
297 &mut self,
298 rng: &mut (dyn RngCore + Send),
299 clock: &dyn Clock,
300 username: String,
301 ) -> Result<User, Self::Error> {
302 let created_at = clock.now();
303 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
304 tracing::Span::current().record("user.id", tracing::field::display(id));
305
306 let res = sqlx::query!(
307 r#"
308 INSERT INTO users (user_id, username, created_at)
309 VALUES ($1, $2, $3)
310 ON CONFLICT (username) DO NOTHING
311 "#,
312 Uuid::from(id),
313 username,
314 created_at,
315 )
316 .traced()
317 .execute(&mut *self.conn)
318 .await?;
319
320 DatabaseError::ensure_affected_rows(&res, 1)?;
323
324 Ok(User {
325 id,
326 username,
327 sub: id.to_string(),
328 created_at,
329 locked_at: None,
330 deactivated_at: None,
331 can_request_admin: false,
332 is_guest: false,
333 })
334 }
335
336 #[tracing::instrument(
337 name = "db.user.exists",
338 skip_all,
339 fields(
340 db.query.text,
341 user.username = username,
342 ),
343 err,
344 )]
345 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
346 let exists = sqlx::query_scalar!(
347 r#"
348 SELECT EXISTS(
349 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
350 ) AS "exists!"
351 "#,
352 username
353 )
354 .traced()
355 .fetch_one(&mut *self.conn)
356 .await?;
357
358 Ok(exists)
359 }
360
361 #[tracing::instrument(
362 name = "db.user.lock",
363 skip_all,
364 fields(
365 db.query.text,
366 %user.id,
367 ),
368 err,
369 )]
370 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
371 if user.locked_at.is_some() {
372 return Ok(user);
373 }
374
375 let locked_at = clock.now();
376 let res = sqlx::query!(
377 r#"
378 UPDATE users
379 SET locked_at = $1
380 WHERE user_id = $2
381 "#,
382 locked_at,
383 Uuid::from(user.id),
384 )
385 .traced()
386 .execute(&mut *self.conn)
387 .await?;
388
389 DatabaseError::ensure_affected_rows(&res, 1)?;
390
391 user.locked_at = Some(locked_at);
392
393 Ok(user)
394 }
395
396 #[tracing::instrument(
397 name = "db.user.unlock",
398 skip_all,
399 fields(
400 db.query.text,
401 %user.id,
402 ),
403 err,
404 )]
405 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
406 if user.locked_at.is_none() {
407 return Ok(user);
408 }
409
410 let res = sqlx::query!(
411 r#"
412 UPDATE users
413 SET locked_at = NULL
414 WHERE user_id = $1
415 "#,
416 Uuid::from(user.id),
417 )
418 .traced()
419 .execute(&mut *self.conn)
420 .await?;
421
422 DatabaseError::ensure_affected_rows(&res, 1)?;
423
424 user.locked_at = None;
425
426 Ok(user)
427 }
428
429 #[tracing::instrument(
430 name = "db.user.deactivate",
431 skip_all,
432 fields(
433 db.query.text,
434 %user.id,
435 ),
436 err,
437 )]
438 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
439 if user.deactivated_at.is_some() {
440 return Ok(user);
441 }
442
443 let deactivated_at = clock.now();
444 let res = sqlx::query!(
445 r#"
446 UPDATE users
447 SET deactivated_at = $2
448 WHERE user_id = $1
449 AND deactivated_at IS NULL
450 "#,
451 Uuid::from(user.id),
452 deactivated_at,
453 )
454 .traced()
455 .execute(&mut *self.conn)
456 .await?;
457
458 DatabaseError::ensure_affected_rows(&res, 1)?;
459
460 user.deactivated_at = Some(deactivated_at);
461
462 Ok(user)
463 }
464
465 #[tracing::instrument(
466 name = "db.user.reactivate",
467 skip_all,
468 fields(
469 db.query.text,
470 %user.id,
471 ),
472 err,
473 )]
474 async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
475 if user.deactivated_at.is_none() {
476 return Ok(user);
477 }
478
479 let res = sqlx::query!(
480 r#"
481 UPDATE users
482 SET deactivated_at = NULL
483 WHERE user_id = $1
484 "#,
485 Uuid::from(user.id),
486 )
487 .traced()
488 .execute(&mut *self.conn)
489 .await?;
490
491 DatabaseError::ensure_affected_rows(&res, 1)?;
492
493 user.deactivated_at = None;
494
495 Ok(user)
496 }
497
498 #[tracing::instrument(
499 name = "db.user.delete_unsupported_threepids",
500 skip_all,
501 fields(
502 db.query.text,
503 %user.id,
504 ),
505 err,
506 )]
507 async fn delete_unsupported_threepids(&mut self, user: &User) -> Result<usize, Self::Error> {
508 let res = sqlx::query!(
509 r#"
510 DELETE FROM user_unsupported_third_party_ids
511 WHERE user_id = $1
512 "#,
513 Uuid::from(user.id),
514 )
515 .traced()
516 .execute(&mut *self.conn)
517 .await?;
518
519 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
520 }
521
522 #[tracing::instrument(
523 name = "db.user.set_can_request_admin",
524 skip_all,
525 fields(
526 db.query.text,
527 %user.id,
528 user.can_request_admin = can_request_admin,
529 ),
530 err,
531 )]
532 async fn set_can_request_admin(
533 &mut self,
534 mut user: User,
535 can_request_admin: bool,
536 ) -> Result<User, Self::Error> {
537 let res = sqlx::query!(
538 r#"
539 UPDATE users
540 SET can_request_admin = $2
541 WHERE user_id = $1
542 "#,
543 Uuid::from(user.id),
544 can_request_admin,
545 )
546 .traced()
547 .execute(&mut *self.conn)
548 .await?;
549
550 DatabaseError::ensure_affected_rows(&res, 1)?;
551
552 user.can_request_admin = can_request_admin;
553
554 Ok(user)
555 }
556
557 #[tracing::instrument(
558 name = "db.user.list",
559 skip_all,
560 fields(
561 db.query.text,
562 ),
563 err,
564 )]
565 async fn list(
566 &mut self,
567 filter: UserFilter<'_>,
568 pagination: mas_storage::Pagination,
569 ) -> Result<mas_storage::Page<User>, Self::Error> {
570 let (sql, arguments) = Query::select()
571 .expr_as(
572 Expr::col((Users::Table, Users::UserId)),
573 UserLookupIden::UserId,
574 )
575 .expr_as(
576 Expr::col((Users::Table, Users::Username)),
577 UserLookupIden::Username,
578 )
579 .expr_as(
580 Expr::col((Users::Table, Users::CreatedAt)),
581 UserLookupIden::CreatedAt,
582 )
583 .expr_as(
584 Expr::col((Users::Table, Users::LockedAt)),
585 UserLookupIden::LockedAt,
586 )
587 .expr_as(
588 Expr::col((Users::Table, Users::DeactivatedAt)),
589 UserLookupIden::DeactivatedAt,
590 )
591 .expr_as(
592 Expr::col((Users::Table, Users::CanRequestAdmin)),
593 UserLookupIden::CanRequestAdmin,
594 )
595 .expr_as(
596 Expr::col((Users::Table, Users::IsGuest)),
597 UserLookupIden::IsGuest,
598 )
599 .from(Users::Table)
600 .apply_filter(filter)
601 .generate_pagination((Users::Table, Users::UserId), pagination)
602 .build_sqlx(PostgresQueryBuilder);
603
604 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
605 .traced()
606 .fetch_all(&mut *self.conn)
607 .await?;
608
609 let page = pagination.process(edges).map(User::from);
610
611 Ok(page)
612 }
613
614 #[tracing::instrument(
615 name = "db.user.count",
616 skip_all,
617 fields(
618 db.query.text,
619 ),
620 err,
621 )]
622 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
623 let (sql, arguments) = Query::select()
624 .expr(Expr::col((Users::Table, Users::UserId)).count())
625 .from(Users::Table)
626 .apply_filter(filter)
627 .build_sqlx(PostgresQueryBuilder);
628
629 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
630 .traced()
631 .fetch_one(&mut *self.conn)
632 .await?;
633
634 count
635 .try_into()
636 .map_err(DatabaseError::to_invalid_operation)
637 }
638
639 #[tracing::instrument(
640 name = "db.user.acquire_lock_for_sync",
641 skip_all,
642 fields(
643 db.query.text,
644 user.id = %user.id,
645 ),
646 err,
647 )]
648 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
649 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
657
658 sqlx::query!(
661 r#"
662 SELECT pg_advisory_xact_lock($1)
663 "#,
664 lock_id,
665 )
666 .traced()
667 .execute(&mut *self.conn)
668 .await?;
669
670 Ok(())
671 }
672}