1use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{BrowserSession, Client, Clock, Session, SessionState, User};
13use mas_storage::{
14 Page, Pagination,
15 oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
16 pagination::Node,
17};
18use oauth2_types::scope::{Scope, ScopeToken};
19use rand::RngCore;
20use sea_query::{
21 Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
22 extension::postgres::PgExpr,
23};
24use sea_query_binder::SqlxBinder;
25use sqlx::PgConnection;
26use ulid::Ulid;
27use uuid::Uuid;
28
29use crate::{
30 DatabaseError, DatabaseInconsistencyError,
31 filter::{Filter, StatementExt},
32 iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
33 pagination::QueryBuilderExt,
34 tracing::ExecuteExt,
35 ulid_at::{max_ulid_at, min_ulid_at},
36};
37
38pub struct PgOAuth2SessionRepository<'c> {
40 conn: &'c mut PgConnection,
41}
42
43impl<'c> PgOAuth2SessionRepository<'c> {
44 pub fn new(conn: &'c mut PgConnection) -> Self {
47 Self { conn }
48 }
49}
50
51#[derive(sqlx::FromRow)]
52#[enum_def]
53struct OAuthSessionLookup {
54 oauth2_session_id: Uuid,
55 user_id: Option<Uuid>,
56 user_session_id: Option<Uuid>,
57 oauth2_client_id: Uuid,
58 scope_list: Vec<String>,
59 created_at: DateTime<Utc>,
60 finished_at: Option<DateTime<Utc>>,
61 user_agent: Option<String>,
62 last_active_at: Option<DateTime<Utc>>,
63 last_active_ip: Option<IpAddr>,
64 human_name: Option<String>,
65}
66
67impl Node<Ulid> for OAuthSessionLookup {
68 fn cursor(&self) -> Ulid {
69 self.oauth2_session_id.into()
70 }
71}
72
73impl TryFrom<OAuthSessionLookup> for Session {
74 type Error = DatabaseInconsistencyError;
75
76 fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
77 let id = Ulid::from(value.oauth2_session_id);
78 let scope: Result<Scope, _> = value
79 .scope_list
80 .iter()
81 .map(|s| s.parse::<ScopeToken>())
82 .collect();
83 let scope = scope.map_err(|e| {
84 DatabaseInconsistencyError::on("oauth2_sessions")
85 .column("scope")
86 .row(id)
87 .source(e)
88 })?;
89
90 let state = match value.finished_at {
91 None => SessionState::Valid,
92 Some(finished_at) => SessionState::Finished { finished_at },
93 };
94
95 Ok(Session {
96 id,
97 state,
98 created_at: value.created_at,
99 client_id: value.oauth2_client_id.into(),
100 user_id: value.user_id.map(Ulid::from),
101 user_session_id: value.user_session_id.map(Ulid::from),
102 scope,
103 user_agent: value.user_agent,
104 last_active_at: value.last_active_at,
105 last_active_ip: value.last_active_ip,
106 human_name: value.human_name,
107 })
108 }
109}
110
111impl Filter for OAuth2SessionFilter<'_> {
112 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
113 sea_query::Condition::all()
114 .add_option(self.user().map(|user| {
115 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
116 }))
117 .add_option(self.client().map(|client| {
118 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
119 .eq(Uuid::from(client.id))
120 }))
121 .add_option(self.clients().map(|clients| {
122 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
123 .is_in(clients.iter().map(|c| Uuid::from(c.id)))
124 }))
125 .add_option(self.client_kind().map(|client_kind| {
126 let static_clients = Query::select()
130 .expr(Expr::col((
131 OAuth2Clients::Table,
132 OAuth2Clients::OAuth2ClientId,
133 )))
134 .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
135 .from(OAuth2Clients::Table)
136 .take();
137 if client_kind.is_static() {
138 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
139 .eq(Expr::any(static_clients))
140 } else {
141 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
142 .ne(Expr::all(static_clients))
143 }
144 }))
145 .add_option(self.device().map(|device| -> SimpleExpr {
146 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
147 Condition::any()
148 .add(
149 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
150 OAuth2Sessions::Table,
151 OAuth2Sessions::ScopeList,
152 )))),
153 )
154 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
155 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
156 )))
157 .into()
158 } else {
159 Expr::val(false).into()
161 }
162 }))
163 .add_option(self.browser_session().map(|browser_session| {
164 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
165 .eq(Uuid::from(browser_session.id))
166 }))
167 .add_option(self.browser_session_filter().map(|browser_session_filter| {
168 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
169 Query::select()
170 .expr(Expr::col((
171 UserSessions::Table,
172 UserSessions::UserSessionId,
173 )))
174 .apply_filter(browser_session_filter)
175 .from(UserSessions::Table)
176 .take(),
177 )
178 }))
179 .add_option(self.state().map(|state| {
180 if state.is_active() {
181 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
182 } else {
183 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
184 }
185 }))
186 .add_option(self.scope().map(|scope| {
187 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
188 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
189 }))
190 .add_option(self.any_user().map(|any_user| {
191 if any_user {
192 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
193 } else {
194 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
195 }
196 }))
197 .add_option(self.last_active_after().map(|last_active_after| {
198 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
199 .gt(last_active_after)
200 }))
201 .add_option(self.last_active_before().map(|last_active_before| {
202 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
203 .lt(last_active_before)
204 }))
205 .add_option(self.created_after().map(|created_after| {
206 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId))
210 .gt(max_ulid_at(created_after))
211 }))
212 .add_option(self.created_before().map(|created_before| {
213 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId))
214 .lt(min_ulid_at(created_before))
215 }))
216 }
217}
218
219#[async_trait]
220impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
221 type Error = DatabaseError;
222
223 #[tracing::instrument(
224 name = "db.oauth2_session.lookup",
225 skip_all,
226 fields(
227 db.query.text,
228 session.id = %id,
229 ),
230 err,
231 )]
232 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
233 let res = sqlx::query_as!(
234 OAuthSessionLookup,
235 r#"
236 SELECT oauth2_session_id
237 , user_id
238 , user_session_id
239 , oauth2_client_id
240 , scope_list
241 , created_at
242 , finished_at
243 , user_agent
244 , last_active_at
245 , last_active_ip as "last_active_ip: IpAddr"
246 , human_name
247 FROM oauth2_sessions
248
249 WHERE oauth2_session_id = $1
250 "#,
251 Uuid::from(id),
252 )
253 .traced()
254 .fetch_optional(&mut *self.conn)
255 .await?;
256
257 let Some(session) = res else { return Ok(None) };
258
259 Ok(Some(session.try_into()?))
260 }
261
262 #[tracing::instrument(
263 name = "db.oauth2_session.add",
264 skip_all,
265 fields(
266 db.query.text,
267 %client.id,
268 session.id,
269 session.scope = %scope,
270 ),
271 err,
272 )]
273 async fn add(
274 &mut self,
275 rng: &mut (dyn RngCore + Send),
276 clock: &dyn Clock,
277 client: &Client,
278 user: Option<&User>,
279 user_session: Option<&BrowserSession>,
280 scope: Scope,
281 ) -> Result<Session, Self::Error> {
282 let created_at = clock.now();
283 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
284 tracing::Span::current().record("session.id", tracing::field::display(id));
285
286 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
287
288 sqlx::query!(
289 r#"
290 INSERT INTO oauth2_sessions
291 ( oauth2_session_id
292 , user_id
293 , user_session_id
294 , oauth2_client_id
295 , scope_list
296 , created_at
297 )
298 VALUES ($1, $2, $3, $4, $5, $6)
299 "#,
300 Uuid::from(id),
301 user.map(|u| Uuid::from(u.id)),
302 user_session.map(|s| Uuid::from(s.id)),
303 Uuid::from(client.id),
304 &scope_list,
305 created_at,
306 )
307 .traced()
308 .execute(&mut *self.conn)
309 .await?;
310
311 Ok(Session {
312 id,
313 state: SessionState::Valid,
314 created_at,
315 user_id: user.map(|u| u.id),
316 user_session_id: user_session.map(|s| s.id),
317 client_id: client.id,
318 scope,
319 user_agent: None,
320 last_active_at: None,
321 last_active_ip: None,
322 human_name: None,
323 })
324 }
325
326 #[tracing::instrument(
327 name = "db.oauth2_session.finish_bulk",
328 skip_all,
329 fields(
330 db.query.text,
331 ),
332 err,
333 )]
334 async fn finish_bulk(
335 &mut self,
336 clock: &dyn Clock,
337 filter: OAuth2SessionFilter<'_>,
338 ) -> Result<usize, Self::Error> {
339 let finished_at = clock.now();
340 let (sql, arguments) = Query::update()
341 .table(OAuth2Sessions::Table)
342 .value(OAuth2Sessions::FinishedAt, finished_at)
343 .apply_filter(filter)
344 .build_sqlx(PostgresQueryBuilder);
345
346 let res = sqlx::query_with(&sql, arguments)
347 .traced()
348 .execute(&mut *self.conn)
349 .await?;
350
351 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
352 }
353
354 #[tracing::instrument(
355 name = "db.oauth2_session.finish",
356 skip_all,
357 fields(
358 db.query.text,
359 %session.id,
360 %session.scope,
361 client.id = %session.client_id,
362 ),
363 err,
364 )]
365 async fn finish(
366 &mut self,
367 clock: &dyn Clock,
368 session: Session,
369 ) -> Result<Session, Self::Error> {
370 let finished_at = clock.now();
371 let res = sqlx::query!(
372 r#"
373 UPDATE oauth2_sessions
374 SET finished_at = $2
375 WHERE oauth2_session_id = $1
376 "#,
377 Uuid::from(session.id),
378 finished_at,
379 )
380 .traced()
381 .execute(&mut *self.conn)
382 .await?;
383
384 DatabaseError::ensure_affected_rows(&res, 1)?;
385
386 session
387 .finish(finished_at)
388 .map_err(DatabaseError::to_invalid_operation)
389 }
390
391 #[tracing::instrument(
392 name = "db.oauth2_session.list",
393 skip_all,
394 fields(
395 db.query.text,
396 ),
397 err,
398 )]
399 async fn list(
400 &mut self,
401 filter: OAuth2SessionFilter<'_>,
402 pagination: Pagination,
403 ) -> Result<Page<Session>, Self::Error> {
404 let (sql, arguments) = Query::select()
405 .expr_as(
406 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
407 OAuthSessionLookupIden::Oauth2SessionId,
408 )
409 .expr_as(
410 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
411 OAuthSessionLookupIden::UserId,
412 )
413 .expr_as(
414 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
415 OAuthSessionLookupIden::UserSessionId,
416 )
417 .expr_as(
418 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
419 OAuthSessionLookupIden::Oauth2ClientId,
420 )
421 .expr_as(
422 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
423 OAuthSessionLookupIden::ScopeList,
424 )
425 .expr_as(
426 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
427 OAuthSessionLookupIden::CreatedAt,
428 )
429 .expr_as(
430 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
431 OAuthSessionLookupIden::FinishedAt,
432 )
433 .expr_as(
434 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
435 OAuthSessionLookupIden::UserAgent,
436 )
437 .expr_as(
438 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
439 OAuthSessionLookupIden::LastActiveAt,
440 )
441 .expr_as(
442 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
443 OAuthSessionLookupIden::LastActiveIp,
444 )
445 .expr_as(
446 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
447 OAuthSessionLookupIden::HumanName,
448 )
449 .from(OAuth2Sessions::Table)
450 .apply_filter(filter)
451 .generate_pagination(
452 (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
453 pagination,
454 )
455 .build_sqlx(PostgresQueryBuilder);
456
457 let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
458 .traced()
459 .fetch_all(&mut *self.conn)
460 .await?;
461
462 let page = pagination.process(edges).try_map(Session::try_from)?;
463
464 Ok(page)
465 }
466
467 #[tracing::instrument(
468 name = "db.oauth2_session.count",
469 skip_all,
470 fields(
471 db.query.text,
472 ),
473 err,
474 )]
475 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
476 let (sql, arguments) = Query::select()
477 .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
478 .from(OAuth2Sessions::Table)
479 .apply_filter(filter)
480 .build_sqlx(PostgresQueryBuilder);
481
482 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
483 .traced()
484 .fetch_one(&mut *self.conn)
485 .await?;
486
487 count
488 .try_into()
489 .map_err(DatabaseError::to_invalid_operation)
490 }
491
492 #[tracing::instrument(
493 name = "db.oauth2_session.record_batch_activity",
494 skip_all,
495 fields(
496 db.query.text,
497 ),
498 err,
499 )]
500 async fn record_batch_activity(
501 &mut self,
502 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
503 ) -> Result<(), Self::Error> {
504 activities.sort_unstable();
507 let mut ids = Vec::with_capacity(activities.len());
508 let mut last_activities = Vec::with_capacity(activities.len());
509 let mut ips = Vec::with_capacity(activities.len());
510
511 for (id, last_activity, ip) in activities {
512 ids.push(Uuid::from(id));
513 last_activities.push(last_activity);
514 ips.push(ip);
515 }
516
517 let res = sqlx::query!(
518 r#"
519 UPDATE oauth2_sessions
520 SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
521 , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
522 FROM (
523 SELECT *
524 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
525 AS t(oauth2_session_id, last_active_at, last_active_ip)
526 ) AS t
527 WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
528 "#,
529 &ids,
530 &last_activities,
531 &ips as &[Option<IpAddr>],
532 )
533 .traced()
534 .execute(&mut *self.conn)
535 .await?;
536
537 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
538
539 Ok(())
540 }
541
542 #[tracing::instrument(
543 name = "db.oauth2_session.record_user_agent",
544 skip_all,
545 fields(
546 db.query.text,
547 %session.id,
548 %session.scope,
549 client.id = %session.client_id,
550 session.user_agent = user_agent,
551 ),
552 err,
553 )]
554 async fn record_user_agent(
555 &mut self,
556 mut session: Session,
557 user_agent: String,
558 ) -> Result<Session, Self::Error> {
559 let res = sqlx::query!(
560 r#"
561 UPDATE oauth2_sessions
562 SET user_agent = $2
563 WHERE oauth2_session_id = $1
564 "#,
565 Uuid::from(session.id),
566 &*user_agent,
567 )
568 .traced()
569 .execute(&mut *self.conn)
570 .await?;
571
572 session.user_agent = Some(user_agent);
573
574 DatabaseError::ensure_affected_rows(&res, 1)?;
575
576 Ok(session)
577 }
578
579 #[tracing::instrument(
580 name = "repository.oauth2_session.set_human_name",
581 skip(self),
582 fields(
583 client.id = %session.client_id,
584 session.human_name = ?human_name,
585 ),
586 err,
587 )]
588 async fn set_human_name(
589 &mut self,
590 mut session: Session,
591 human_name: Option<String>,
592 ) -> Result<Session, Self::Error> {
593 let res = sqlx::query!(
594 r#"
595 UPDATE oauth2_sessions
596 SET human_name = $2
597 WHERE oauth2_session_id = $1
598 "#,
599 Uuid::from(session.id),
600 human_name.as_deref(),
601 )
602 .traced()
603 .execute(&mut *self.conn)
604 .await?;
605
606 session.human_name = human_name;
607
608 DatabaseError::ensure_affected_rows(&res, 1)?;
609
610 Ok(session)
611 }
612
613 #[tracing::instrument(
614 name = "db.oauth2_session.cleanup_finished",
615 skip_all,
616 fields(
617 db.query.text,
618 since = since.map(tracing::field::display),
619 until = %until,
620 limit = limit,
621 ),
622 err,
623 )]
624 async fn cleanup_finished(
625 &mut self,
626 since: Option<DateTime<Utc>>,
627 until: DateTime<Utc>,
628 limit: usize,
629 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
630 let res = sqlx::query!(
631 r#"
632 WITH
633 to_delete AS (
634 SELECT oauth2_session_id, finished_at
635 FROM oauth2_sessions
636 WHERE finished_at IS NOT NULL
637 AND ($1::timestamptz IS NULL OR finished_at >= $1)
638 AND finished_at < $2
639 ORDER BY finished_at ASC
640 LIMIT $3
641 FOR UPDATE
642 ),
643 deleted_refresh_tokens AS (
644 DELETE FROM oauth2_refresh_tokens USING to_delete
645 WHERE oauth2_refresh_tokens.oauth2_session_id = to_delete.oauth2_session_id
646 ),
647 deleted_access_tokens AS (
648 DELETE FROM oauth2_access_tokens USING to_delete
649 WHERE oauth2_access_tokens.oauth2_session_id = to_delete.oauth2_session_id
650 ),
651 deleted_sessions AS (
652 DELETE FROM oauth2_sessions USING to_delete
653 WHERE oauth2_sessions.oauth2_session_id = to_delete.oauth2_session_id
654 RETURNING oauth2_sessions.finished_at
655 )
656 SELECT COUNT(*) as "count!", MAX(finished_at) as last_finished_at FROM deleted_sessions
657 "#,
658 since,
659 until,
660 i64::try_from(limit).unwrap_or(i64::MAX),
661 )
662 .traced()
663 .fetch_one(&mut *self.conn)
664 .await?;
665
666 Ok((
667 res.count.try_into().unwrap_or(usize::MAX),
668 res.last_finished_at,
669 ))
670 }
671
672 #[tracing::instrument(
673 name = "db.oauth2_session.cleanup_inactive_ips",
674 skip_all,
675 fields(
676 db.query.text,
677 since = since.map(tracing::field::display),
678 threshold = %threshold,
679 limit = limit,
680 ),
681 err,
682 )]
683 async fn cleanup_inactive_ips(
684 &mut self,
685 since: Option<DateTime<Utc>>,
686 threshold: DateTime<Utc>,
687 limit: usize,
688 ) -> Result<(usize, Option<DateTime<Utc>>), Self::Error> {
689 let res = sqlx::query!(
690 r#"
691 WITH to_update AS (
692 SELECT oauth2_session_id, last_active_at
693 FROM oauth2_sessions
694 WHERE last_active_ip IS NOT NULL
695 AND last_active_at IS NOT NULL
696 AND ($1::timestamptz IS NULL OR last_active_at >= $1)
697 AND last_active_at < $2
698 ORDER BY last_active_at ASC
699 LIMIT $3
700 FOR UPDATE
701 ),
702 updated AS (
703 UPDATE oauth2_sessions
704 SET last_active_ip = NULL
705 FROM to_update
706 WHERE oauth2_sessions.oauth2_session_id = to_update.oauth2_session_id
707 RETURNING oauth2_sessions.last_active_at
708 )
709 SELECT COUNT(*) AS "count!", MAX(last_active_at) AS last_active_at FROM updated
710 "#,
711 since,
712 threshold,
713 i64::try_from(limit).unwrap_or(i64::MAX),
714 )
715 .traced()
716 .fetch_one(&mut *self.conn)
717 .await?;
718
719 Ok((
720 res.count.try_into().unwrap_or(usize::MAX),
721 res.last_active_at,
722 ))
723 }
724}