mas_storage/oauth2/session.rs
1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Client, Clock, Device, Session, User};
13use oauth2_types::scope::Scope;
14use rand_core::RngCore;
15use ulid::Ulid;
16
17use crate::{Pagination, pagination::Page, repository_impl, user::BrowserSessionFilter};
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum OAuth2SessionState {
21 Active,
22 Finished,
23}
24
25impl OAuth2SessionState {
26 pub fn is_active(self) -> bool {
27 matches!(self, Self::Active)
28 }
29
30 pub fn is_finished(self) -> bool {
31 matches!(self, Self::Finished)
32 }
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
36pub enum ClientKind {
37 Static,
38 Dynamic,
39}
40
41impl ClientKind {
42 pub fn is_static(self) -> bool {
43 matches!(self, Self::Static)
44 }
45}
46
47/// Filter parameters for listing OAuth 2.0 sessions
48#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
49pub struct OAuth2SessionFilter<'a> {
50 user: Option<&'a User>,
51 any_user: Option<bool>,
52 browser_session: Option<&'a BrowserSession>,
53 browser_session_filter: Option<BrowserSessionFilter<'a>>,
54 device: Option<&'a Device>,
55 client: Option<&'a Client>,
56 clients: Option<&'a [&'a Client]>,
57 client_kind: Option<ClientKind>,
58 state: Option<OAuth2SessionState>,
59 scope: Option<&'a Scope>,
60 last_active_before: Option<DateTime<Utc>>,
61 last_active_after: Option<DateTime<Utc>>,
62 created_before: Option<DateTime<Utc>>,
63 created_after: Option<DateTime<Utc>>,
64}
65
66impl<'a> OAuth2SessionFilter<'a> {
67 /// Create a new [`OAuth2SessionFilter`] with default values
68 #[must_use]
69 pub fn new() -> Self {
70 Self::default()
71 }
72
73 /// List sessions for a specific user
74 #[must_use]
75 pub fn for_user(mut self, user: &'a User) -> Self {
76 self.user = Some(user);
77 self
78 }
79
80 /// Get the user filter
81 ///
82 /// Returns [`None`] if no user filter was set
83 #[must_use]
84 pub fn user(&self) -> Option<&'a User> {
85 self.user
86 }
87
88 /// List sessions which belong to any user
89 #[must_use]
90 pub fn for_any_user(mut self) -> Self {
91 self.any_user = Some(true);
92 self
93 }
94
95 /// List sessions which belong to no user
96 #[must_use]
97 pub fn for_no_user(mut self) -> Self {
98 self.any_user = Some(false);
99 self
100 }
101
102 /// Get the 'any user' filter
103 ///
104 /// Returns [`None`] if no 'any user' filter was set
105 #[must_use]
106 pub fn any_user(&self) -> Option<bool> {
107 self.any_user
108 }
109
110 /// List sessions started by a specific browser session
111 #[must_use]
112 pub fn for_browser_session(mut self, browser_session: &'a BrowserSession) -> Self {
113 self.browser_session = Some(browser_session);
114 self
115 }
116
117 /// List sessions started by a set of browser sessions
118 #[must_use]
119 pub fn for_browser_sessions(
120 mut self,
121 browser_session_filter: BrowserSessionFilter<'a>,
122 ) -> Self {
123 self.browser_session_filter = Some(browser_session_filter);
124 self
125 }
126
127 /// Get the browser session filter
128 ///
129 /// Returns [`None`] if no browser session filter was set
130 #[must_use]
131 pub fn browser_session(&self) -> Option<&'a BrowserSession> {
132 self.browser_session
133 }
134
135 /// Get the browser sessions filter
136 ///
137 /// Returns [`None`] if no browser session filter was set
138 #[must_use]
139 pub fn browser_session_filter(&self) -> Option<BrowserSessionFilter<'a>> {
140 self.browser_session_filter
141 }
142
143 /// List sessions for a specific client
144 #[must_use]
145 pub fn for_client(mut self, client: &'a Client) -> Self {
146 self.client = Some(client);
147 self
148 }
149
150 /// Get the client filter
151 ///
152 /// Returns [`None`] if no client filter was set
153 #[must_use]
154 pub fn client(&self) -> Option<&'a Client> {
155 self.client
156 }
157
158 /// List sessions for a set of clients
159 ///
160 /// This filter is independent of [`Self::for_client`]: if both are set,
161 /// the conditions are `AND`-ed together. In practice an API caller uses
162 /// one or the other.
163 #[must_use]
164 pub fn for_clients(mut self, clients: &'a [&'a Client]) -> Self {
165 self.clients = Some(clients);
166 self
167 }
168
169 /// Get the multi-client filter
170 ///
171 /// Returns [`None`] if no multi-client filter was set
172 #[must_use]
173 pub fn clients(&self) -> Option<&'a [&'a Client]> {
174 self.clients
175 }
176
177 /// List only static clients
178 #[must_use]
179 pub fn only_static_clients(mut self) -> Self {
180 self.client_kind = Some(ClientKind::Static);
181 self
182 }
183
184 /// List only dynamic clients
185 #[must_use]
186 pub fn only_dynamic_clients(mut self) -> Self {
187 self.client_kind = Some(ClientKind::Dynamic);
188 self
189 }
190
191 /// Get the client kind filter
192 ///
193 /// Returns [`None`] if no client kind filter was set
194 #[must_use]
195 pub fn client_kind(&self) -> Option<ClientKind> {
196 self.client_kind
197 }
198
199 /// Only return sessions with a last active time before the given time
200 #[must_use]
201 pub fn with_last_active_before(mut self, last_active_before: DateTime<Utc>) -> Self {
202 self.last_active_before = Some(last_active_before);
203 self
204 }
205
206 /// Only return sessions with a last active time after the given time
207 #[must_use]
208 pub fn with_last_active_after(mut self, last_active_after: DateTime<Utc>) -> Self {
209 self.last_active_after = Some(last_active_after);
210 self
211 }
212
213 /// Get the last active before filter
214 ///
215 /// Returns [`None`] if no client filter was set
216 #[must_use]
217 pub fn last_active_before(&self) -> Option<DateTime<Utc>> {
218 self.last_active_before
219 }
220
221 /// Get the last active after filter
222 ///
223 /// Returns [`None`] if no client filter was set
224 #[must_use]
225 pub fn last_active_after(&self) -> Option<DateTime<Utc>> {
226 self.last_active_after
227 }
228
229 /// Only return sessions created before the given time
230 #[must_use]
231 pub fn with_created_before(mut self, created_before: DateTime<Utc>) -> Self {
232 self.created_before = Some(created_before);
233 self
234 }
235
236 /// Only return sessions created after the given time
237 #[must_use]
238 pub fn with_created_after(mut self, created_after: DateTime<Utc>) -> Self {
239 self.created_after = Some(created_after);
240 self
241 }
242
243 /// Get the created-before filter
244 ///
245 /// Returns [`None`] if no filter was set
246 #[must_use]
247 pub fn created_before(&self) -> Option<DateTime<Utc>> {
248 self.created_before
249 }
250
251 /// Get the created-after filter
252 ///
253 /// Returns [`None`] if no filter was set
254 #[must_use]
255 pub fn created_after(&self) -> Option<DateTime<Utc>> {
256 self.created_after
257 }
258
259 /// Only return active sessions
260 #[must_use]
261 pub fn active_only(mut self) -> Self {
262 self.state = Some(OAuth2SessionState::Active);
263 self
264 }
265
266 /// Only return finished sessions
267 #[must_use]
268 pub fn finished_only(mut self) -> Self {
269 self.state = Some(OAuth2SessionState::Finished);
270 self
271 }
272
273 /// Get the state filter
274 ///
275 /// Returns [`None`] if no state filter was set
276 #[must_use]
277 pub fn state(&self) -> Option<OAuth2SessionState> {
278 self.state
279 }
280
281 /// Only return sessions with the given scope
282 #[must_use]
283 pub fn with_scope(mut self, scope: &'a Scope) -> Self {
284 self.scope = Some(scope);
285 self
286 }
287
288 /// Get the scope filter
289 ///
290 /// Returns [`None`] if no scope filter was set
291 #[must_use]
292 pub fn scope(&self) -> Option<&'a Scope> {
293 self.scope
294 }
295
296 /// Only return sessions that have the given device in their scope
297 #[must_use]
298 pub fn for_device(mut self, device: &'a Device) -> Self {
299 self.device = Some(device);
300 self
301 }
302
303 /// Get the device filter
304 ///
305 /// Returns [`None`] if no device filter was set
306 #[must_use]
307 pub fn device(&self) -> Option<&'a Device> {
308 self.device
309 }
310}
311
312/// An [`OAuth2SessionRepository`] helps interacting with [`Session`]
313/// saved in the storage backend
314#[async_trait]
315pub trait OAuth2SessionRepository: Send + Sync {
316 /// The error type returned by the repository
317 type Error;
318
319 /// Lookup an [`Session`] by its ID
320 ///
321 /// Returns `None` if no [`Session`] was found
322 ///
323 /// # Parameters
324 ///
325 /// * `id`: The ID of the [`Session`] to lookup
326 ///
327 /// # Errors
328 ///
329 /// Returns [`Self::Error`] if the underlying repository fails
330 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
331
332 /// Create a new [`Session`] with the given parameters
333 ///
334 /// Returns the newly created [`Session`]
335 ///
336 /// # Parameters
337 ///
338 /// * `rng`: The random number generator to use
339 /// * `clock`: The clock used to generate timestamps
340 /// * `client`: The [`Client`] which created the [`Session`]
341 /// * `user`: The [`User`] for which the session should be created, if any
342 /// * `user_session`: The [`BrowserSession`] of the user which completed the
343 /// authorization, if any
344 /// * `scope`: The [`Scope`] of the [`Session`]
345 ///
346 /// # Errors
347 ///
348 /// Returns [`Self::Error`] if the underlying repository fails
349 async fn add(
350 &mut self,
351 rng: &mut (dyn RngCore + Send),
352 clock: &dyn Clock,
353 client: &Client,
354 user: Option<&User>,
355 user_session: Option<&BrowserSession>,
356 scope: Scope,
357 ) -> Result<Session, Self::Error>;
358
359 /// Create a new [`Session`] out of a [`Client`] and a [`BrowserSession`]
360 ///
361 /// Returns the newly created [`Session`]
362 ///
363 /// # Parameters
364 ///
365 /// * `rng`: The random number generator to use
366 /// * `clock`: The clock used to generate timestamps
367 /// * `client`: The [`Client`] which created the [`Session`]
368 /// * `user_session`: The [`BrowserSession`] of the user which completed the
369 /// authorization
370 /// * `scope`: The [`Scope`] of the [`Session`]
371 ///
372 /// # Errors
373 ///
374 /// Returns [`Self::Error`] if the underlying repository fails
375 async fn add_from_browser_session(
376 &mut self,
377 rng: &mut (dyn RngCore + Send),
378 clock: &dyn Clock,
379 client: &Client,
380 user_session: &BrowserSession,
381 scope: Scope,
382 ) -> Result<Session, Self::Error> {
383 self.add(
384 rng,
385 clock,
386 client,
387 Some(&user_session.user),
388 Some(user_session),
389 scope,
390 )
391 .await
392 }
393
394 /// Create a new [`Session`] for a [`Client`] using the client credentials
395 /// flow
396 ///
397 /// Returns the newly created [`Session`]
398 ///
399 /// # Parameters
400 ///
401 /// * `rng`: The random number generator to use
402 /// * `clock`: The clock used to generate timestamps
403 /// * `client`: The [`Client`] which created the [`Session`]
404 /// * `scope`: The [`Scope`] of the [`Session`]
405 ///
406 /// # Errors
407 ///
408 /// Returns [`Self::Error`] if the underlying repository fails
409 async fn add_from_client_credentials(
410 &mut self,
411 rng: &mut (dyn RngCore + Send),
412 clock: &dyn Clock,
413 client: &Client,
414 scope: Scope,
415 ) -> Result<Session, Self::Error> {
416 self.add(rng, clock, client, None, None, scope).await
417 }
418
419 /// Mark a [`Session`] as finished
420 ///
421 /// Returns the updated [`Session`]
422 ///
423 /// # Parameters
424 ///
425 /// * `clock`: The clock used to generate timestamps
426 /// * `session`: The [`Session`] to mark as finished
427 ///
428 /// # Errors
429 ///
430 /// Returns [`Self::Error`] if the underlying repository fails
431 async fn finish(&mut self, clock: &dyn Clock, session: Session)
432 -> Result<Session, Self::Error>;
433
434 /// Mark all the [`Session`] matching the given filter as finished
435 ///
436 /// Returns the number of sessions affected
437 ///
438 /// # Parameters
439 ///
440 /// * `clock`: The clock used to generate timestamps
441 /// * `filter`: The filter parameters
442 ///
443 /// # Errors
444 ///
445 /// Returns [`Self::Error`] if the underlying repository fails
446 async fn finish_bulk(
447 &mut self,
448 clock: &dyn Clock,
449 filter: OAuth2SessionFilter<'_>,
450 ) -> Result<usize, Self::Error>;
451
452 /// List [`Session`]s matching the given filter and pagination parameters
453 ///
454 /// # Parameters
455 ///
456 /// * `filter`: The filter parameters
457 /// * `pagination`: The pagination parameters
458 ///
459 /// # Errors
460 ///
461 /// Returns [`Self::Error`] if the underlying repository fails
462 async fn list(
463 &mut self,
464 filter: OAuth2SessionFilter<'_>,
465 pagination: Pagination,
466 ) -> Result<Page<Session>, Self::Error>;
467
468 /// Count [`Session`]s matching the given filter
469 ///
470 /// # Parameters
471 ///
472 /// * `filter`: The filter parameters
473 ///
474 /// # Errors
475 ///
476 /// Returns [`Self::Error`] if the underlying repository fails
477 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
478
479 /// Record a batch of [`Session`] activity
480 ///
481 /// # Parameters
482 ///
483 /// * `activity`: A list of tuples containing the session ID, the last
484 /// activity timestamp and the IP address of the client
485 ///
486 /// # Errors
487 ///
488 /// Returns [`Self::Error`] if the underlying repository fails
489 async fn record_batch_activity(
490 &mut self,
491 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
492 ) -> Result<(), Self::Error>;
493
494 /// Record the user agent of a [`Session`]
495 ///
496 /// # Parameters
497 ///
498 /// * `session`: The [`Session`] to record the user agent for
499 /// * `user_agent`: The user agent to record
500 async fn record_user_agent(
501 &mut self,
502 session: Session,
503 user_agent: String,
504 ) -> Result<Session, Self::Error>;
505
506 /// Set the human name of a [`Session`]
507 ///
508 /// # Parameters
509 ///
510 /// * `session`: The [`Session`] to set the human name for
511 /// * `human_name`: The human name to set
512 async fn set_human_name(
513 &mut self,
514 session: Session,
515 human_name: Option<String>,
516 ) -> Result<Session, Self::Error>;
517
518 /// Cleanup finished [`Session`]s
519 ///
520 /// Deletes sessions finished between `since` and `until`. Returns the
521 /// number of deleted sessions and the timestamp of the last deleted
522 /// session for pagination.
523 ///
524 /// # Parameters
525 ///
526 /// * `since`: The earliest finish time to delete (exclusive). If `None`,
527 /// starts from the beginning.
528 /// * `until`: The latest finish time to delete (exclusive)
529 /// * `limit`: Maximum number of sessions to delete in this batch
530 ///
531 /// # Errors
532 ///
533 /// Returns [`Self::Error`] if the underlying repository fails
534 async fn cleanup_finished(
535 &mut self,
536 since: Option<DateTime<Utc>>,
537 until: DateTime<Utc>,
538 limit: usize,
539 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
540
541 /// Clear IP addresses from sessions inactive since the threshold
542 ///
543 /// Sets `last_active_ip` to `NULL` for sessions where `last_active_at` is
544 /// before the threshold. Returns the number of sessions affected and the
545 /// last `last_active_at` timestamp processed for pagination.
546 ///
547 /// # Parameters
548 ///
549 /// * `since`: Only process sessions with `last_active_at` at or after this
550 /// timestamp (exclusive). If `None`, starts from the beginning.
551 /// * `threshold`: Clear IPs for sessions with `last_active_at` before this
552 /// time
553 /// * `limit`: Maximum number of sessions to update in this batch
554 ///
555 /// # Errors
556 ///
557 /// Returns [`Self::Error`] if the underlying repository fails
558 async fn cleanup_inactive_ips(
559 &mut self,
560 since: Option<DateTime<Utc>>,
561 threshold: DateTime<Utc>,
562 limit: usize,
563 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
564}
565
566repository_impl!(OAuth2SessionRepository:
567 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error>;
568
569 async fn add(
570 &mut self,
571 rng: &mut (dyn RngCore + Send),
572 clock: &dyn Clock,
573 client: &Client,
574 user: Option<&User>,
575 user_session: Option<&BrowserSession>,
576 scope: Scope,
577 ) -> Result<Session, Self::Error>;
578
579 async fn add_from_browser_session(
580 &mut self,
581 rng: &mut (dyn RngCore + Send),
582 clock: &dyn Clock,
583 client: &Client,
584 user_session: &BrowserSession,
585 scope: Scope,
586 ) -> Result<Session, Self::Error>;
587
588 async fn add_from_client_credentials(
589 &mut self,
590 rng: &mut (dyn RngCore + Send),
591 clock: &dyn Clock,
592 client: &Client,
593 scope: Scope,
594 ) -> Result<Session, Self::Error>;
595
596 async fn finish(&mut self, clock: &dyn Clock, session: Session)
597 -> Result<Session, Self::Error>;
598
599 async fn finish_bulk(
600 &mut self,
601 clock: &dyn Clock,
602 filter: OAuth2SessionFilter<'_>,
603 ) -> Result<usize, Self::Error>;
604
605 async fn list(
606 &mut self,
607 filter: OAuth2SessionFilter<'_>,
608 pagination: Pagination,
609 ) -> Result<Page<Session>, Self::Error>;
610
611 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error>;
612
613 async fn record_batch_activity(
614 &mut self,
615 activity: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
616 ) -> Result<(), Self::Error>;
617
618 async fn record_user_agent(
619 &mut self,
620 session: Session,
621 user_agent: String,
622 ) -> Result<Session, Self::Error>;
623
624 async fn set_human_name(
625 &mut self,
626 session: Session,
627 human_name: Option<String>,
628 ) -> Result<Session, Self::Error>;
629
630 async fn cleanup_finished(
631 &mut self,
632 since: Option<DateTime<Utc>>,
633 until: DateTime<Utc>,
634 limit: usize,
635 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
636
637 async fn cleanup_inactive_ips(
638 &mut self,
639 since: Option<DateTime<Utc>>,
640 threshold: DateTime<Utc>,
641 limit: usize,
642 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error>;
643);