Skip to main content

mas_storage_pg/
app_session.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8//! A module containing PostgreSQL implementation of repositories for sessions
9
10use async_trait::async_trait;
11use mas_data_model::{
12    Clock, CompatSession, CompatSessionState, Device, Session, SessionState, User,
13};
14use mas_storage::{
15    Page, Pagination,
16    app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState},
17    compat::CompatSessionFilter,
18    oauth2::OAuth2SessionFilter,
19};
20use oauth2_types::scope::{Scope, ScopeToken};
21use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
22use sea_query::{
23    Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
24};
25use sea_query_binder::SqlxBinder;
26use sqlx::PgConnection;
27use tracing::Instrument;
28use ulid::Ulid;
29use uuid::Uuid;
30
31use crate::{
32    DatabaseError, ExecuteExt,
33    errors::DatabaseInconsistencyError,
34    filter::StatementExt,
35    iden::{CompatSessions, OAuth2Sessions},
36    pagination::QueryBuilderExt,
37};
38
39/// An implementation of [`AppSessionRepository`] for a PostgreSQL connection
40pub struct PgAppSessionRepository<'c> {
41    conn: &'c mut PgConnection,
42}
43
44impl<'c> PgAppSessionRepository<'c> {
45    /// Create a new [`PgAppSessionRepository`] from an active PostgreSQL
46    /// connection
47    pub fn new(conn: &'c mut PgConnection) -> Self {
48        Self { conn }
49    }
50}
51
52mod priv_ {
53    // The enum_def macro generates a public enum, which we don't want, because it
54    // triggers the missing docs warning
55
56    use std::net::IpAddr;
57
58    use chrono::{DateTime, Utc};
59    use mas_storage::pagination::Node;
60    use sea_query::enum_def;
61    use ulid::Ulid;
62    use uuid::Uuid;
63
64    #[derive(sqlx::FromRow)]
65    #[enum_def]
66    pub(super) struct AppSessionLookup {
67        pub(super) cursor: Uuid,
68        pub(super) compat_session_id: Option<Uuid>,
69        pub(super) oauth2_session_id: Option<Uuid>,
70        pub(super) oauth2_client_id: Option<Uuid>,
71        pub(super) user_session_id: Option<Uuid>,
72        pub(super) user_id: Option<Uuid>,
73        pub(super) scope_list: Option<Vec<String>>,
74        pub(super) device_id: Option<String>,
75        pub(super) human_name: Option<String>,
76        pub(super) created_at: DateTime<Utc>,
77        pub(super) finished_at: Option<DateTime<Utc>>,
78        pub(super) is_synapse_admin: Option<bool>,
79        pub(super) user_agent: Option<String>,
80        pub(super) last_active_at: Option<DateTime<Utc>>,
81        pub(super) last_active_ip: Option<IpAddr>,
82    }
83
84    impl Node<Ulid> for AppSessionLookup {
85        fn cursor(&self) -> Ulid {
86            self.cursor.into()
87        }
88    }
89}
90
91use priv_::{AppSessionLookup, AppSessionLookupIden};
92
93impl TryFrom<AppSessionLookup> for AppSession {
94    type Error = DatabaseError;
95
96    fn try_from(value: AppSessionLookup) -> Result<Self, Self::Error> {
97        // This is annoying to do, but we have to match on all the fields to determine
98        // whether it's a compat session or an oauth2 session
99        let AppSessionLookup {
100            cursor,
101            compat_session_id,
102            oauth2_session_id,
103            oauth2_client_id,
104            user_session_id,
105            user_id,
106            scope_list,
107            device_id,
108            human_name,
109            created_at,
110            finished_at,
111            is_synapse_admin,
112            user_agent,
113            last_active_at,
114            last_active_ip,
115        } = value;
116
117        let user_session_id = user_session_id.map(Ulid::from);
118
119        match (
120            compat_session_id,
121            oauth2_session_id,
122            oauth2_client_id,
123            user_id,
124            scope_list,
125            device_id,
126            is_synapse_admin,
127        ) {
128            (
129                Some(compat_session_id),
130                None,
131                None,
132                Some(user_id),
133                None,
134                device_id_opt,
135                Some(is_synapse_admin),
136            ) => {
137                let id = compat_session_id.into();
138                let device = device_id_opt
139                    .map(Device::try_from)
140                    .transpose()
141                    .map_err(|e| {
142                        DatabaseInconsistencyError::on("compat_sessions")
143                            .column("device_id")
144                            .row(id)
145                            .source(e)
146                    })?;
147
148                let state = match finished_at {
149                    None => CompatSessionState::Valid,
150                    Some(finished_at) => CompatSessionState::Finished { finished_at },
151                };
152
153                let session = CompatSession {
154                    id,
155                    state,
156                    user_id: user_id.into(),
157                    device,
158                    human_name,
159                    user_session_id,
160                    created_at,
161                    is_synapse_admin,
162                    user_agent,
163                    last_active_at,
164                    last_active_ip,
165                };
166
167                Ok(AppSession::Compat(Box::new(session)))
168            }
169
170            (
171                None,
172                Some(oauth2_session_id),
173                Some(oauth2_client_id),
174                user_id,
175                Some(scope_list),
176                None,
177                None,
178            ) => {
179                let id = oauth2_session_id.into();
180                let scope: Result<Scope, _> =
181                    scope_list.iter().map(|s| s.parse::<ScopeToken>()).collect();
182                let scope = scope.map_err(|e| {
183                    DatabaseInconsistencyError::on("oauth2_sessions")
184                        .column("scope")
185                        .row(id)
186                        .source(e)
187                })?;
188
189                let state = match value.finished_at {
190                    None => SessionState::Valid,
191                    Some(finished_at) => SessionState::Finished { finished_at },
192                };
193
194                let session = Session {
195                    id,
196                    state,
197                    created_at,
198                    client_id: oauth2_client_id.into(),
199                    user_id: user_id.map(Ulid::from),
200                    user_session_id,
201                    scope,
202                    user_agent,
203                    last_active_at,
204                    last_active_ip,
205                    human_name,
206                };
207
208                Ok(AppSession::OAuth2(Box::new(session)))
209            }
210
211            _ => Err(DatabaseInconsistencyError::on("sessions")
212                .row(cursor.into())
213                .into()),
214        }
215    }
216}
217
218/// Split a [`AppSessionFilter`] into two separate filters: a
219/// [`CompatSessionFilter`] and an [`OAuth2SessionFilter`].
220fn split_filter(
221    filter: AppSessionFilter<'_>,
222) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) {
223    let mut compat_filter = CompatSessionFilter::new();
224    let mut oauth2_filter = OAuth2SessionFilter::new();
225
226    if let Some(user) = filter.user() {
227        compat_filter = compat_filter.for_user(user);
228        oauth2_filter = oauth2_filter.for_user(user);
229    }
230
231    match filter.state() {
232        Some(AppSessionState::Active) => {
233            compat_filter = compat_filter.active_only();
234            oauth2_filter = oauth2_filter.active_only();
235        }
236        Some(AppSessionState::Finished) => {
237            compat_filter = compat_filter.finished_only();
238            oauth2_filter = oauth2_filter.finished_only();
239        }
240        None => {}
241    }
242
243    if let Some(device) = filter.device() {
244        compat_filter = compat_filter.for_device(device);
245        oauth2_filter = oauth2_filter.for_device(device);
246    }
247
248    if let Some(browser_session) = filter.browser_session() {
249        compat_filter = compat_filter.for_browser_session(browser_session);
250        oauth2_filter = oauth2_filter.for_browser_session(browser_session);
251    }
252
253    if let Some(last_active_before) = filter.last_active_before() {
254        compat_filter = compat_filter.with_last_active_before(last_active_before);
255        oauth2_filter = oauth2_filter.with_last_active_before(last_active_before);
256    }
257
258    if let Some(last_active_after) = filter.last_active_after() {
259        compat_filter = compat_filter.with_last_active_after(last_active_after);
260        oauth2_filter = oauth2_filter.with_last_active_after(last_active_after);
261    }
262
263    if let Some(created_before) = filter.created_before() {
264        compat_filter = compat_filter.with_created_before(created_before);
265        oauth2_filter = oauth2_filter.with_created_before(created_before);
266    }
267
268    if let Some(created_after) = filter.created_after() {
269        compat_filter = compat_filter.with_created_after(created_after);
270        oauth2_filter = oauth2_filter.with_created_after(created_after);
271    }
272
273    (compat_filter, oauth2_filter)
274}
275
276#[async_trait]
277impl AppSessionRepository for PgAppSessionRepository<'_> {
278    type Error = DatabaseError;
279
280    #[tracing::instrument(
281        name = "db.app_session.list",
282        fields(
283            db.query.text,
284        ),
285        skip_all,
286        err,
287    )]
288    async fn list(
289        &mut self,
290        filter: AppSessionFilter<'_>,
291        pagination: Pagination,
292    ) -> Result<Page<AppSession>, Self::Error> {
293        let (compat_filter, oauth2_filter) = split_filter(filter);
294
295        let mut oauth2_session_select = Query::select()
296            .expr_as(
297                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
298                AppSessionLookupIden::Cursor,
299            )
300            .expr_as(Expr::cust("NULL"), AppSessionLookupIden::CompatSessionId)
301            .expr_as(
302                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
303                AppSessionLookupIden::Oauth2SessionId,
304            )
305            .expr_as(
306                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
307                AppSessionLookupIden::Oauth2ClientId,
308            )
309            .expr_as(
310                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
311                AppSessionLookupIden::UserSessionId,
312            )
313            .expr_as(
314                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
315                AppSessionLookupIden::UserId,
316            )
317            .expr_as(
318                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
319                AppSessionLookupIden::ScopeList,
320            )
321            .expr_as(Expr::cust("NULL"), AppSessionLookupIden::DeviceId)
322            .expr_as(
323                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
324                AppSessionLookupIden::HumanName,
325            )
326            .expr_as(
327                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
328                AppSessionLookupIden::CreatedAt,
329            )
330            .expr_as(
331                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
332                AppSessionLookupIden::FinishedAt,
333            )
334            .expr_as(Expr::cust("NULL"), AppSessionLookupIden::IsSynapseAdmin)
335            .expr_as(
336                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
337                AppSessionLookupIden::UserAgent,
338            )
339            .expr_as(
340                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
341                AppSessionLookupIden::LastActiveAt,
342            )
343            .expr_as(
344                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
345                AppSessionLookupIden::LastActiveIp,
346            )
347            .from(OAuth2Sessions::Table)
348            .apply_filter(oauth2_filter)
349            .clone();
350
351        let compat_session_select = Query::select()
352            .expr_as(
353                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
354                AppSessionLookupIden::Cursor,
355            )
356            .expr_as(
357                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
358                AppSessionLookupIden::CompatSessionId,
359            )
360            .expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2SessionId)
361            .expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2ClientId)
362            .expr_as(
363                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
364                AppSessionLookupIden::UserSessionId,
365            )
366            .expr_as(
367                Expr::col((CompatSessions::Table, CompatSessions::UserId)),
368                AppSessionLookupIden::UserId,
369            )
370            .expr_as(Expr::cust("NULL"), AppSessionLookupIden::ScopeList)
371            .expr_as(
372                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
373                AppSessionLookupIden::DeviceId,
374            )
375            .expr_as(
376                Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
377                AppSessionLookupIden::HumanName,
378            )
379            .expr_as(
380                Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
381                AppSessionLookupIden::CreatedAt,
382            )
383            .expr_as(
384                Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
385                AppSessionLookupIden::FinishedAt,
386            )
387            .expr_as(
388                Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
389                AppSessionLookupIden::IsSynapseAdmin,
390            )
391            .expr_as(
392                Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
393                AppSessionLookupIden::UserAgent,
394            )
395            .expr_as(
396                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
397                AppSessionLookupIden::LastActiveAt,
398            )
399            .expr_as(
400                Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
401                AppSessionLookupIden::LastActiveIp,
402            )
403            .from(CompatSessions::Table)
404            .apply_filter(compat_filter)
405            .clone();
406
407        let common_table_expression = CommonTableExpression::new()
408            .query(
409                oauth2_session_select
410                    .union(UnionType::All, compat_session_select)
411                    .clone(),
412            )
413            .table_name(Alias::new("sessions"))
414            .clone();
415
416        let with_clause = Query::with().cte(common_table_expression).clone();
417
418        let select = Query::select()
419            .column(ColumnRef::Asterisk)
420            .from(Alias::new("sessions"))
421            .generate_pagination(AppSessionLookupIden::Cursor, pagination)
422            .clone();
423
424        let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder);
425
426        let edges: Vec<AppSessionLookup> = sqlx::query_as_with(&sql, arguments)
427            .traced()
428            .fetch_all(&mut *self.conn)
429            .await?;
430
431        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
432
433        Ok(page)
434    }
435
436    #[tracing::instrument(
437        name = "db.app_session.count",
438        fields(
439            db.query.text,
440        ),
441        skip_all,
442        err,
443    )]
444    async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
445        let (compat_filter, oauth2_filter) = split_filter(filter);
446        let mut oauth2_session_select = Query::select()
447            .expr(Expr::cust("1"))
448            .from(OAuth2Sessions::Table)
449            .apply_filter(oauth2_filter)
450            .clone();
451
452        let compat_session_select = Query::select()
453            .expr(Expr::cust("1"))
454            .from(CompatSessions::Table)
455            .apply_filter(compat_filter)
456            .clone();
457
458        let common_table_expression = CommonTableExpression::new()
459            .query(
460                oauth2_session_select
461                    .union(UnionType::All, compat_session_select)
462                    .clone(),
463            )
464            .table_name(Alias::new("sessions"))
465            .clone();
466
467        let with_clause = Query::with().cte(common_table_expression).clone();
468
469        let select = Query::select()
470            .expr(Expr::cust("COUNT(*)"))
471            .from(Alias::new("sessions"))
472            .clone();
473
474        let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder);
475
476        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
477            .traced()
478            .fetch_one(&mut *self.conn)
479            .await?;
480
481        count
482            .try_into()
483            .map_err(DatabaseError::to_invalid_operation)
484    }
485
486    #[tracing::instrument(
487        name = "db.app_session.finish_sessions_to_replace_device",
488        fields(
489            db.query.text,
490            %user.id,
491            %device_id = device.as_str()
492        ),
493        skip_all,
494        err,
495    )]
496    async fn finish_sessions_to_replace_device(
497        &mut self,
498        clock: &dyn Clock,
499        user: &User,
500        device: &Device,
501    ) -> Result<bool, Self::Error> {
502        let mut affected = false;
503        // TODO need to invoke this from all the oauth2 login sites
504        let span = tracing::info_span!(
505            "db.app_session.finish_sessions_to_replace_device.compat_sessions",
506            { DB_QUERY_TEXT } = tracing::field::Empty,
507        );
508        let finished_at = clock.now();
509        let compat_affected = sqlx::query!(
510            "
511                UPDATE compat_sessions SET finished_at = $3 WHERE user_id = $1 AND device_id = $2 AND finished_at IS NULL
512            ",
513            Uuid::from(user.id),
514            device.as_str(),
515            finished_at
516        )
517        .record(&span)
518        .execute(&mut *self.conn)
519        .instrument(span)
520        .await?
521        .rows_affected();
522        affected |= compat_affected > 0;
523
524        if let Ok([stable_device_as_scope_token, unstable_device_as_scope_token]) =
525            device.to_scope_token()
526        {
527            let span = tracing::info_span!(
528                "db.app_session.finish_sessions_to_replace_device.oauth2_sessions",
529                { DB_QUERY_TEXT } = tracing::field::Empty,
530            );
531            let oauth2_affected = sqlx::query!(
532                "
533                    UPDATE oauth2_sessions
534                    SET finished_at = $4
535                    WHERE user_id = $1
536                      AND ($2 = ANY(scope_list) OR $3 = ANY(scope_list))
537                      AND finished_at IS NULL
538                ",
539                Uuid::from(user.id),
540                stable_device_as_scope_token.as_str(),
541                unstable_device_as_scope_token.as_str(),
542                finished_at
543            )
544            .record(&span)
545            .execute(&mut *self.conn)
546            .instrument(span)
547            .await?
548            .rows_affected();
549            affected |= oauth2_affected > 0;
550        }
551
552        Ok(affected)
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use chrono::Duration;
559    use mas_data_model::{Clock, Device, clock::MockClock};
560    use mas_storage::{
561        Pagination, RepositoryAccess,
562        app_session::{AppSession, AppSessionFilter},
563        oauth2::OAuth2SessionRepository,
564    };
565    use oauth2_types::{
566        requests::GrantType,
567        scope::{OPENID, Scope},
568    };
569    use rand::SeedableRng;
570    use rand_chacha::ChaChaRng;
571    use sqlx::PgPool;
572
573    use crate::PgRepository;
574
575    #[sqlx::test(migrator = "crate::MIGRATOR")]
576    async fn test_app_repo(pool: PgPool) {
577        let mut rng = ChaChaRng::seed_from_u64(42);
578        let clock = MockClock::default();
579        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
580
581        // Create a user
582        let user = repo
583            .user()
584            .add(&mut rng, &clock, "john".to_owned())
585            .await
586            .unwrap();
587
588        let all = AppSessionFilter::new().for_user(&user);
589        let active = all.active_only();
590        let finished = all.finished_only();
591        let pagination = Pagination::first(10);
592
593        assert_eq!(repo.app_session().count(all).await.unwrap(), 0);
594        assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
595        assert_eq!(repo.app_session().count(finished).await.unwrap(), 0);
596
597        let full_list = repo.app_session().list(all, pagination).await.unwrap();
598        assert!(full_list.edges.is_empty());
599        let active_list = repo.app_session().list(active, pagination).await.unwrap();
600        assert!(active_list.edges.is_empty());
601        let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
602        assert!(finished_list.edges.is_empty());
603
604        // Start a compat session for that user
605        let device = Device::generate(&mut rng);
606        let compat_session = repo
607            .compat_session()
608            .add(&mut rng, &clock, &user, device.clone(), None, false, None)
609            .await
610            .unwrap();
611
612        assert_eq!(repo.app_session().count(all).await.unwrap(), 1);
613        assert_eq!(repo.app_session().count(active).await.unwrap(), 1);
614        assert_eq!(repo.app_session().count(finished).await.unwrap(), 0);
615
616        let full_list = repo.app_session().list(all, pagination).await.unwrap();
617        assert_eq!(full_list.edges.len(), 1);
618        assert_eq!(
619            full_list.edges[0].node,
620            AppSession::Compat(Box::new(compat_session.clone()))
621        );
622        let active_list = repo.app_session().list(active, pagination).await.unwrap();
623        assert_eq!(active_list.edges.len(), 1);
624        assert_eq!(
625            active_list.edges[0].node,
626            AppSession::Compat(Box::new(compat_session.clone()))
627        );
628        let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
629        assert!(finished_list.edges.is_empty());
630
631        // Finish the session
632        let compat_session = repo
633            .compat_session()
634            .finish(&clock, compat_session)
635            .await
636            .unwrap();
637
638        assert_eq!(repo.app_session().count(all).await.unwrap(), 1);
639        assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
640        assert_eq!(repo.app_session().count(finished).await.unwrap(), 1);
641
642        let full_list = repo.app_session().list(all, pagination).await.unwrap();
643        assert_eq!(full_list.edges.len(), 1);
644        assert_eq!(
645            full_list.edges[0].node,
646            AppSession::Compat(Box::new(compat_session.clone()))
647        );
648        let active_list = repo.app_session().list(active, pagination).await.unwrap();
649        assert!(active_list.edges.is_empty());
650        let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
651        assert_eq!(finished_list.edges.len(), 1);
652        assert_eq!(
653            finished_list.edges[0].node,
654            AppSession::Compat(Box::new(compat_session.clone()))
655        );
656
657        // Start an OAuth2 session
658        let client = repo
659            .oauth2_client()
660            .add(
661                &mut rng,
662                &clock,
663                vec!["https://example.com/redirect".parse().unwrap()],
664                None,
665                None,
666                None,
667                vec![GrantType::AuthorizationCode],
668                Some("First client".to_owned()),
669                Some("https://example.com/logo.png".parse().unwrap()),
670                Some("https://example.com/".parse().unwrap()),
671                Some("https://example.com/policy".parse().unwrap()),
672                Some("https://example.com/tos".parse().unwrap()),
673                Some("https://example.com/jwks.json".parse().unwrap()),
674                None,
675                None,
676                None,
677                None,
678                None,
679                Some("https://example.com/login".parse().unwrap()),
680            )
681            .await
682            .unwrap();
683
684        let device2 = Device::generate(&mut rng);
685        let scope: Scope = [OPENID]
686            .into_iter()
687            .chain(device2.to_scope_token().unwrap().into_iter())
688            .collect();
689
690        // We're moving the clock forward by 1 minute between each session to ensure
691        // we're getting consistent ordering in lists.
692        clock.advance(Duration::try_minutes(1).unwrap());
693
694        let oauth_session = repo
695            .oauth2_session()
696            .add(&mut rng, &clock, &client, Some(&user), None, scope)
697            .await
698            .unwrap();
699
700        assert_eq!(repo.app_session().count(all).await.unwrap(), 2);
701        assert_eq!(repo.app_session().count(active).await.unwrap(), 1);
702        assert_eq!(repo.app_session().count(finished).await.unwrap(), 1);
703
704        let full_list = repo.app_session().list(all, pagination).await.unwrap();
705        assert_eq!(full_list.edges.len(), 2);
706        assert_eq!(
707            full_list.edges[0].node,
708            AppSession::Compat(Box::new(compat_session.clone()))
709        );
710        assert_eq!(
711            full_list.edges[1].node,
712            AppSession::OAuth2(Box::new(oauth_session.clone()))
713        );
714
715        let active_list = repo.app_session().list(active, pagination).await.unwrap();
716        assert_eq!(active_list.edges.len(), 1);
717        assert_eq!(
718            active_list.edges[0].node,
719            AppSession::OAuth2(Box::new(oauth_session.clone()))
720        );
721
722        let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
723        assert_eq!(finished_list.edges.len(), 1);
724        assert_eq!(
725            finished_list.edges[0].node,
726            AppSession::Compat(Box::new(compat_session.clone()))
727        );
728
729        // Finish the session
730        let oauth_session = repo
731            .oauth2_session()
732            .finish(&clock, oauth_session)
733            .await
734            .unwrap();
735
736        assert_eq!(repo.app_session().count(all).await.unwrap(), 2);
737        assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
738        assert_eq!(repo.app_session().count(finished).await.unwrap(), 2);
739
740        let full_list = repo.app_session().list(all, pagination).await.unwrap();
741        assert_eq!(full_list.edges.len(), 2);
742        assert_eq!(
743            full_list.edges[0].node,
744            AppSession::Compat(Box::new(compat_session.clone()))
745        );
746        assert_eq!(
747            full_list.edges[1].node,
748            AppSession::OAuth2(Box::new(oauth_session.clone()))
749        );
750
751        let active_list = repo.app_session().list(active, pagination).await.unwrap();
752        assert!(active_list.edges.is_empty());
753
754        let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
755        assert_eq!(finished_list.edges.len(), 2);
756        assert_eq!(
757            finished_list.edges[0].node,
758            AppSession::Compat(Box::new(compat_session.clone()))
759        );
760        assert_eq!(
761            full_list.edges[1].node,
762            AppSession::OAuth2(Box::new(oauth_session.clone()))
763        );
764
765        // Query by device
766        let filter = AppSessionFilter::new().for_device(&device);
767        assert_eq!(repo.app_session().count(filter).await.unwrap(), 1);
768        let list = repo.app_session().list(filter, pagination).await.unwrap();
769        assert_eq!(list.edges.len(), 1);
770        assert_eq!(
771            list.edges[0].node,
772            AppSession::Compat(Box::new(compat_session.clone()))
773        );
774
775        let filter = AppSessionFilter::new().for_device(&device2);
776        assert_eq!(repo.app_session().count(filter).await.unwrap(), 1);
777        let list = repo.app_session().list(filter, pagination).await.unwrap();
778        assert_eq!(list.edges.len(), 1);
779        assert_eq!(
780            list.edges[0].node,
781            AppSession::OAuth2(Box::new(oauth_session.clone()))
782        );
783
784        // Create a second user
785        let user2 = repo
786            .user()
787            .add(&mut rng, &clock, "alice".to_owned())
788            .await
789            .unwrap();
790
791        // If we list/count for this user, we should get nothing
792        let filter = AppSessionFilter::new().for_user(&user2);
793        assert_eq!(repo.app_session().count(filter).await.unwrap(), 0);
794        let list = repo.app_session().list(filter, pagination).await.unwrap();
795        assert!(list.edges.is_empty());
796    }
797
798    /// Test the created-at filters on [`AppSessionFilter`]: they should apply
799    /// to both the compat and the `OAuth2` branches of the union.
800    #[sqlx::test(migrator = "crate::MIGRATOR")]
801    async fn test_list_app_sessions_by_created_at(pool: PgPool) {
802        let mut rng = ChaChaRng::seed_from_u64(42);
803        let clock = MockClock::default();
804        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
805
806        let user = repo
807            .user()
808            .add(&mut rng, &clock, "alice".to_owned())
809            .await
810            .unwrap();
811
812        // Create a compat session
813        let device = Device::generate(&mut rng);
814        let compat_session = repo
815            .compat_session()
816            .add(&mut rng, &clock, &user, device, None, false, None)
817            .await
818            .unwrap();
819        clock.advance(Duration::try_minutes(1).unwrap());
820
821        // Capture a cutoff that sits between the compat session and the
822        // OAuth2 one
823        let cutoff = clock.now();
824        clock.advance(Duration::try_minutes(1).unwrap());
825
826        // Create an OAuth2 session for the same user
827        let client = repo
828            .oauth2_client()
829            .add(
830                &mut rng,
831                &clock,
832                vec!["https://example.com/redirect".parse().unwrap()],
833                None,
834                None,
835                None,
836                vec![GrantType::AuthorizationCode],
837                Some("Test client".to_owned()),
838                Some("https://example.com/logo.png".parse().unwrap()),
839                Some("https://example.com/".parse().unwrap()),
840                Some("https://example.com/policy".parse().unwrap()),
841                Some("https://example.com/tos".parse().unwrap()),
842                Some("https://example.com/jwks.json".parse().unwrap()),
843                None,
844                None,
845                None,
846                None,
847                None,
848                Some("https://example.com/login".parse().unwrap()),
849            )
850            .await
851            .unwrap();
852        let scope: Scope = [OPENID].into_iter().collect();
853        let oauth_session = repo
854            .oauth2_session()
855            .add(&mut rng, &clock, &client, Some(&user), None, scope)
856            .await
857            .unwrap();
858
859        let pagination = Pagination::first(10);
860
861        // Sessions created before the cutoff: only the compat one
862        let filter = AppSessionFilter::new().with_created_before(cutoff);
863        let list = repo.app_session().list(filter, pagination).await.unwrap();
864        assert_eq!(list.edges.len(), 1);
865        assert_eq!(
866            list.edges[0].node,
867            AppSession::Compat(Box::new(compat_session.clone()))
868        );
869        assert_eq!(repo.app_session().count(filter).await.unwrap(), 1);
870
871        // Sessions created after the cutoff: only the OAuth2 one
872        let filter = AppSessionFilter::new().with_created_after(cutoff);
873        let list = repo.app_session().list(filter, pagination).await.unwrap();
874        assert_eq!(list.edges.len(), 1);
875        assert_eq!(
876            list.edges[0].node,
877            AppSession::OAuth2(Box::new(oauth_session.clone()))
878        );
879        assert_eq!(repo.app_session().count(filter).await.unwrap(), 1);
880    }
881}