1use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13 Authentication, AuthenticationMethod, BrowserSession, Clock, Password,
14 UpstreamOAuthAuthorizationSession, User,
15};
16use mas_storage::{
17 Page, Pagination,
18 pagination::Node,
19 user::{BrowserSessionFilter, BrowserSessionRepository},
20};
21use rand::RngCore;
22use sea_query::{Expr, PostgresQueryBuilder, Query};
23use sea_query_binder::SqlxBinder;
24use sqlx::PgConnection;
25use ulid::Ulid;
26use uuid::Uuid;
27
28use crate::{
29 DatabaseError, DatabaseInconsistencyError,
30 filter::StatementExt,
31 iden::{UpstreamOAuthAuthorizationSessions, UserSessions, Users},
32 pagination::QueryBuilderExt,
33 tracing::ExecuteExt,
34 ulid_at::{max_ulid_at, min_ulid_at},
35};
36
37pub struct PgBrowserSessionRepository<'c> {
40 conn: &'c mut PgConnection,
41}
42
43impl<'c> PgBrowserSessionRepository<'c> {
44 pub fn new(conn: &'c mut PgConnection) -> Self {
47 Self { conn }
48 }
49}
50
51#[expect(clippy::struct_field_names)]
52#[derive(sqlx::FromRow)]
53#[sea_query::enum_def]
54struct SessionLookup {
55 user_session_id: Uuid,
56 user_session_created_at: DateTime<Utc>,
57 user_session_finished_at: Option<DateTime<Utc>>,
58 user_session_user_agent: Option<String>,
59 user_session_last_active_at: Option<DateTime<Utc>>,
60 user_session_last_active_ip: Option<IpAddr>,
61 user_id: Uuid,
62 user_username: String,
63 user_created_at: DateTime<Utc>,
64 user_locked_at: Option<DateTime<Utc>>,
65 user_deactivated_at: Option<DateTime<Utc>>,
66 user_can_request_admin: bool,
67 user_is_guest: bool,
68}
69
70impl Node<Ulid> for SessionLookup {
71 fn cursor(&self) -> Ulid {
72 self.user_id.into()
73 }
74}
75
76impl TryFrom<SessionLookup> for BrowserSession {
77 type Error = DatabaseInconsistencyError;
78
79 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
80 let id = Ulid::from(value.user_id);
81 let user = User {
82 id,
83 username: value.user_username,
84 sub: id.to_string(),
85 created_at: value.user_created_at,
86 locked_at: value.user_locked_at,
87 deactivated_at: value.user_deactivated_at,
88 can_request_admin: value.user_can_request_admin,
89 is_guest: value.user_is_guest,
90 };
91
92 Ok(BrowserSession {
93 id: value.user_session_id.into(),
94 user,
95 created_at: value.user_session_created_at,
96 finished_at: value.user_session_finished_at,
97 user_agent: value.user_session_user_agent,
98 last_active_at: value.user_session_last_active_at,
99 last_active_ip: value.user_session_last_active_ip,
100 })
101 }
102}
103
104struct AuthenticationLookup {
105 user_session_authentication_id: Uuid,
106 created_at: DateTime<Utc>,
107 user_password_id: Option<Uuid>,
108 upstream_oauth_authorization_session_id: Option<Uuid>,
109}
110
111impl TryFrom<AuthenticationLookup> for Authentication {
112 type Error = DatabaseInconsistencyError;
113
114 fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
115 let id = Ulid::from(value.user_session_authentication_id);
116 let authentication_method = match (
117 value.user_password_id.map(Into::into),
118 value
119 .upstream_oauth_authorization_session_id
120 .map(Into::into),
121 ) {
122 (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
123 (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
124 upstream_oauth2_session_id,
125 },
126 (None, None) => AuthenticationMethod::Unknown,
127 _ => {
128 return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
129 }
130 };
131
132 Ok(Authentication {
133 id,
134 created_at: value.created_at,
135 authentication_method,
136 })
137 }
138}
139
140impl crate::filter::Filter for BrowserSessionFilter<'_> {
141 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
142 sea_query::Condition::all()
143 .add_option(self.user().map(|user| {
144 Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
145 }))
146 .add_option(self.state().map(|state| {
147 if state.is_active() {
148 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
149 } else {
150 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
151 }
152 }))
153 .add_option(self.last_active_after().map(|last_active_after| {
154 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
155 }))
156 .add_option(self.last_active_before().map(|last_active_before| {
157 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
158 }))
159 .add_option(self.created_after().map(|created_after| {
160 Expr::col((UserSessions::Table, UserSessions::UserSessionId))
164 .gt(max_ulid_at(created_after))
165 }))
166 .add_option(self.created_before().map(|created_before| {
167 Expr::col((UserSessions::Table, UserSessions::UserSessionId))
168 .lt(min_ulid_at(created_before))
169 }))
170 .add_option(self.linked_to_upstream_sessions().map(|filter| {
171 Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
172 Query::select()
173 .expr(Expr::col((
174 UpstreamOAuthAuthorizationSessions::Table,
175 UpstreamOAuthAuthorizationSessions::UserSessionId,
176 )))
177 .from(UpstreamOAuthAuthorizationSessions::Table)
178 .apply_filter(filter)
179 .take(),
180 )
181 }))
182 }
183}
184
185#[async_trait]
186impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
187 type Error = DatabaseError;
188
189 #[tracing::instrument(
190 name = "db.browser_session.lookup",
191 skip_all,
192 fields(
193 db.query.text,
194 user_session.id = %id,
195 ),
196 err,
197 )]
198 async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
199 let res = sqlx::query_as!(
200 SessionLookup,
201 r#"
202 SELECT s.user_session_id
203 , s.created_at AS "user_session_created_at"
204 , s.finished_at AS "user_session_finished_at"
205 , s.user_agent AS "user_session_user_agent"
206 , s.last_active_at AS "user_session_last_active_at"
207 , s.last_active_ip AS "user_session_last_active_ip: IpAddr"
208 , u.user_id
209 , u.username AS "user_username"
210 , u.created_at AS "user_created_at"
211 , u.locked_at AS "user_locked_at"
212 , u.deactivated_at AS "user_deactivated_at"
213 , u.can_request_admin AS "user_can_request_admin"
214 , u.is_guest AS "user_is_guest"
215 FROM user_sessions s
216 INNER JOIN users u
217 USING (user_id)
218 WHERE s.user_session_id = $1
219 "#,
220 Uuid::from(id),
221 )
222 .traced()
223 .fetch_optional(&mut *self.conn)
224 .await?;
225
226 let Some(res) = res else { return Ok(None) };
227
228 Ok(Some(res.try_into()?))
229 }
230
231 #[tracing::instrument(
232 name = "db.browser_session.add",
233 skip_all,
234 fields(
235 db.query.text,
236 %user.id,
237 user_session.id,
238 ),
239 err,
240 )]
241 async fn add(
242 &mut self,
243 rng: &mut (dyn RngCore + Send),
244 clock: &dyn Clock,
245 user: &User,
246 user_agent: Option<String>,
247 ) -> Result<BrowserSession, Self::Error> {
248 let created_at = clock.now();
249 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
250 tracing::Span::current().record("user_session.id", tracing::field::display(id));
251
252 sqlx::query!(
253 r#"
254 INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
255 VALUES ($1, $2, $3, $4)
256 "#,
257 Uuid::from(id),
258 Uuid::from(user.id),
259 created_at,
260 user_agent.as_deref(),
261 )
262 .traced()
263 .execute(&mut *self.conn)
264 .await?;
265
266 let session = BrowserSession {
267 id,
268 user: user.clone(),
270 created_at,
271 finished_at: None,
272 user_agent,
273 last_active_at: None,
274 last_active_ip: None,
275 };
276
277 Ok(session)
278 }
279
280 #[tracing::instrument(
281 name = "db.browser_session.finish",
282 skip_all,
283 fields(
284 db.query.text,
285 %user_session.id,
286 ),
287 err,
288 )]
289 async fn finish(
290 &mut self,
291 clock: &dyn Clock,
292 mut user_session: BrowserSession,
293 ) -> Result<BrowserSession, Self::Error> {
294 let finished_at = clock.now();
295 let res = sqlx::query!(
296 r#"
297 UPDATE user_sessions
298 SET finished_at = $1
299 WHERE user_session_id = $2
300 "#,
301 finished_at,
302 Uuid::from(user_session.id),
303 )
304 .traced()
305 .execute(&mut *self.conn)
306 .await?;
307
308 user_session.finished_at = Some(finished_at);
309
310 DatabaseError::ensure_affected_rows(&res, 1)?;
311
312 Ok(user_session)
313 }
314
315 #[tracing::instrument(
316 name = "db.browser_session.finish_bulk",
317 skip_all,
318 fields(
319 db.query.text,
320 ),
321 err,
322 )]
323 async fn finish_bulk(
324 &mut self,
325 clock: &dyn Clock,
326 filter: BrowserSessionFilter<'_>,
327 ) -> Result<usize, Self::Error> {
328 let finished_at = clock.now();
329 let (sql, arguments) = sea_query::Query::update()
330 .table(UserSessions::Table)
331 .value(UserSessions::FinishedAt, finished_at)
332 .apply_filter(filter)
333 .build_sqlx(PostgresQueryBuilder);
334
335 let res = sqlx::query_with(&sql, arguments)
336 .traced()
337 .execute(&mut *self.conn)
338 .await?;
339
340 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
341 }
342
343 #[tracing::instrument(
344 name = "db.browser_session.list",
345 skip_all,
346 fields(
347 db.query.text,
348 ),
349 err,
350 )]
351 async fn list(
352 &mut self,
353 filter: BrowserSessionFilter<'_>,
354 pagination: Pagination,
355 ) -> Result<Page<BrowserSession>, Self::Error> {
356 let (sql, arguments) = sea_query::Query::select()
357 .expr_as(
358 Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
359 SessionLookupIden::UserSessionId,
360 )
361 .expr_as(
362 Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
363 SessionLookupIden::UserSessionCreatedAt,
364 )
365 .expr_as(
366 Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
367 SessionLookupIden::UserSessionFinishedAt,
368 )
369 .expr_as(
370 Expr::col((UserSessions::Table, UserSessions::UserAgent)),
371 SessionLookupIden::UserSessionUserAgent,
372 )
373 .expr_as(
374 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
375 SessionLookupIden::UserSessionLastActiveAt,
376 )
377 .expr_as(
378 Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
379 SessionLookupIden::UserSessionLastActiveIp,
380 )
381 .expr_as(
382 Expr::col((Users::Table, Users::UserId)),
383 SessionLookupIden::UserId,
384 )
385 .expr_as(
386 Expr::col((Users::Table, Users::Username)),
387 SessionLookupIden::UserUsername,
388 )
389 .expr_as(
390 Expr::col((Users::Table, Users::CreatedAt)),
391 SessionLookupIden::UserCreatedAt,
392 )
393 .expr_as(
394 Expr::col((Users::Table, Users::LockedAt)),
395 SessionLookupIden::UserLockedAt,
396 )
397 .expr_as(
398 Expr::col((Users::Table, Users::DeactivatedAt)),
399 SessionLookupIden::UserDeactivatedAt,
400 )
401 .expr_as(
402 Expr::col((Users::Table, Users::CanRequestAdmin)),
403 SessionLookupIden::UserCanRequestAdmin,
404 )
405 .expr_as(
406 Expr::col((Users::Table, Users::IsGuest)),
407 SessionLookupIden::UserIsGuest,
408 )
409 .from(UserSessions::Table)
410 .inner_join(
411 Users::Table,
412 Expr::col((UserSessions::Table, UserSessions::UserId))
413 .equals((Users::Table, Users::UserId)),
414 )
415 .apply_filter(filter)
416 .generate_pagination(
417 (UserSessions::Table, UserSessions::UserSessionId),
418 pagination,
419 )
420 .build_sqlx(PostgresQueryBuilder);
421
422 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
423 .traced()
424 .fetch_all(&mut *self.conn)
425 .await?;
426
427 let page = pagination
428 .process(edges)
429 .try_map(BrowserSession::try_from)?;
430
431 Ok(page)
432 }
433
434 #[tracing::instrument(
435 name = "db.browser_session.count",
436 skip_all,
437 fields(
438 db.query.text,
439 ),
440 err,
441 )]
442 async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
443 let (sql, arguments) = sea_query::Query::select()
444 .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
445 .from(UserSessions::Table)
446 .apply_filter(filter)
447 .build_sqlx(PostgresQueryBuilder);
448
449 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
450 .traced()
451 .fetch_one(&mut *self.conn)
452 .await?;
453
454 count
455 .try_into()
456 .map_err(DatabaseError::to_invalid_operation)
457 }
458
459 #[tracing::instrument(
460 name = "db.browser_session.authenticate_with_password",
461 skip_all,
462 fields(
463 db.query.text,
464 %user_session.id,
465 %user_password.id,
466 user_session_authentication.id,
467 ),
468 err,
469 )]
470 async fn authenticate_with_password(
471 &mut self,
472 rng: &mut (dyn RngCore + Send),
473 clock: &dyn Clock,
474 user_session: &BrowserSession,
475 user_password: &Password,
476 ) -> Result<Authentication, Self::Error> {
477 let created_at = clock.now();
478 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
479 tracing::Span::current().record(
480 "user_session_authentication.id",
481 tracing::field::display(id),
482 );
483
484 sqlx::query!(
485 r#"
486 INSERT INTO user_session_authentications
487 (user_session_authentication_id, user_session_id, created_at, user_password_id)
488 VALUES ($1, $2, $3, $4)
489 "#,
490 Uuid::from(id),
491 Uuid::from(user_session.id),
492 created_at,
493 Uuid::from(user_password.id),
494 )
495 .traced()
496 .execute(&mut *self.conn)
497 .await?;
498
499 Ok(Authentication {
500 id,
501 created_at,
502 authentication_method: AuthenticationMethod::Password {
503 user_password_id: user_password.id,
504 },
505 })
506 }
507
508 #[tracing::instrument(
509 name = "db.browser_session.authenticate_with_upstream",
510 skip_all,
511 fields(
512 db.query.text,
513 %user_session.id,
514 %upstream_oauth_session.id,
515 user_session_authentication.id,
516 ),
517 err,
518 )]
519 async fn authenticate_with_upstream(
520 &mut self,
521 rng: &mut (dyn RngCore + Send),
522 clock: &dyn Clock,
523 user_session: &BrowserSession,
524 upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
525 ) -> Result<Authentication, Self::Error> {
526 let created_at = clock.now();
527 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
528 tracing::Span::current().record(
529 "user_session_authentication.id",
530 tracing::field::display(id),
531 );
532
533 sqlx::query!(
534 r#"
535 INSERT INTO user_session_authentications
536 (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
537 VALUES ($1, $2, $3, $4)
538 "#,
539 Uuid::from(id),
540 Uuid::from(user_session.id),
541 created_at,
542 Uuid::from(upstream_oauth_session.id),
543 )
544 .traced()
545 .execute(&mut *self.conn)
546 .await?;
547
548 Ok(Authentication {
549 id,
550 created_at,
551 authentication_method: AuthenticationMethod::UpstreamOAuth2 {
552 upstream_oauth2_session_id: upstream_oauth_session.id,
553 },
554 })
555 }
556
557 #[tracing::instrument(
558 name = "db.browser_session.get_last_authentication",
559 skip_all,
560 fields(
561 db.query.text,
562 %user_session.id,
563 ),
564 err,
565 )]
566 async fn get_last_authentication(
567 &mut self,
568 user_session: &BrowserSession,
569 ) -> Result<Option<Authentication>, Self::Error> {
570 let authentication = sqlx::query_as!(
571 AuthenticationLookup,
572 r#"
573 SELECT user_session_authentication_id
574 , created_at
575 , user_password_id
576 , upstream_oauth_authorization_session_id
577 FROM user_session_authentications
578 WHERE user_session_id = $1
579 ORDER BY created_at DESC
580 LIMIT 1
581 "#,
582 Uuid::from(user_session.id),
583 )
584 .traced()
585 .fetch_optional(&mut *self.conn)
586 .await?;
587
588 let Some(authentication) = authentication else {
589 return Ok(None);
590 };
591
592 let authentication = Authentication::try_from(authentication)?;
593 Ok(Some(authentication))
594 }
595
596 #[tracing::instrument(
597 name = "db.browser_session.record_batch_activity",
598 skip_all,
599 fields(
600 db.query.text,
601 ),
602 err,
603 )]
604 async fn record_batch_activity(
605 &mut self,
606 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
607 ) -> Result<(), Self::Error> {
608 activities.sort_unstable();
611 let mut ids = Vec::with_capacity(activities.len());
612 let mut last_activities = Vec::with_capacity(activities.len());
613 let mut ips = Vec::with_capacity(activities.len());
614
615 for (id, last_activity, ip) in activities {
616 ids.push(Uuid::from(id));
617 last_activities.push(last_activity);
618 ips.push(ip);
619 }
620
621 let res = sqlx::query!(
622 r#"
623 UPDATE user_sessions
624 SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
625 , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
626 FROM (
627 SELECT *
628 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
629 AS t(user_session_id, last_active_at, last_active_ip)
630 ) AS t
631 WHERE user_sessions.user_session_id = t.user_session_id
632 "#,
633 &ids,
634 &last_activities,
635 &ips as &[Option<IpAddr>],
636 )
637 .traced()
638 .execute(&mut *self.conn)
639 .await?;
640
641 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
642
643 Ok(())
644 }
645
646 #[tracing::instrument(
647 name = "db.browser_session.cleanup_finished",
648 skip_all,
649 fields(
650 db.query.text,
651 since = since.map(tracing::field::display),
652 until = %until,
653 limit = limit,
654 ),
655 err,
656 )]
657 async fn cleanup_finished(
658 &mut self,
659 since: Option<DateTime<Utc>>,
660 until: DateTime<Utc>,
661 limit: usize,
662 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
663 let res = sqlx::query!(
664 r#"
665 WITH
666 to_delete AS (
667 SELECT user_session_id, finished_at
668 FROM user_sessions us
669 WHERE us.finished_at IS NOT NULL
670 AND ($1::timestamptz IS NULL OR us.finished_at >= $1)
671 AND us.finished_at < $2
672 -- Only delete if no oauth2_sessions reference this user_session
673 AND NOT EXISTS (
674 SELECT 1 FROM oauth2_sessions os
675 WHERE os.user_session_id = us.user_session_id
676 )
677 -- Only delete if no compat_sessions reference this user_session
678 AND NOT EXISTS (
679 SELECT 1 FROM compat_sessions cs
680 WHERE cs.user_session_id = us.user_session_id
681 )
682 ORDER BY us.finished_at ASC
683 LIMIT $3
684 FOR UPDATE OF us
685 ),
686 deleted_authentications AS (
687 DELETE FROM user_session_authentications USING to_delete
688 WHERE user_session_authentications.user_session_id = to_delete.user_session_id
689 ),
690 deleted_sessions AS (
691 DELETE FROM user_sessions USING to_delete
692 WHERE user_sessions.user_session_id = to_delete.user_session_id
693 RETURNING user_sessions.finished_at
694 )
695 SELECT COUNT(*) as "count!", MAX(finished_at) as last_finished_at FROM deleted_sessions
696 "#,
697 since,
698 until,
699 i64::try_from(limit).unwrap_or(i64::MAX),
700 )
701 .traced()
702 .fetch_one(&mut *self.conn)
703 .await?;
704
705 Ok((
706 res.count.try_into().unwrap_or(usize::MAX),
707 res.last_finished_at,
708 ))
709 }
710
711 #[tracing::instrument(
712 name = "db.browser_session.cleanup_inactive_ips",
713 skip_all,
714 fields(
715 db.query.text,
716 since = since.map(tracing::field::display),
717 threshold = %threshold,
718 limit = limit,
719 ),
720 err,
721 )]
722 async fn cleanup_inactive_ips(
723 &mut self,
724 since: Option<DateTime<Utc>>,
725 threshold: DateTime<Utc>,
726 limit: usize,
727 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
728 let res = sqlx::query!(
729 r#"
730 WITH to_update AS (
731 SELECT user_session_id, last_active_at
732 FROM user_sessions
733 WHERE last_active_ip IS NOT NULL
734 AND last_active_at IS NOT NULL
735 AND ($1::timestamptz IS NULL OR last_active_at >= $1)
736 AND last_active_at < $2
737 ORDER BY last_active_at ASC
738 LIMIT $3
739 FOR UPDATE
740 ),
741 updated AS (
742 UPDATE user_sessions
743 SET last_active_ip = NULL
744 FROM to_update
745 WHERE user_sessions.user_session_id = to_update.user_session_id
746 RETURNING user_sessions.last_active_at
747 )
748 SELECT COUNT(*) AS "count!", MAX(last_active_at) AS last_active_at FROM updated
749 "#,
750 since,
751 threshold,
752 i64::try_from(limit).unwrap_or(i64::MAX),
753 )
754 .traced()
755 .fetch_one(&mut *self.conn)
756 .await?;
757
758 Ok((
759 res.count.try_into().unwrap_or(usize::MAX),
760 res.last_active_at,
761 ))
762 }
763}