1use std::{
9 collections::{BTreeMap, BTreeSet},
10 string::ToString,
11};
12
13use async_trait::async_trait;
14use mas_data_model::{Client, Clock, JwksOrJwksUri};
15use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
16use mas_jose::jwk::PublicJsonWebKeySet;
17use mas_storage::{
18 Page, Pagination,
19 oauth2::{OAuth2ClientFilter, OAuth2ClientKind, OAuth2ClientRepository},
20 pagination::Node,
21};
22use oauth2_types::{oidc::ApplicationType, requests::GrantType};
23use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
24use rand::RngCore;
25use sea_query::{
26 Expr, PostgresQueryBuilder, Query, SimpleExpr, enum_def, extension::postgres::PgExpr as _,
27};
28use sea_query_binder::SqlxBinder;
29use sqlx::PgConnection;
30use tracing::{Instrument, info_span};
31use ulid::Ulid;
32use url::Url;
33use uuid::Uuid;
34
35use crate::{
36 DatabaseError, DatabaseInconsistencyError,
37 filter::{Filter, StatementExt},
38 iden::{OAuth2Clients, OAuth2Sessions},
39 pagination::QueryBuilderExt,
40 tracing::ExecuteExt,
41};
42
43pub struct PgOAuth2ClientRepository<'c> {
45 conn: &'c mut PgConnection,
46}
47
48impl<'c> PgOAuth2ClientRepository<'c> {
49 pub fn new(conn: &'c mut PgConnection) -> Self {
52 Self { conn }
53 }
54}
55
56#[expect(clippy::struct_excessive_bools)]
57#[derive(Debug, sqlx::FromRow)]
58#[enum_def]
59struct OAuth2ClientLookup {
60 oauth2_client_id: Uuid,
61 metadata_digest: Option<String>,
62 encrypted_client_secret: Option<String>,
63 application_type: Option<String>,
64 redirect_uris: Vec<String>,
65 grant_type_authorization_code: bool,
66 grant_type_refresh_token: bool,
67 grant_type_client_credentials: bool,
68 grant_type_device_code: bool,
69 client_name: Option<String>,
70 logo_uri: Option<String>,
71 client_uri: Option<String>,
72 policy_uri: Option<String>,
73 tos_uri: Option<String>,
74 jwks_uri: Option<String>,
75 jwks: Option<serde_json::Value>,
76 id_token_signed_response_alg: Option<String>,
77 userinfo_signed_response_alg: Option<String>,
78 token_endpoint_auth_method: Option<String>,
79 token_endpoint_auth_signing_alg: Option<String>,
80 initiate_login_uri: Option<String>,
81 is_static: bool,
82}
83
84impl Node<Ulid> for OAuth2ClientLookup {
85 fn cursor(&self) -> Ulid {
86 self.oauth2_client_id.into()
87 }
88}
89
90impl Filter for OAuth2ClientFilter<'_> {
91 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
92 sea_query::Condition::all()
93 .add_option(self.kind().map(|kind| {
94 let is_static = matches!(kind, OAuth2ClientKind::Static);
95 Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).eq(is_static)
96 }))
97 .add_option(self.client_name().map(|client_name| {
98 Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientName))
99 .ilike(format!("%{client_name}%"))
100 }))
101 .add_option(self.client_uri().map(|client_uri| {
102 Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientUri))
103 .ilike(format!("%{client_uri}%"))
104 }))
105 .add_option(self.grant_type().map(|grant_type| -> SimpleExpr {
106 let column = match grant_type {
107 GrantType::AuthorizationCode => OAuth2Clients::GrantTypeAuthorizationCode,
108 GrantType::RefreshToken => OAuth2Clients::GrantTypeRefreshToken,
109 GrantType::ClientCredentials => OAuth2Clients::GrantTypeClientCredentials,
110 GrantType::DeviceCode => OAuth2Clients::GrantTypeDeviceCode,
111 _ => return Expr::val(false).into(),
114 };
115 Expr::col((OAuth2Clients::Table, column)).eq(true)
116 }))
117 .add_option(self.has_active_sessions().map(|has| -> SimpleExpr {
118 let exists = Expr::exists(
119 Query::select()
120 .expr(Expr::cust("1"))
121 .from(OAuth2Sessions::Table)
122 .and_where(
123 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
124 .equals((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)),
125 )
126 .and_where(
127 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
128 .is_null(),
129 )
130 .take(),
131 );
132 if has { exists } else { exists.not() }
133 }))
134 }
135}
136
137impl TryFrom<OAuth2ClientLookup> for Client {
138 type Error = DatabaseInconsistencyError;
139
140 fn try_from(value: OAuth2ClientLookup) -> Result<Self, Self::Error> {
141 let id = Ulid::from(value.oauth2_client_id);
142
143 let redirect_uris: Result<Vec<Url>, _> =
144 value.redirect_uris.iter().map(|s| s.parse()).collect();
145 let redirect_uris = redirect_uris.map_err(|e| {
146 DatabaseInconsistencyError::on("oauth2_clients")
147 .column("redirect_uris")
148 .row(id)
149 .source(e)
150 })?;
151
152 let application_type = value
153 .application_type
154 .map(|s| s.parse())
155 .transpose()
156 .map_err(|e| {
157 DatabaseInconsistencyError::on("oauth2_clients")
158 .column("application_type")
159 .row(id)
160 .source(e)
161 })?;
162
163 let mut grant_types = Vec::new();
164 if value.grant_type_authorization_code {
165 grant_types.push(GrantType::AuthorizationCode);
166 }
167 if value.grant_type_refresh_token {
168 grant_types.push(GrantType::RefreshToken);
169 }
170 if value.grant_type_client_credentials {
171 grant_types.push(GrantType::ClientCredentials);
172 }
173 if value.grant_type_device_code {
174 grant_types.push(GrantType::DeviceCode);
175 }
176
177 let logo_uri = value.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
178 DatabaseInconsistencyError::on("oauth2_clients")
179 .column("logo_uri")
180 .row(id)
181 .source(e)
182 })?;
183
184 let client_uri = value
185 .client_uri
186 .map(|s| s.parse())
187 .transpose()
188 .map_err(|e| {
189 DatabaseInconsistencyError::on("oauth2_clients")
190 .column("client_uri")
191 .row(id)
192 .source(e)
193 })?;
194
195 let policy_uri = value
196 .policy_uri
197 .map(|s| s.parse())
198 .transpose()
199 .map_err(|e| {
200 DatabaseInconsistencyError::on("oauth2_clients")
201 .column("policy_uri")
202 .row(id)
203 .source(e)
204 })?;
205
206 let tos_uri = value.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
207 DatabaseInconsistencyError::on("oauth2_clients")
208 .column("tos_uri")
209 .row(id)
210 .source(e)
211 })?;
212
213 let id_token_signed_response_alg = value
214 .id_token_signed_response_alg
215 .map(|s| s.parse())
216 .transpose()
217 .map_err(|e| {
218 DatabaseInconsistencyError::on("oauth2_clients")
219 .column("id_token_signed_response_alg")
220 .row(id)
221 .source(e)
222 })?;
223
224 let userinfo_signed_response_alg = value
225 .userinfo_signed_response_alg
226 .map(|s| s.parse())
227 .transpose()
228 .map_err(|e| {
229 DatabaseInconsistencyError::on("oauth2_clients")
230 .column("userinfo_signed_response_alg")
231 .row(id)
232 .source(e)
233 })?;
234
235 let token_endpoint_auth_method = value
236 .token_endpoint_auth_method
237 .map(|s| s.parse())
238 .transpose()
239 .map_err(|e| {
240 DatabaseInconsistencyError::on("oauth2_clients")
241 .column("token_endpoint_auth_method")
242 .row(id)
243 .source(e)
244 })?;
245
246 let token_endpoint_auth_signing_alg = value
247 .token_endpoint_auth_signing_alg
248 .map(|s| s.parse())
249 .transpose()
250 .map_err(|e| {
251 DatabaseInconsistencyError::on("oauth2_clients")
252 .column("token_endpoint_auth_signing_alg")
253 .row(id)
254 .source(e)
255 })?;
256
257 let initiate_login_uri = value
258 .initiate_login_uri
259 .map(|s| s.parse())
260 .transpose()
261 .map_err(|e| {
262 DatabaseInconsistencyError::on("oauth2_clients")
263 .column("initiate_login_uri")
264 .row(id)
265 .source(e)
266 })?;
267
268 let jwks = match (value.jwks, value.jwks_uri) {
269 (None, None) => None,
270 (Some(jwks), None) => {
271 let jwks = serde_json::from_value(jwks).map_err(|e| {
272 DatabaseInconsistencyError::on("oauth2_clients")
273 .column("jwks")
274 .row(id)
275 .source(e)
276 })?;
277 Some(JwksOrJwksUri::Jwks(jwks))
278 }
279 (None, Some(jwks_uri)) => {
280 let jwks_uri = jwks_uri.parse().map_err(|e| {
281 DatabaseInconsistencyError::on("oauth2_clients")
282 .column("jwks_uri")
283 .row(id)
284 .source(e)
285 })?;
286
287 Some(JwksOrJwksUri::JwksUri(jwks_uri))
288 }
289 _ => {
290 return Err(DatabaseInconsistencyError::on("oauth2_clients")
291 .column("jwks(_uri)")
292 .row(id));
293 }
294 };
295
296 Ok(Client {
297 id,
298 client_id: id.to_string(),
299 metadata_digest: value.metadata_digest,
300 encrypted_client_secret: value.encrypted_client_secret,
301 application_type,
302 redirect_uris,
303 grant_types,
304 client_name: value.client_name,
305 logo_uri,
306 client_uri,
307 policy_uri,
308 tos_uri,
309 jwks,
310 id_token_signed_response_alg,
311 userinfo_signed_response_alg,
312 token_endpoint_auth_method,
313 token_endpoint_auth_signing_alg,
314 initiate_login_uri,
315 is_static: value.is_static,
316 })
317 }
318}
319
320#[async_trait]
321impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
322 type Error = DatabaseError;
323
324 #[tracing::instrument(
325 name = "db.oauth2_client.lookup",
326 skip_all,
327 fields(
328 db.query.text,
329 oauth2_client.id = %id,
330 ),
331 err,
332 )]
333 async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
334 let res = sqlx::query_as!(
335 OAuth2ClientLookup,
336 r#"
337 SELECT oauth2_client_id
338 , metadata_digest
339 , encrypted_client_secret
340 , application_type
341 , redirect_uris
342 , grant_type_authorization_code
343 , grant_type_refresh_token
344 , grant_type_client_credentials
345 , grant_type_device_code
346 , client_name
347 , logo_uri
348 , client_uri
349 , policy_uri
350 , tos_uri
351 , jwks_uri
352 , jwks
353 , id_token_signed_response_alg
354 , userinfo_signed_response_alg
355 , token_endpoint_auth_method
356 , token_endpoint_auth_signing_alg
357 , initiate_login_uri
358 , is_static
359 FROM oauth2_clients c
360
361 WHERE oauth2_client_id = $1
362 "#,
363 Uuid::from(id),
364 )
365 .traced()
366 .fetch_optional(&mut *self.conn)
367 .await?;
368
369 let Some(res) = res else { return Ok(None) };
370
371 Ok(Some(res.try_into()?))
372 }
373
374 #[tracing::instrument(
375 name = "db.oauth2_client.find_by_metadata_digest",
376 skip_all,
377 fields(
378 db.query.text,
379 ),
380 err,
381 )]
382 async fn find_by_metadata_digest(
383 &mut self,
384 digest: &str,
385 ) -> Result<Option<Client>, Self::Error> {
386 let res = sqlx::query_as!(
387 OAuth2ClientLookup,
388 r#"
389 SELECT oauth2_client_id
390 , metadata_digest
391 , encrypted_client_secret
392 , application_type
393 , redirect_uris
394 , grant_type_authorization_code
395 , grant_type_refresh_token
396 , grant_type_client_credentials
397 , grant_type_device_code
398 , client_name
399 , logo_uri
400 , client_uri
401 , policy_uri
402 , tos_uri
403 , jwks_uri
404 , jwks
405 , id_token_signed_response_alg
406 , userinfo_signed_response_alg
407 , token_endpoint_auth_method
408 , token_endpoint_auth_signing_alg
409 , initiate_login_uri
410 , is_static
411 FROM oauth2_clients
412 WHERE metadata_digest = $1
413 "#,
414 digest,
415 )
416 .traced()
417 .fetch_optional(&mut *self.conn)
418 .await?;
419
420 let Some(res) = res else { return Ok(None) };
421
422 Ok(Some(res.try_into()?))
423 }
424
425 #[tracing::instrument(
426 name = "db.oauth2_client.load_batch",
427 skip_all,
428 fields(
429 db.query.text,
430 ),
431 err,
432 )]
433 async fn load_batch(
434 &mut self,
435 ids: BTreeSet<Ulid>,
436 ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
437 let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
438 let res = sqlx::query_as!(
439 OAuth2ClientLookup,
440 r#"
441 SELECT oauth2_client_id
442 , metadata_digest
443 , encrypted_client_secret
444 , application_type
445 , redirect_uris
446 , grant_type_authorization_code
447 , grant_type_refresh_token
448 , grant_type_client_credentials
449 , grant_type_device_code
450 , client_name
451 , logo_uri
452 , client_uri
453 , policy_uri
454 , tos_uri
455 , jwks_uri
456 , jwks
457 , id_token_signed_response_alg
458 , userinfo_signed_response_alg
459 , token_endpoint_auth_method
460 , token_endpoint_auth_signing_alg
461 , initiate_login_uri
462 , is_static
463 FROM oauth2_clients c
464
465 WHERE oauth2_client_id = ANY($1::uuid[])
466 "#,
467 &ids,
468 )
469 .traced()
470 .fetch_all(&mut *self.conn)
471 .await?;
472
473 res.into_iter()
474 .map(|r| {
475 r.try_into()
476 .map(|c: Client| (c.id, c))
477 .map_err(DatabaseError::from)
478 })
479 .collect()
480 }
481
482 #[tracing::instrument(
483 name = "db.oauth2_client.add",
484 skip_all,
485 fields(
486 db.query.text,
487 client.id,
488 client.name = client_name
489 ),
490 err,
491 )]
492 async fn add(
493 &mut self,
494 rng: &mut (dyn RngCore + Send),
495 clock: &dyn Clock,
496 redirect_uris: Vec<Url>,
497 metadata_digest: Option<String>,
498 encrypted_client_secret: Option<String>,
499 application_type: Option<ApplicationType>,
500 grant_types: Vec<GrantType>,
501 client_name: Option<String>,
502 logo_uri: Option<Url>,
503 client_uri: Option<Url>,
504 policy_uri: Option<Url>,
505 tos_uri: Option<Url>,
506 jwks_uri: Option<Url>,
507 jwks: Option<PublicJsonWebKeySet>,
508 id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
509 userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
510 token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
511 token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
512 initiate_login_uri: Option<Url>,
513 ) -> Result<Client, Self::Error> {
514 let now = clock.now();
515 let id = Ulid::from_datetime_with_source(now.into(), rng);
516 tracing::Span::current().record("client.id", tracing::field::display(id));
517
518 let jwks_json = jwks
519 .as_ref()
520 .map(serde_json::to_value)
521 .transpose()
522 .map_err(DatabaseError::to_invalid_operation)?;
523
524 let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
525
526 sqlx::query!(
527 r#"
528 INSERT INTO oauth2_clients
529 ( oauth2_client_id
530 , metadata_digest
531 , encrypted_client_secret
532 , application_type
533 , redirect_uris
534 , grant_type_authorization_code
535 , grant_type_refresh_token
536 , grant_type_client_credentials
537 , grant_type_device_code
538 , client_name
539 , logo_uri
540 , client_uri
541 , policy_uri
542 , tos_uri
543 , jwks_uri
544 , jwks
545 , id_token_signed_response_alg
546 , userinfo_signed_response_alg
547 , token_endpoint_auth_method
548 , token_endpoint_auth_signing_alg
549 , initiate_login_uri
550 , is_static
551 )
552 VALUES
553 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
554 $14, $15, $16, $17, $18, $19, $20, $21, FALSE)
555 "#,
556 Uuid::from(id),
557 metadata_digest,
558 encrypted_client_secret,
559 application_type.as_ref().map(ToString::to_string),
560 &redirect_uris_array,
561 grant_types.contains(&GrantType::AuthorizationCode),
562 grant_types.contains(&GrantType::RefreshToken),
563 grant_types.contains(&GrantType::ClientCredentials),
564 grant_types.contains(&GrantType::DeviceCode),
565 client_name,
566 logo_uri.as_ref().map(Url::as_str),
567 client_uri.as_ref().map(Url::as_str),
568 policy_uri.as_ref().map(Url::as_str),
569 tos_uri.as_ref().map(Url::as_str),
570 jwks_uri.as_ref().map(Url::as_str),
571 jwks_json,
572 id_token_signed_response_alg
573 .as_ref()
574 .map(ToString::to_string),
575 userinfo_signed_response_alg
576 .as_ref()
577 .map(ToString::to_string),
578 token_endpoint_auth_method.as_ref().map(ToString::to_string),
579 token_endpoint_auth_signing_alg
580 .as_ref()
581 .map(ToString::to_string),
582 initiate_login_uri.as_ref().map(Url::as_str),
583 )
584 .traced()
585 .execute(&mut *self.conn)
586 .await?;
587
588 let jwks = match (jwks, jwks_uri) {
589 (None, None) => None,
590 (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
591 (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
592 _ => return Err(DatabaseError::invalid_operation()),
593 };
594
595 Ok(Client {
596 id,
597 client_id: id.to_string(),
598 metadata_digest: None,
599 encrypted_client_secret,
600 application_type,
601 redirect_uris,
602 grant_types,
603 client_name,
604 logo_uri,
605 client_uri,
606 policy_uri,
607 tos_uri,
608 jwks,
609 id_token_signed_response_alg,
610 userinfo_signed_response_alg,
611 token_endpoint_auth_method,
612 token_endpoint_auth_signing_alg,
613 initiate_login_uri,
614 is_static: false,
615 })
616 }
617
618 #[tracing::instrument(
619 name = "db.oauth2_client.upsert_static",
620 skip_all,
621 fields(
622 db.query.text,
623 client.id = %client_id,
624 ),
625 err,
626 )]
627 async fn upsert_static(
628 &mut self,
629 client_id: Ulid,
630 client_name: Option<String>,
631 client_auth_method: OAuthClientAuthenticationMethod,
632 encrypted_client_secret: Option<String>,
633 jwks: Option<PublicJsonWebKeySet>,
634 jwks_uri: Option<Url>,
635 redirect_uris: Vec<Url>,
636 ) -> Result<Client, Self::Error> {
637 let jwks_json = jwks
638 .as_ref()
639 .map(serde_json::to_value)
640 .transpose()
641 .map_err(DatabaseError::to_invalid_operation)?;
642
643 let client_auth_method = client_auth_method.to_string();
644 let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
645
646 sqlx::query!(
647 r#"
648 INSERT INTO oauth2_clients
649 ( oauth2_client_id
650 , encrypted_client_secret
651 , redirect_uris
652 , grant_type_authorization_code
653 , grant_type_refresh_token
654 , grant_type_client_credentials
655 , grant_type_device_code
656 , token_endpoint_auth_method
657 , jwks
658 , client_name
659 , jwks_uri
660 , is_static
661 )
662 VALUES
663 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, TRUE)
664 ON CONFLICT (oauth2_client_id)
665 DO
666 UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
667 , redirect_uris = EXCLUDED.redirect_uris
668 , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
669 , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
670 , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
671 , grant_type_device_code = EXCLUDED.grant_type_device_code
672 , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
673 , jwks = EXCLUDED.jwks
674 , client_name = EXCLUDED.client_name
675 , jwks_uri = EXCLUDED.jwks_uri
676 , is_static = TRUE
677 "#,
678 Uuid::from(client_id),
679 encrypted_client_secret,
680 &redirect_uris_array,
681 true,
682 true,
683 true,
684 true,
685 client_auth_method,
686 jwks_json,
687 client_name,
688 jwks_uri.as_ref().map(Url::as_str),
689 )
690 .traced()
691 .execute(&mut *self.conn)
692 .await?;
693
694 let jwks = match (jwks, jwks_uri) {
695 (None, None) => None,
696 (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
697 (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
698 _ => return Err(DatabaseError::invalid_operation()),
699 };
700
701 Ok(Client {
702 id: client_id,
703 client_id: client_id.to_string(),
704 metadata_digest: None,
705 encrypted_client_secret,
706 application_type: None,
707 redirect_uris,
708 grant_types: vec![
709 GrantType::AuthorizationCode,
710 GrantType::RefreshToken,
711 GrantType::ClientCredentials,
712 ],
713 client_name,
714 logo_uri: None,
715 client_uri: None,
716 policy_uri: None,
717 tos_uri: None,
718 jwks,
719 id_token_signed_response_alg: None,
720 userinfo_signed_response_alg: None,
721 token_endpoint_auth_method: None,
722 token_endpoint_auth_signing_alg: None,
723 initiate_login_uri: None,
724 is_static: true,
725 })
726 }
727
728 #[tracing::instrument(
729 name = "db.oauth2_client.all_static",
730 skip_all,
731 fields(
732 db.query.text,
733 ),
734 err,
735 )]
736 async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
737 let res = sqlx::query_as!(
738 OAuth2ClientLookup,
739 r#"
740 SELECT oauth2_client_id
741 , metadata_digest
742 , encrypted_client_secret
743 , application_type
744 , redirect_uris
745 , grant_type_authorization_code
746 , grant_type_refresh_token
747 , grant_type_client_credentials
748 , grant_type_device_code
749 , client_name
750 , logo_uri
751 , client_uri
752 , policy_uri
753 , tos_uri
754 , jwks_uri
755 , jwks
756 , id_token_signed_response_alg
757 , userinfo_signed_response_alg
758 , token_endpoint_auth_method
759 , token_endpoint_auth_signing_alg
760 , initiate_login_uri
761 , is_static
762 FROM oauth2_clients c
763 WHERE is_static = TRUE
764 "#,
765 )
766 .traced()
767 .fetch_all(&mut *self.conn)
768 .await?;
769
770 res.into_iter()
771 .map(|r| r.try_into().map_err(DatabaseError::from))
772 .collect()
773 }
774
775 #[tracing::instrument(
776 name = "db.oauth2_client.delete_by_id",
777 skip_all,
778 fields(
779 db.query.text,
780 client.id = %id,
781 ),
782 err,
783 )]
784 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
785 {
787 let span = info_span!(
788 "db.oauth2_client.delete_by_id.authorization_grants",
789 { DB_QUERY_TEXT } = tracing::field::Empty,
790 );
791
792 sqlx::query!(
793 r#"
794 DELETE FROM oauth2_authorization_grants
795 WHERE oauth2_client_id = $1
796 "#,
797 Uuid::from(id),
798 )
799 .record(&span)
800 .execute(&mut *self.conn)
801 .instrument(span)
802 .await?;
803 }
804
805 {
807 let span = info_span!(
808 "db.oauth2_client.delete_by_id.access_tokens",
809 { DB_QUERY_TEXT } = tracing::field::Empty,
810 );
811
812 sqlx::query!(
813 r#"
814 DELETE FROM oauth2_access_tokens
815 WHERE oauth2_session_id IN (
816 SELECT oauth2_session_id
817 FROM oauth2_sessions
818 WHERE oauth2_client_id = $1
819 )
820 "#,
821 Uuid::from(id),
822 )
823 .record(&span)
824 .execute(&mut *self.conn)
825 .instrument(span)
826 .await?;
827 }
828
829 {
830 let span = info_span!(
831 "db.oauth2_client.delete_by_id.refresh_tokens",
832 { DB_QUERY_TEXT } = tracing::field::Empty,
833 );
834
835 sqlx::query!(
836 r#"
837 DELETE FROM oauth2_refresh_tokens
838 WHERE oauth2_session_id IN (
839 SELECT oauth2_session_id
840 FROM oauth2_sessions
841 WHERE oauth2_client_id = $1
842 )
843 "#,
844 Uuid::from(id),
845 )
846 .record(&span)
847 .execute(&mut *self.conn)
848 .instrument(span)
849 .await?;
850 }
851
852 {
853 let span = info_span!(
854 "db.oauth2_client.delete_by_id.sessions",
855 { DB_QUERY_TEXT } = tracing::field::Empty,
856 );
857
858 sqlx::query!(
859 r#"
860 DELETE FROM oauth2_sessions
861 WHERE oauth2_client_id = $1
862 "#,
863 Uuid::from(id),
864 )
865 .record(&span)
866 .execute(&mut *self.conn)
867 .instrument(span)
868 .await?;
869 }
870
871 {
874 let span = info_span!(
875 "db.oauth2_client.delete_by_id.personal_access_tokens",
876 { DB_QUERY_TEXT } = tracing::field::Empty,
877 );
878
879 sqlx::query!(
880 r#"
881 DELETE FROM personal_access_tokens
882 WHERE personal_session_id IN (
883 SELECT personal_session_id
884 FROM personal_sessions
885 WHERE owner_oauth2_client_id = $1
886 )
887 "#,
888 Uuid::from(id),
889 )
890 .record(&span)
891 .execute(&mut *self.conn)
892 .instrument(span)
893 .await?;
894 }
895 {
896 let span = info_span!(
897 "db.oauth2_client.delete_by_id.personal_sessions",
898 { DB_QUERY_TEXT } = tracing::field::Empty,
899 );
900
901 sqlx::query!(
902 r#"
903 DELETE FROM personal_sessions
904 WHERE owner_oauth2_client_id = $1
905 "#,
906 Uuid::from(id),
907 )
908 .record(&span)
909 .execute(&mut *self.conn)
910 .instrument(span)
911 .await?;
912 }
913
914 let res = sqlx::query!(
916 r#"
917 DELETE FROM oauth2_clients
918 WHERE oauth2_client_id = $1
919 "#,
920 Uuid::from(id),
921 )
922 .traced()
923 .execute(&mut *self.conn)
924 .await?;
925
926 DatabaseError::ensure_affected_rows(&res, 1)
927 }
928
929 #[tracing::instrument(
930 name = "db.oauth2_client.list",
931 skip_all,
932 fields(
933 db.query.text,
934 ),
935 err,
936 )]
937 async fn list(
938 &mut self,
939 filter: OAuth2ClientFilter<'_>,
940 pagination: Pagination,
941 ) -> Result<Page<Client>, Self::Error> {
942 let (sql, arguments) = Query::select()
943 .expr_as(
944 Expr::col((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)),
945 OAuth2ClientLookupIden::Oauth2ClientId,
946 )
947 .expr_as(
948 Expr::cust("metadata_digest"),
949 OAuth2ClientLookupIden::MetadataDigest,
950 )
951 .expr_as(
952 Expr::cust("encrypted_client_secret"),
953 OAuth2ClientLookupIden::EncryptedClientSecret,
954 )
955 .expr_as(
956 Expr::cust("application_type"),
957 OAuth2ClientLookupIden::ApplicationType,
958 )
959 .expr_as(
960 Expr::col((OAuth2Clients::Table, OAuth2Clients::RedirectUris)),
961 OAuth2ClientLookupIden::RedirectUris,
962 )
963 .expr_as(
964 Expr::cust("grant_type_authorization_code"),
965 OAuth2ClientLookupIden::GrantTypeAuthorizationCode,
966 )
967 .expr_as(
968 Expr::cust("grant_type_refresh_token"),
969 OAuth2ClientLookupIden::GrantTypeRefreshToken,
970 )
971 .expr_as(
972 Expr::cust("grant_type_client_credentials"),
973 OAuth2ClientLookupIden::GrantTypeClientCredentials,
974 )
975 .expr_as(
976 Expr::cust("grant_type_device_code"),
977 OAuth2ClientLookupIden::GrantTypeDeviceCode,
978 )
979 .expr_as(
980 Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientName)),
981 OAuth2ClientLookupIden::ClientName,
982 )
983 .expr_as(
984 Expr::col((OAuth2Clients::Table, OAuth2Clients::LogoUri)),
985 OAuth2ClientLookupIden::LogoUri,
986 )
987 .expr_as(
988 Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientUri)),
989 OAuth2ClientLookupIden::ClientUri,
990 )
991 .expr_as(Expr::cust("policy_uri"), OAuth2ClientLookupIden::PolicyUri)
992 .expr_as(Expr::cust("tos_uri"), OAuth2ClientLookupIden::TosUri)
993 .expr_as(Expr::cust("jwks_uri"), OAuth2ClientLookupIden::JwksUri)
994 .expr_as(Expr::cust("jwks"), OAuth2ClientLookupIden::Jwks)
995 .expr_as(
996 Expr::cust("id_token_signed_response_alg"),
997 OAuth2ClientLookupIden::IdTokenSignedResponseAlg,
998 )
999 .expr_as(
1000 Expr::cust("userinfo_signed_response_alg"),
1001 OAuth2ClientLookupIden::UserinfoSignedResponseAlg,
1002 )
1003 .expr_as(
1004 Expr::cust("token_endpoint_auth_method"),
1005 OAuth2ClientLookupIden::TokenEndpointAuthMethod,
1006 )
1007 .expr_as(
1008 Expr::cust("token_endpoint_auth_signing_alg"),
1009 OAuth2ClientLookupIden::TokenEndpointAuthSigningAlg,
1010 )
1011 .expr_as(
1012 Expr::cust("initiate_login_uri"),
1013 OAuth2ClientLookupIden::InitiateLoginUri,
1014 )
1015 .expr_as(
1016 Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)),
1017 OAuth2ClientLookupIden::IsStatic,
1018 )
1019 .from(OAuth2Clients::Table)
1020 .apply_filter(filter)
1021 .generate_pagination(
1022 (OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId),
1023 pagination,
1024 )
1025 .build_sqlx(PostgresQueryBuilder);
1026
1027 let edges: Vec<OAuth2ClientLookup> = sqlx::query_as_with(&sql, arguments)
1028 .traced()
1029 .fetch_all(&mut *self.conn)
1030 .await?;
1031
1032 let page = pagination.process(edges).try_map(Client::try_from)?;
1033
1034 Ok(page)
1035 }
1036
1037 #[tracing::instrument(
1038 name = "db.oauth2_client.count",
1039 skip_all,
1040 fields(
1041 db.query.text,
1042 ),
1043 err,
1044 )]
1045 async fn count(&mut self, filter: OAuth2ClientFilter<'_>) -> Result<usize, Self::Error> {
1046 let (sql, arguments) = Query::select()
1047 .expr(Expr::col((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)).count())
1048 .from(OAuth2Clients::Table)
1049 .apply_filter(filter)
1050 .build_sqlx(PostgresQueryBuilder);
1051
1052 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
1053 .traced()
1054 .fetch_one(&mut *self.conn)
1055 .await?;
1056
1057 count
1058 .try_into()
1059 .map_err(DatabaseError::to_invalid_operation)
1060 }
1061}