Skip to main content

mas_storage/oauth2/
session.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Client, Clock, Device, Session, User};
13use oauth2_types::scope::Scope;
14use rand_core::RngCore;
15use ulid::Ulid;
16
17use crate::{Pagination, pagination::Page, repository_impl, user::BrowserSessionFilter};
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum OAuth2SessionState {
21    Active,
22    Finished,
23}
24
25impl OAuth2SessionState {
26    pub fn is_active(self) -> bool {
27        matches!(self, Self::Active)
28    }
29
30    pub fn is_finished(self) -> bool {
31        matches!(self, Self::Finished)
32    }
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
36pub enum ClientKind {
37    Static,
38    Dynamic,
39}
40
41impl ClientKind {
42    pub fn is_static(self) -> bool {
43        matches!(self, Self::Static)
44    }
45}
46
47/// Filter parameters for listing OAuth 2.0 sessions
48#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
49pub struct OAuth2SessionFilter<'a> {
50    user: Option<&'a User>,
51    any_user: Option<bool>,
52    browser_session: Option<&'a BrowserSession>,
53    browser_session_filter: Option<BrowserSessionFilter<'a>>,
54    device: Option<&'a Device>,
55    client: Option<&'a Client>,
56    clients: Option<&'a [&'a Client]>,
57    client_kind: Option<ClientKind>,
58    state: Option<OAuth2SessionState>,
59    scope: Option<&'a Scope>,
60    last_active_before: Option<DateTime<Utc>>,
61    last_active_after: Option<DateTime<Utc>>,
62    created_before: Option<DateTime<Utc>>,
63    created_after: Option<DateTime<Utc>>,
64}
65
66impl<'a> OAuth2SessionFilter<'a> {
67    /// Create a new [`OAuth2SessionFilter`] with default values
68    #[must_use]
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// List sessions for a specific user
74    #[must_use]
75    pub fn for_user(mut self, user: &'a User) -> Self {
76        self.user = Some(user);
77        self
78    }
79
80    /// Get the user filter
81    ///
82    /// Returns [`None`] if no user filter was set
83    #[must_use]
84    pub fn user(&self) -> Option<&'a User> {
85        self.user
86    }
87
88    /// List sessions which belong to any user
89    #[must_use]
90    pub fn for_any_user(mut self) -> Self {
91        self.any_user = Some(true);
92        self
93    }
94
95    /// List sessions which belong to no user
96    #[must_use]
97    pub fn for_no_user(mut self) -> Self {
98        self.any_user = Some(false);
99        self
100    }
101
102    /// Get the 'any user' filter
103    ///
104    /// Returns [`None`] if no 'any user' filter was set
105    #[must_use]
106    pub fn any_user(&self) -> Option<bool> {
107        self.any_user
108    }
109
110    /// List sessions started by a specific browser session
111    #[must_use]
112    pub fn for_browser_session(mut self, browser_session: &'a BrowserSession) -> Self {
113        self.browser_session = Some(browser_session);
114        self
115    }
116
117    /// List sessions started by a set of browser sessions
118    #[must_use]
119    pub fn for_browser_sessions(
120        mut self,
121        browser_session_filter: BrowserSessionFilter<'a>,
122    ) -> Self {
123        self.browser_session_filter = Some(browser_session_filter);
124        self
125    }
126
127    /// Get the browser session filter
128    ///
129    /// Returns [`None`] if no browser session filter was set
130    #[must_use]
131    pub fn browser_session(&self) -> Option<&'a BrowserSession> {
132        self.browser_session
133    }
134
135    /// Get the browser sessions filter
136    ///
137    /// Returns [`None`] if no browser session filter was set
138    #[must_use]
139    pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
140        self.browser_session_filter
141    }
142
143    /// List sessions for a specific client
144    #[must_use]
145    pub fn for_client(mut self, client: &'a Client) -> Self {
146        self.client = Some(client);
147        self
148    }
149
150    /// Get the client filter
151    ///
152    /// Returns [`None`] if no client filter was set
153    #[must_use]
154    pub fn client(&self) -> Option<&'a Client> {
155        self.client
156    }
157
158    /// List sessions for a set of clients
159    ///
160    /// This filter is independent of [`Self::for_client`]: if both are set,
161    /// the conditions are `AND`-ed together. In practice an API caller uses
162    /// one or the other.
163    #[must_use]
164    pub fn for_clients(mut self, clients: &'a [&'a Client]) -> Self {
165        self.clients = Some(clients);
166        self
167    }
168
169    /// Get the multi-client filter
170    ///
171    /// Returns [`None`] if no multi-client filter was set
172    #[must_use]
173    pub fn clients(&self) -> Option<&'a [&'a Client]> {
174        self.clients
175    }
176
177    /// List only static clients
178    #[must_use]
179    pub fn only_static_clients(mut self) -> Self {
180        self.client_kind = Some(ClientKind::Static);
181        self
182    }
183
184    /// List only dynamic clients
185    #[must_use]
186    pub fn only_dynamic_clients(mut self) -> Self {
187        self.client_kind = Some(ClientKind::Dynamic);
188        self
189    }
190
191    /// Get the client kind filter
192    ///
193    /// Returns [`None`] if no client kind filter was set
194    #[must_use]
195    pub fn client_kind(&self) -> Option<ClientKind> {
196        self.client_kind
197    }
198
199    /// Only return sessions with a last active time before the given time
200    #[must_use]
201    pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {
202        self.last_active_before = Some(last_active_before);
203        self
204    }
205
206    /// Only return sessions with a last active time after the given time
207    #[must_use]
208    pub fn with_last_active_after(mut self, last_active_after: DateTime<Utc>) -> Self {
209        self.last_active_after = Some(last_active_after);
210        self
211    }
212
213    /// Get the last active before filter
214    ///
215    /// Returns [`None`] if no client filter was set
216    #[must_use]
217    pub fn last_active_before(&self) -> Option<DateTime<Utc>> {
218        self.last_active_before
219    }
220
221    /// Get the last active after filter
222    ///
223    /// Returns [`None`] if no client filter was set
224    #[must_use]
225    pub fn last_active_after(&self) -> Option<DateTime<Utc>> {
226        self.last_active_after
227    }
228
229    /// Only return sessions created before the given time
230    #[must_use]
231    pub fn with_created_before(mut self, created_before: DateTime<Utc>) -> Self {
232        self.created_before = Some(created_before);
233        self
234    }
235
236    /// Only return sessions created after the given time
237    #[must_use]
238    pub fn with_created_after(mut self, created_after: DateTime<Utc>) -> Self {
239        self.created_after = Some(created_after);
240        self
241    }
242
243    /// Get the created-before filter
244    ///
245    /// Returns [`None`] if no filter was set
246    #[must_use]
247    pub fn created_before(&self) -> Option<DateTime<Utc>> {
248        self.created_before
249    }
250
251    /// Get the created-after filter
252    ///
253    /// Returns [`None`] if no filter was set
254    #[must_use]
255    pub fn created_after(&self) -> Option<DateTime<Utc>> {
256        self.created_after
257    }
258
259    /// Only return active sessions
260    #[must_use]
261    pub fn active_only(mut self) -> Self {
262        self.state = Some(OAuth2SessionState::Active);
263        self
264    }
265
266    /// Only return finished sessions
267    #[must_use]
268    pub fn finished_only(mut self) -> Self {
269        self.state = Some(OAuth2SessionState::Finished);
270        self
271    }
272
273    /// Get the state filter
274    ///
275    /// Returns [`None`] if no state filter was set
276    #[must_use]
277    pub fn state(&self) -> Option<OAuth2SessionState> {
278        self.state
279    }
280
281    /// Only return sessions with the given scope
282    #[must_use]
283    pub fn with_scope(mut self, scope: &'a Scope) -> Self {
284        self.scope = Some(scope);
285        self
286    }
287
288    /// Get the scope filter
289    ///
290    /// Returns [`None`] if no scope filter was set
291    #[must_use]
292    pub fn scope(&self) -> Option<&'a Scope> {
293        self.scope
294    }
295
296    /// Only return sessions that have the given device in their scope
297    #[must_use]
298    pub fn for_device(mut self, device: &'a Device) -> Self {
299        self.device = Some(device);
300        self
301    }
302
303    /// Get the device filter
304    ///
305    /// Returns [`None`] if no device filter was set
306    #[must_use]
307    pub fn device(&self) -> Option<&'a Device> {
308        self.device
309    }
310}
311
312/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
313/// saved in the storage backend
314#[async_trait]
315pub trait OAuth2SessionRepository: Send + Sync {
316    /// The error type returned by the repository
317    type Error;
318
319    /// Lookup an [`Session`] by its ID
320    ///
321    /// Returns `None` if no [`Session`] was found
322    ///
323    /// # Parameters
324    ///
325    /// * `id`: The ID of the [`Session`] to lookup
326    ///
327    /// # Errors
328    ///
329    /// Returns [`Self::Error`] if the underlying repository fails
330    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
331
332    /// Create a new [`Session`] with the given parameters
333    ///
334    /// Returns the newly created [`Session`]
335    ///
336    /// # Parameters
337    ///
338    /// * `rng`: The random number generator to use
339    /// * `clock`: The clock used to generate timestamps
340    /// * `client`: The [`Client`] which created the [`Session`]
341    /// * `user`: The [`User`] for which the session should be created, if any
342    /// * `user_session`: The [`BrowserSession`] of the user which completed the
343    ///   authorization, if any
344    /// * `scope`: The [`Scope`] of the [`Session`]
345    ///
346    /// # Errors
347    ///
348    /// Returns [`Self::Error`] if the underlying repository fails
349    async fn add(
350        &mut self,
351        rng: &mut (dyn RngCore + Send),
352        clock: &dyn Clock,
353        client: &Client,
354        user: Option<&User>,
355        user_session: Option<&BrowserSession>,
356        scope: Scope,
357    ) -> Result<Session, Self::Error>;
358
359    /// Create a new [`Session`] out of a [`Client`] and a [`BrowserSession`]
360    ///
361    /// Returns the newly created [`Session`]
362    ///
363    /// # Parameters
364    ///
365    /// * `rng`: The random number generator to use
366    /// * `clock`: The clock used to generate timestamps
367    /// * `client`: The [`Client`] which created the [`Session`]
368    /// * `user_session`: The [`BrowserSession`] of the user which completed the
369    ///   authorization
370    /// * `scope`: The [`Scope`] of the [`Session`]
371    ///
372    /// # Errors
373    ///
374    /// Returns [`Self::Error`] if the underlying repository fails
375    async fn add_from_browser_session(
376        &mut self,
377        rng: &mut (dyn RngCore + Send),
378        clock: &dyn Clock,
379        client: &Client,
380        user_session: &BrowserSession,
381        scope: Scope,
382    ) -> Result<Session, Self::Error> {
383        self.add(
384            rng,
385            clock,
386            client,
387            Some(&user_session.user),
388            Some(user_session),
389            scope,
390        )
391        .await
392    }
393
394    /// Create a new [`Session`] for a [`Client`] using the client credentials
395    /// flow
396    ///
397    /// Returns the newly created [`Session`]
398    ///
399    /// # Parameters
400    ///
401    /// * `rng`: The random number generator to use
402    /// * `clock`: The clock used to generate timestamps
403    /// * `client`: The [`Client`] which created the [`Session`]
404    /// * `scope`: The [`Scope`] of the [`Session`]
405    ///
406    /// # Errors
407    ///
408    /// Returns [`Self::Error`] if the underlying repository fails
409    async fn add_from_client_credentials(
410        &mut self,
411        rng: &mut (dyn RngCore + Send),
412        clock: &dyn Clock,
413        client: &Client,
414        scope: Scope,
415    ) -> Result<Session, Self::Error> {
416        self.add(rng, clock, client, None, None, scope).await
417    }
418
419    /// Mark a [`Session`] as finished
420    ///
421    /// Returns the updated [`Session`]
422    ///
423    /// # Parameters
424    ///
425    /// * `clock`: The clock used to generate timestamps
426    /// * `session`: The [`Session`] to mark as finished
427    ///
428    /// # Errors
429    ///
430    /// Returns [`Self::Error`] if the underlying repository fails
431    async fn finish(&mut self, clock: &dyn Clock, session: Session)
432    -> Result<Session, Self::Error>;
433
434    /// Mark all the [`Session`] matching the given filter as finished
435    ///
436    /// Returns the number of sessions affected
437    ///
438    /// # Parameters
439    ///
440    /// * `clock`: The clock used to generate timestamps
441    /// * `filter`: The filter parameters
442    ///
443    /// # Errors
444    ///
445    /// Returns [`Self::Error`] if the underlying repository fails
446    async fn finish_bulk(
447        &mut self,
448        clock: &dyn Clock,
449        filter: OAuth2SessionFilter<'_>,
450    ) -> Result<usize, Self::Error>;
451
452    /// List [`Session`]s matching the given filter and pagination parameters
453    ///
454    /// # Parameters
455    ///
456    /// * `filter`: The filter parameters
457    /// * `pagination`: The pagination parameters
458    ///
459    /// # Errors
460    ///
461    /// Returns [`Self::Error`] if the underlying repository fails
462    async fn list(
463        &mut self,
464        filter: OAuth2SessionFilter<'_>,
465        pagination: Pagination,
466    ) -> Result<Page<Session>, Self::Error>;
467
468    /// Count [`Session`]s matching the given filter
469    ///
470    /// # Parameters
471    ///
472    /// * `filter`: The filter parameters
473    ///
474    /// # Errors
475    ///
476    /// Returns [`Self::Error`] if the underlying repository fails
477    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
478
479    /// Record a batch of [`Session`] activity
480    ///
481    /// # Parameters
482    ///
483    /// * `activity`: A list of tuples containing the session ID, the last
484    ///   activity timestamp and the IP address of the client
485    ///
486    /// # Errors
487    ///
488    /// Returns [`Self::Error`] if the underlying repository fails
489    async fn record_batch_activity(
490        &mut self,
491        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
492    ) -> Result<(), Self::Error>;
493
494    /// Record the user agent of a [`Session`]
495    ///
496    /// # Parameters
497    ///
498    /// * `session`: The [`Session`] to record the user agent for
499    /// * `user_agent`: The user agent to record
500    async fn record_user_agent(
501        &mut self,
502        session: Session,
503        user_agent: String,
504    ) -> Result<Session, Self::Error>;
505
506    /// Set the human name of a [`Session`]
507    ///
508    /// # Parameters
509    ///
510    /// * `session`: The [`Session`] to set the human name for
511    /// * `human_name`: The human name to set
512    async fn set_human_name(
513        &mut self,
514        session: Session,
515        human_name: Option<String>,
516    ) -> Result<Session, Self::Error>;
517
518    /// Cleanup finished [`Session`]s
519    ///
520    /// Deletes sessions finished between `since` and `until`. Returns the
521    /// number of deleted sessions and the timestamp of the last deleted
522    /// session for pagination.
523    ///
524    /// # Parameters
525    ///
526    /// * `since`: The earliest finish time to delete (exclusive). If `None`,
527    ///   starts from the beginning.
528    /// * `until`: The latest finish time to delete (exclusive)
529    /// * `limit`: Maximum number of sessions to delete in this batch
530    ///
531    /// # Errors
532    ///
533    /// Returns [`Self::Error`] if the underlying repository fails
534    async fn cleanup_finished(
535        &mut self,
536        since: Option<DateTime<Utc>>,
537        until: DateTime<Utc>,
538        limit: usize,
539    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
540
541    /// Clear IP addresses from sessions inactive since the threshold
542    ///
543    /// Sets `last_active_ip` to `NULL` for sessions where `last_active_at` is
544    /// before the threshold. Returns the number of sessions affected and the
545    /// last `last_active_at` timestamp processed for pagination.
546    ///
547    /// # Parameters
548    ///
549    /// * `since`: Only process sessions with `last_active_at` at or after this
550    ///   timestamp (exclusive). If `None`, starts from the beginning.
551    /// * `threshold`: Clear IPs for sessions with `last_active_at` before this
552    ///   time
553    /// * `limit`: Maximum number of sessions to update in this batch
554    ///
555    /// # Errors
556    ///
557    /// Returns [`Self::Error`] if the underlying repository fails
558    async fn cleanup_inactive_ips(
559        &mut self,
560        since: Option<DateTime<Utc>>,
561        threshold: DateTime<Utc>,
562        limit: usize,
563    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
564}
565
566repository_impl!(OAuth2SessionRepository:
567    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
568
569    async fn add(
570        &mut self,
571        rng: &mut (dyn RngCore + Send),
572        clock: &dyn Clock,
573        client: &Client,
574        user: Option<&User>,
575        user_session: Option<&BrowserSession>,
576        scope: Scope,
577    ) -> Result<Session, Self::Error>;
578
579    async fn add_from_browser_session(
580        &mut self,
581        rng: &mut (dyn RngCore + Send),
582        clock: &dyn Clock,
583        client: &Client,
584        user_session: &BrowserSession,
585        scope: Scope,
586    ) -> Result<Session, Self::Error>;
587
588    async fn add_from_client_credentials(
589        &mut self,
590        rng: &mut (dyn RngCore + Send),
591        clock: &dyn Clock,
592        client: &Client,
593        scope: Scope,
594    ) -> Result<Session, Self::Error>;
595
596    async fn finish(&mut self, clock: &dyn Clock, session: Session)
597        -> Result<Session, Self::Error>;
598
599    async fn finish_bulk(
600        &mut self,
601        clock: &dyn Clock,
602        filter: OAuth2SessionFilter<'_>,
603    ) -> Result<usize, Self::Error>;
604
605    async fn list(
606        &mut self,
607        filter: OAuth2SessionFilter<'_>,
608        pagination: Pagination,
609    ) -> Result<Page<Session>, Self::Error>;
610
611    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
612
613    async fn record_batch_activity(
614        &mut self,
615        activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
616    ) -> Result<(), Self::Error>;
617
618    async fn record_user_agent(
619        &mut self,
620        session: Session,
621        user_agent: String,
622    ) -> Result<Session, Self::Error>;
623
624    async fn set_human_name(
625        &mut self,
626        session: Session,
627        human_name: Option<String>,
628    ) -> Result<Session, Self::Error>;
629
630    async fn cleanup_finished(
631        &mut self,
632        since: Option<DateTime<Utc>>,
633        until: DateTime<Utc>,
634        limit: usize,
635    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
636
637    async fn cleanup_inactive_ips(
638        &mut self,
639        since: Option<DateTime<Utc>>,
640        threshold: DateTime<Utc>,
641        limit: usize,
642    ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
643);