1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{Clock, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11 Page, Pagination,
12 pagination::Node,
13 upstream_oauth2::{
14 UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
15 },
16};
17use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
18use rand::RngCore;
19use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
20use sea_query_binder::SqlxBinder;
21use sqlx::{PgConnection, types::Json};
22use tracing::{Instrument, info_span};
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27 DatabaseError, DatabaseInconsistencyError,
28 filter::{Filter, StatementExt},
29 iden::UpstreamOAuthProviders,
30 pagination::QueryBuilderExt,
31 tracing::ExecuteExt,
32};
33
34pub struct PgUpstreamOAuthProviderRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgUpstreamOAuthProviderRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48#[derive(sqlx::FromRow)]
49#[enum_def]
50struct ProviderLookup {
51 upstream_oauth_provider_id: Uuid,
52 issuer: Option<String>,
53 human_name: Option<String>,
54 brand_name: Option<String>,
55 scope: String,
56 client_id: String,
57 encrypted_client_secret: Option<String>,
58 token_endpoint_signing_alg: Option<String>,
59 token_endpoint_auth_method: String,
60 id_token_signed_response_alg: String,
61 fetch_userinfo: bool,
62 userinfo_signed_response_alg: Option<String>,
63 created_at: DateTime<Utc>,
64 disabled_at: Option<DateTime<Utc>>,
65 claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
66 jwks_uri_override: Option<String>,
67 authorization_endpoint_override: Option<String>,
68 token_endpoint_override: Option<String>,
69 userinfo_endpoint_override: Option<String>,
70 discovery_mode: String,
71 pkce_mode: String,
72 response_mode: Option<String>,
73 additional_parameters: Option<Json<Vec<(String, String)>>>,
74 forward_login_hint: bool,
75 on_backchannel_logout: String,
76 registration_token_required: bool,
77}
78
79impl Node<Ulid> for ProviderLookup {
80 fn cursor(&self) -> Ulid {
81 self.upstream_oauth_provider_id.into()
82 }
83}
84
85impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
86 type Error = DatabaseInconsistencyError;
87
88 fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
89 let id = value.upstream_oauth_provider_id.into();
90 let scope = value.scope.parse().map_err(|e| {
91 DatabaseInconsistencyError::on("upstream_oauth_providers")
92 .column("scope")
93 .row(id)
94 .source(e)
95 })?;
96 let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
97 DatabaseInconsistencyError::on("upstream_oauth_providers")
98 .column("token_endpoint_auth_method")
99 .row(id)
100 .source(e)
101 })?;
102 let token_endpoint_signing_alg = value
103 .token_endpoint_signing_alg
104 .map(|x| x.parse())
105 .transpose()
106 .map_err(|e| {
107 DatabaseInconsistencyError::on("upstream_oauth_providers")
108 .column("token_endpoint_signing_alg")
109 .row(id)
110 .source(e)
111 })?;
112 let id_token_signed_response_alg =
113 value.id_token_signed_response_alg.parse().map_err(|e| {
114 DatabaseInconsistencyError::on("upstream_oauth_providers")
115 .column("id_token_signed_response_alg")
116 .row(id)
117 .source(e)
118 })?;
119
120 let userinfo_signed_response_alg = value
121 .userinfo_signed_response_alg
122 .map(|x| x.parse())
123 .transpose()
124 .map_err(|e| {
125 DatabaseInconsistencyError::on("upstream_oauth_providers")
126 .column("userinfo_signed_response_alg")
127 .row(id)
128 .source(e)
129 })?;
130
131 let authorization_endpoint_override = value
132 .authorization_endpoint_override
133 .map(|x| x.parse())
134 .transpose()
135 .map_err(|e| {
136 DatabaseInconsistencyError::on("upstream_oauth_providers")
137 .column("authorization_endpoint_override")
138 .row(id)
139 .source(e)
140 })?;
141
142 let token_endpoint_override = value
143 .token_endpoint_override
144 .map(|x| x.parse())
145 .transpose()
146 .map_err(|e| {
147 DatabaseInconsistencyError::on("upstream_oauth_providers")
148 .column("token_endpoint_override")
149 .row(id)
150 .source(e)
151 })?;
152
153 let userinfo_endpoint_override = value
154 .userinfo_endpoint_override
155 .map(|x| x.parse())
156 .transpose()
157 .map_err(|e| {
158 DatabaseInconsistencyError::on("upstream_oauth_providers")
159 .column("userinfo_endpoint_override")
160 .row(id)
161 .source(e)
162 })?;
163
164 let jwks_uri_override = value
165 .jwks_uri_override
166 .map(|x| x.parse())
167 .transpose()
168 .map_err(|e| {
169 DatabaseInconsistencyError::on("upstream_oauth_providers")
170 .column("jwks_uri_override")
171 .row(id)
172 .source(e)
173 })?;
174
175 let discovery_mode = value.discovery_mode.parse().map_err(|e| {
176 DatabaseInconsistencyError::on("upstream_oauth_providers")
177 .column("discovery_mode")
178 .row(id)
179 .source(e)
180 })?;
181
182 let pkce_mode = value.pkce_mode.parse().map_err(|e| {
183 DatabaseInconsistencyError::on("upstream_oauth_providers")
184 .column("pkce_mode")
185 .row(id)
186 .source(e)
187 })?;
188
189 let response_mode = value
190 .response_mode
191 .map(|x| x.parse())
192 .transpose()
193 .map_err(|e| {
194 DatabaseInconsistencyError::on("upstream_oauth_providers")
195 .column("response_mode")
196 .row(id)
197 .source(e)
198 })?;
199
200 let additional_authorization_parameters = value
201 .additional_parameters
202 .map(|Json(x)| x)
203 .unwrap_or_default();
204
205 let on_backchannel_logout = value.on_backchannel_logout.parse().map_err(|e| {
206 DatabaseInconsistencyError::on("upstream_oauth_providers")
207 .column("on_backchannel_logout")
208 .row(id)
209 .source(e)
210 })?;
211
212 Ok(UpstreamOAuthProvider {
213 id,
214 issuer: value.issuer,
215 human_name: value.human_name,
216 brand_name: value.brand_name,
217 scope,
218 client_id: value.client_id,
219 encrypted_client_secret: value.encrypted_client_secret,
220 token_endpoint_auth_method,
221 token_endpoint_signing_alg,
222 id_token_signed_response_alg,
223 fetch_userinfo: value.fetch_userinfo,
224 userinfo_signed_response_alg,
225 created_at: value.created_at,
226 disabled_at: value.disabled_at,
227 claims_imports: value.claims_imports.0,
228 authorization_endpoint_override,
229 token_endpoint_override,
230 userinfo_endpoint_override,
231 jwks_uri_override,
232 discovery_mode,
233 pkce_mode,
234 response_mode,
235 additional_authorization_parameters,
236 forward_login_hint: value.forward_login_hint,
237 on_backchannel_logout,
238 registration_token_required: value.registration_token_required,
239 })
240 }
241}
242
243impl Filter for UpstreamOAuthProviderFilter<'_> {
244 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
245 sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
246 Expr::col((
247 UpstreamOAuthProviders::Table,
248 UpstreamOAuthProviders::DisabledAt,
249 ))
250 .is_null()
251 .eq(enabled)
252 }))
253 }
254}
255
256#[async_trait]
257impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
258 type Error = DatabaseError;
259
260 #[tracing::instrument(
261 name = "db.upstream_oauth_provider.lookup",
262 skip_all,
263 fields(
264 db.query.text,
265 upstream_oauth_provider.id = %id,
266 ),
267 err,
268 )]
269 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
270 let res = sqlx::query_as!(
271 ProviderLookup,
272 r#"
273 SELECT
274 upstream_oauth_provider_id,
275 issuer,
276 human_name,
277 brand_name,
278 scope,
279 client_id,
280 encrypted_client_secret,
281 token_endpoint_signing_alg,
282 token_endpoint_auth_method,
283 id_token_signed_response_alg,
284 fetch_userinfo,
285 userinfo_signed_response_alg,
286 created_at,
287 disabled_at,
288 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
289 jwks_uri_override,
290 authorization_endpoint_override,
291 token_endpoint_override,
292 userinfo_endpoint_override,
293 discovery_mode,
294 pkce_mode,
295 response_mode,
296 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
297 forward_login_hint,
298 on_backchannel_logout,
299 registration_token_required
300 FROM upstream_oauth_providers
301 WHERE upstream_oauth_provider_id = $1
302 "#,
303 Uuid::from(id),
304 )
305 .traced()
306 .fetch_optional(&mut *self.conn)
307 .await?;
308
309 let res = res
310 .map(UpstreamOAuthProvider::try_from)
311 .transpose()
312 .map_err(DatabaseError::from)?;
313
314 Ok(res)
315 }
316
317 #[tracing::instrument(
318 name = "db.upstream_oauth_provider.add",
319 skip_all,
320 fields(
321 db.query.text,
322 upstream_oauth_provider.id,
323 upstream_oauth_provider.issuer = params.issuer,
324 upstream_oauth_provider.client_id = %params.client_id,
325 ),
326 err,
327 )]
328 async fn add(
329 &mut self,
330 rng: &mut (dyn RngCore + Send),
331 clock: &dyn Clock,
332 params: UpstreamOAuthProviderParams,
333 ) -> Result<UpstreamOAuthProvider, Self::Error> {
334 let created_at = clock.now();
335 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
336 tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
337
338 sqlx::query!(
339 r#"
340 INSERT INTO upstream_oauth_providers (
341 upstream_oauth_provider_id,
342 issuer,
343 human_name,
344 brand_name,
345 scope,
346 token_endpoint_auth_method,
347 token_endpoint_signing_alg,
348 id_token_signed_response_alg,
349 fetch_userinfo,
350 userinfo_signed_response_alg,
351 client_id,
352 encrypted_client_secret,
353 claims_imports,
354 authorization_endpoint_override,
355 token_endpoint_override,
356 userinfo_endpoint_override,
357 jwks_uri_override,
358 discovery_mode,
359 pkce_mode,
360 response_mode,
361 forward_login_hint,
362 on_backchannel_logout,
363 registration_token_required,
364 created_at
365 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
366 $12, $13, $14, $15, $16, $17, $18, $19, $20,
367 $21, $22, $23, $24)
368 "#,
369 Uuid::from(id),
370 params.issuer.as_deref(),
371 params.human_name.as_deref(),
372 params.brand_name.as_deref(),
373 params.scope.to_string(),
374 params.token_endpoint_auth_method.to_string(),
375 params
376 .token_endpoint_signing_alg
377 .as_ref()
378 .map(ToString::to_string),
379 params.id_token_signed_response_alg.to_string(),
380 params.fetch_userinfo,
381 params
382 .userinfo_signed_response_alg
383 .as_ref()
384 .map(ToString::to_string),
385 ¶ms.client_id,
386 params.encrypted_client_secret.as_deref(),
387 Json(¶ms.claims_imports) as _,
388 params
389 .authorization_endpoint_override
390 .as_ref()
391 .map(ToString::to_string),
392 params
393 .token_endpoint_override
394 .as_ref()
395 .map(ToString::to_string),
396 params
397 .userinfo_endpoint_override
398 .as_ref()
399 .map(ToString::to_string),
400 params.jwks_uri_override.as_ref().map(ToString::to_string),
401 params.discovery_mode.as_str(),
402 params.pkce_mode.as_str(),
403 params.response_mode.as_ref().map(ToString::to_string),
404 params.forward_login_hint,
405 params.on_backchannel_logout.as_str(),
406 params.registration_token_required,
407 created_at,
408 )
409 .traced()
410 .execute(&mut *self.conn)
411 .await?;
412
413 Ok(UpstreamOAuthProvider {
414 id,
415 issuer: params.issuer,
416 human_name: params.human_name,
417 brand_name: params.brand_name,
418 scope: params.scope,
419 client_id: params.client_id,
420 encrypted_client_secret: params.encrypted_client_secret,
421 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
422 token_endpoint_auth_method: params.token_endpoint_auth_method,
423 id_token_signed_response_alg: params.id_token_signed_response_alg,
424 fetch_userinfo: params.fetch_userinfo,
425 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
426 created_at,
427 disabled_at: None,
428 claims_imports: params.claims_imports,
429 authorization_endpoint_override: params.authorization_endpoint_override,
430 token_endpoint_override: params.token_endpoint_override,
431 userinfo_endpoint_override: params.userinfo_endpoint_override,
432 jwks_uri_override: params.jwks_uri_override,
433 discovery_mode: params.discovery_mode,
434 pkce_mode: params.pkce_mode,
435 response_mode: params.response_mode,
436 additional_authorization_parameters: params.additional_authorization_parameters,
437 on_backchannel_logout: params.on_backchannel_logout,
438 forward_login_hint: params.forward_login_hint,
439 registration_token_required: params.registration_token_required,
440 })
441 }
442
443 #[tracing::instrument(
444 name = "db.upstream_oauth_provider.delete_by_id",
445 skip_all,
446 fields(
447 db.query.text,
448 upstream_oauth_provider.id = %id,
449 ),
450 err,
451 )]
452 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
453 {
456 let span = info_span!(
457 "db.oauth2_client.delete_by_id.authorization_sessions",
458 upstream_oauth_provider.id = %id,
459 { DB_QUERY_TEXT } = tracing::field::Empty,
460 );
461 sqlx::query!(
462 r#"
463 DELETE FROM upstream_oauth_authorization_sessions
464 WHERE upstream_oauth_provider_id = $1
465 "#,
466 Uuid::from(id),
467 )
468 .record(&span)
469 .execute(&mut *self.conn)
470 .instrument(span)
471 .await?;
472 }
473
474 {
477 let span = info_span!(
478 "db.oauth2_client.delete_by_id.links",
479 upstream_oauth_provider.id = %id,
480 { DB_QUERY_TEXT } = tracing::field::Empty,
481 );
482 sqlx::query!(
483 r#"
484 DELETE FROM upstream_oauth_links
485 WHERE upstream_oauth_provider_id = $1
486 "#,
487 Uuid::from(id),
488 )
489 .record(&span)
490 .execute(&mut *self.conn)
491 .instrument(span)
492 .await?;
493 }
494
495 let res = sqlx::query!(
496 r#"
497 DELETE FROM upstream_oauth_providers
498 WHERE upstream_oauth_provider_id = $1
499 "#,
500 Uuid::from(id),
501 )
502 .traced()
503 .execute(&mut *self.conn)
504 .await?;
505
506 DatabaseError::ensure_affected_rows(&res, 1)
507 }
508
509 #[tracing::instrument(
510 name = "db.upstream_oauth_provider.add",
511 skip_all,
512 fields(
513 db.query.text,
514 upstream_oauth_provider.id = %id,
515 upstream_oauth_provider.issuer = params.issuer,
516 upstream_oauth_provider.client_id = %params.client_id,
517 ),
518 err,
519 )]
520 async fn upsert(
521 &mut self,
522 clock: &dyn Clock,
523 id: Ulid,
524 params: UpstreamOAuthProviderParams,
525 ) -> Result<UpstreamOAuthProvider, Self::Error> {
526 let created_at = clock.now();
527
528 let created_at = sqlx::query_scalar!(
529 r#"
530 INSERT INTO upstream_oauth_providers (
531 upstream_oauth_provider_id,
532 issuer,
533 human_name,
534 brand_name,
535 scope,
536 token_endpoint_auth_method,
537 token_endpoint_signing_alg,
538 id_token_signed_response_alg,
539 fetch_userinfo,
540 userinfo_signed_response_alg,
541 client_id,
542 encrypted_client_secret,
543 claims_imports,
544 authorization_endpoint_override,
545 token_endpoint_override,
546 userinfo_endpoint_override,
547 jwks_uri_override,
548 discovery_mode,
549 pkce_mode,
550 response_mode,
551 additional_parameters,
552 forward_login_hint,
553 ui_order,
554 on_backchannel_logout,
555 registration_token_required,
556 created_at
557 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
558 $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
559 $21, $22, $23, $24, $25, $26)
560 ON CONFLICT (upstream_oauth_provider_id)
561 DO UPDATE
562 SET
563 issuer = EXCLUDED.issuer,
564 human_name = EXCLUDED.human_name,
565 brand_name = EXCLUDED.brand_name,
566 scope = EXCLUDED.scope,
567 token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
568 token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
569 id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
570 fetch_userinfo = EXCLUDED.fetch_userinfo,
571 userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
572 disabled_at = NULL,
573 client_id = EXCLUDED.client_id,
574 encrypted_client_secret = EXCLUDED.encrypted_client_secret,
575 claims_imports = EXCLUDED.claims_imports,
576 authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
577 token_endpoint_override = EXCLUDED.token_endpoint_override,
578 userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
579 jwks_uri_override = EXCLUDED.jwks_uri_override,
580 discovery_mode = EXCLUDED.discovery_mode,
581 pkce_mode = EXCLUDED.pkce_mode,
582 response_mode = EXCLUDED.response_mode,
583 additional_parameters = EXCLUDED.additional_parameters,
584 forward_login_hint = EXCLUDED.forward_login_hint,
585 ui_order = EXCLUDED.ui_order,
586 on_backchannel_logout = EXCLUDED.on_backchannel_logout,
587 registration_token_required = EXCLUDED.registration_token_required
588 RETURNING created_at
589 "#,
590 Uuid::from(id),
591 params.issuer.as_deref(),
592 params.human_name.as_deref(),
593 params.brand_name.as_deref(),
594 params.scope.to_string(),
595 params.token_endpoint_auth_method.to_string(),
596 params
597 .token_endpoint_signing_alg
598 .as_ref()
599 .map(ToString::to_string),
600 params.id_token_signed_response_alg.to_string(),
601 params.fetch_userinfo,
602 params
603 .userinfo_signed_response_alg
604 .as_ref()
605 .map(ToString::to_string),
606 ¶ms.client_id,
607 params.encrypted_client_secret.as_deref(),
608 Json(¶ms.claims_imports) as _,
609 params
610 .authorization_endpoint_override
611 .as_ref()
612 .map(ToString::to_string),
613 params
614 .token_endpoint_override
615 .as_ref()
616 .map(ToString::to_string),
617 params
618 .userinfo_endpoint_override
619 .as_ref()
620 .map(ToString::to_string),
621 params.jwks_uri_override.as_ref().map(ToString::to_string),
622 params.discovery_mode.as_str(),
623 params.pkce_mode.as_str(),
624 params.response_mode.as_ref().map(ToString::to_string),
625 Json(¶ms.additional_authorization_parameters) as _,
626 params.forward_login_hint,
627 params.ui_order,
628 params.on_backchannel_logout.as_str(),
629 params.registration_token_required,
630 created_at,
631 )
632 .traced()
633 .fetch_one(&mut *self.conn)
634 .await?;
635
636 Ok(UpstreamOAuthProvider {
637 id,
638 issuer: params.issuer,
639 human_name: params.human_name,
640 brand_name: params.brand_name,
641 scope: params.scope,
642 client_id: params.client_id,
643 encrypted_client_secret: params.encrypted_client_secret,
644 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
645 token_endpoint_auth_method: params.token_endpoint_auth_method,
646 id_token_signed_response_alg: params.id_token_signed_response_alg,
647 fetch_userinfo: params.fetch_userinfo,
648 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
649 created_at,
650 disabled_at: None,
651 claims_imports: params.claims_imports,
652 authorization_endpoint_override: params.authorization_endpoint_override,
653 token_endpoint_override: params.token_endpoint_override,
654 userinfo_endpoint_override: params.userinfo_endpoint_override,
655 jwks_uri_override: params.jwks_uri_override,
656 discovery_mode: params.discovery_mode,
657 pkce_mode: params.pkce_mode,
658 response_mode: params.response_mode,
659 additional_authorization_parameters: params.additional_authorization_parameters,
660 forward_login_hint: params.forward_login_hint,
661 on_backchannel_logout: params.on_backchannel_logout,
662 registration_token_required: params.registration_token_required,
663 })
664 }
665
666 #[tracing::instrument(
667 name = "db.upstream_oauth_provider.disable",
668 skip_all,
669 fields(
670 db.query.text,
671 %upstream_oauth_provider.id,
672 ),
673 err,
674 )]
675 async fn disable(
676 &mut self,
677 clock: &dyn Clock,
678 mut upstream_oauth_provider: UpstreamOAuthProvider,
679 ) -> Result<UpstreamOAuthProvider, Self::Error> {
680 let disabled_at = clock.now();
681 let res = sqlx::query!(
682 r#"
683 UPDATE upstream_oauth_providers
684 SET disabled_at = $2
685 WHERE upstream_oauth_provider_id = $1
686 "#,
687 Uuid::from(upstream_oauth_provider.id),
688 disabled_at,
689 )
690 .traced()
691 .execute(&mut *self.conn)
692 .await?;
693
694 DatabaseError::ensure_affected_rows(&res, 1)?;
695
696 upstream_oauth_provider.disabled_at = Some(disabled_at);
697
698 Ok(upstream_oauth_provider)
699 }
700
701 #[tracing::instrument(
702 name = "db.upstream_oauth_provider.list",
703 skip_all,
704 fields(
705 db.query.text,
706 ),
707 err,
708 )]
709 async fn list(
710 &mut self,
711 filter: UpstreamOAuthProviderFilter<'_>,
712 pagination: Pagination,
713 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
714 let (sql, arguments) = Query::select()
715 .expr_as(
716 Expr::col((
717 UpstreamOAuthProviders::Table,
718 UpstreamOAuthProviders::UpstreamOAuthProviderId,
719 )),
720 ProviderLookupIden::UpstreamOauthProviderId,
721 )
722 .expr_as(
723 Expr::col((
724 UpstreamOAuthProviders::Table,
725 UpstreamOAuthProviders::Issuer,
726 )),
727 ProviderLookupIden::Issuer,
728 )
729 .expr_as(
730 Expr::col((
731 UpstreamOAuthProviders::Table,
732 UpstreamOAuthProviders::HumanName,
733 )),
734 ProviderLookupIden::HumanName,
735 )
736 .expr_as(
737 Expr::col((
738 UpstreamOAuthProviders::Table,
739 UpstreamOAuthProviders::BrandName,
740 )),
741 ProviderLookupIden::BrandName,
742 )
743 .expr_as(
744 Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
745 ProviderLookupIden::Scope,
746 )
747 .expr_as(
748 Expr::col((
749 UpstreamOAuthProviders::Table,
750 UpstreamOAuthProviders::ClientId,
751 )),
752 ProviderLookupIden::ClientId,
753 )
754 .expr_as(
755 Expr::col((
756 UpstreamOAuthProviders::Table,
757 UpstreamOAuthProviders::EncryptedClientSecret,
758 )),
759 ProviderLookupIden::EncryptedClientSecret,
760 )
761 .expr_as(
762 Expr::col((
763 UpstreamOAuthProviders::Table,
764 UpstreamOAuthProviders::TokenEndpointSigningAlg,
765 )),
766 ProviderLookupIden::TokenEndpointSigningAlg,
767 )
768 .expr_as(
769 Expr::col((
770 UpstreamOAuthProviders::Table,
771 UpstreamOAuthProviders::TokenEndpointAuthMethod,
772 )),
773 ProviderLookupIden::TokenEndpointAuthMethod,
774 )
775 .expr_as(
776 Expr::col((
777 UpstreamOAuthProviders::Table,
778 UpstreamOAuthProviders::IdTokenSignedResponseAlg,
779 )),
780 ProviderLookupIden::IdTokenSignedResponseAlg,
781 )
782 .expr_as(
783 Expr::col((
784 UpstreamOAuthProviders::Table,
785 UpstreamOAuthProviders::FetchUserinfo,
786 )),
787 ProviderLookupIden::FetchUserinfo,
788 )
789 .expr_as(
790 Expr::col((
791 UpstreamOAuthProviders::Table,
792 UpstreamOAuthProviders::UserinfoSignedResponseAlg,
793 )),
794 ProviderLookupIden::UserinfoSignedResponseAlg,
795 )
796 .expr_as(
797 Expr::col((
798 UpstreamOAuthProviders::Table,
799 UpstreamOAuthProviders::CreatedAt,
800 )),
801 ProviderLookupIden::CreatedAt,
802 )
803 .expr_as(
804 Expr::col((
805 UpstreamOAuthProviders::Table,
806 UpstreamOAuthProviders::DisabledAt,
807 )),
808 ProviderLookupIden::DisabledAt,
809 )
810 .expr_as(
811 Expr::col((
812 UpstreamOAuthProviders::Table,
813 UpstreamOAuthProviders::ClaimsImports,
814 )),
815 ProviderLookupIden::ClaimsImports,
816 )
817 .expr_as(
818 Expr::col((
819 UpstreamOAuthProviders::Table,
820 UpstreamOAuthProviders::JwksUriOverride,
821 )),
822 ProviderLookupIden::JwksUriOverride,
823 )
824 .expr_as(
825 Expr::col((
826 UpstreamOAuthProviders::Table,
827 UpstreamOAuthProviders::TokenEndpointOverride,
828 )),
829 ProviderLookupIden::TokenEndpointOverride,
830 )
831 .expr_as(
832 Expr::col((
833 UpstreamOAuthProviders::Table,
834 UpstreamOAuthProviders::AuthorizationEndpointOverride,
835 )),
836 ProviderLookupIden::AuthorizationEndpointOverride,
837 )
838 .expr_as(
839 Expr::col((
840 UpstreamOAuthProviders::Table,
841 UpstreamOAuthProviders::UserinfoEndpointOverride,
842 )),
843 ProviderLookupIden::UserinfoEndpointOverride,
844 )
845 .expr_as(
846 Expr::col((
847 UpstreamOAuthProviders::Table,
848 UpstreamOAuthProviders::DiscoveryMode,
849 )),
850 ProviderLookupIden::DiscoveryMode,
851 )
852 .expr_as(
853 Expr::col((
854 UpstreamOAuthProviders::Table,
855 UpstreamOAuthProviders::PkceMode,
856 )),
857 ProviderLookupIden::PkceMode,
858 )
859 .expr_as(
860 Expr::col((
861 UpstreamOAuthProviders::Table,
862 UpstreamOAuthProviders::ResponseMode,
863 )),
864 ProviderLookupIden::ResponseMode,
865 )
866 .expr_as(
867 Expr::col((
868 UpstreamOAuthProviders::Table,
869 UpstreamOAuthProviders::AdditionalParameters,
870 )),
871 ProviderLookupIden::AdditionalParameters,
872 )
873 .expr_as(
874 Expr::col((
875 UpstreamOAuthProviders::Table,
876 UpstreamOAuthProviders::ForwardLoginHint,
877 )),
878 ProviderLookupIden::ForwardLoginHint,
879 )
880 .expr_as(
881 Expr::col((
882 UpstreamOAuthProviders::Table,
883 UpstreamOAuthProviders::OnBackchannelLogout,
884 )),
885 ProviderLookupIden::OnBackchannelLogout,
886 )
887 .expr_as(
888 Expr::col((
889 UpstreamOAuthProviders::Table,
890 UpstreamOAuthProviders::RegistrationTokenRequired,
891 )),
892 ProviderLookupIden::RegistrationTokenRequired,
893 )
894 .from(UpstreamOAuthProviders::Table)
895 .apply_filter(filter)
896 .generate_pagination(
897 (
898 UpstreamOAuthProviders::Table,
899 UpstreamOAuthProviders::UpstreamOAuthProviderId,
900 ),
901 pagination,
902 )
903 .build_sqlx(PostgresQueryBuilder);
904
905 let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
906 .traced()
907 .fetch_all(&mut *self.conn)
908 .await?;
909
910 let page = pagination
911 .process(edges)
912 .try_map(UpstreamOAuthProvider::try_from)?;
913
914 return Ok(page);
915 }
916
917 #[tracing::instrument(
918 name = "db.upstream_oauth_provider.count",
919 skip_all,
920 fields(
921 db.query.text,
922 ),
923 err,
924 )]
925 async fn count(
926 &mut self,
927 filter: UpstreamOAuthProviderFilter<'_>,
928 ) -> Result<usize, Self::Error> {
929 let (sql, arguments) = Query::select()
930 .expr(
931 Expr::col((
932 UpstreamOAuthProviders::Table,
933 UpstreamOAuthProviders::UpstreamOAuthProviderId,
934 ))
935 .count(),
936 )
937 .from(UpstreamOAuthProviders::Table)
938 .apply_filter(filter)
939 .build_sqlx(PostgresQueryBuilder);
940
941 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
942 .traced()
943 .fetch_one(&mut *self.conn)
944 .await?;
945
946 count
947 .try_into()
948 .map_err(DatabaseError::to_invalid_operation)
949 }
950
951 #[tracing::instrument(
952 name = "db.upstream_oauth_provider.all_enabled",
953 skip_all,
954 fields(
955 db.query.text,
956 ),
957 err,
958 )]
959 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
960 let res = sqlx::query_as!(
961 ProviderLookup,
962 r#"
963 SELECT
964 upstream_oauth_provider_id,
965 issuer,
966 human_name,
967 brand_name,
968 scope,
969 client_id,
970 encrypted_client_secret,
971 token_endpoint_signing_alg,
972 token_endpoint_auth_method,
973 id_token_signed_response_alg,
974 fetch_userinfo,
975 userinfo_signed_response_alg,
976 created_at,
977 disabled_at,
978 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
979 jwks_uri_override,
980 authorization_endpoint_override,
981 token_endpoint_override,
982 userinfo_endpoint_override,
983 discovery_mode,
984 pkce_mode,
985 response_mode,
986 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
987 forward_login_hint,
988 on_backchannel_logout,
989 registration_token_required
990
991 FROM upstream_oauth_providers
992 WHERE disabled_at IS NULL
993 ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
994 "#,
995 )
996 .traced()
997 .fetch_all(&mut *self.conn)
998 .await?;
999
1000 let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
1001 Ok(res?)
1002 }
1003}