Skip to main content

mas_storage_pg/oauth2/
client.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::{
9    collections::{BTreeMap, BTreeSet},
10    string::ToString,
11};
12
13use async_trait::async_trait;
14use mas_data_model::{Client, Clock, JwksOrJwksUri};
15use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
16use mas_jose::jwk::PublicJsonWebKeySet;
17use mas_storage::{
18    Page, Pagination,
19    oauth2::{OAuth2ClientFilter, OAuth2ClientKind, OAuth2ClientRepository},
20    pagination::Node,
21};
22use oauth2_types::{oidc::ApplicationType, requests::GrantType};
23use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
24use rand::RngCore;
25use sea_query::{
26    Expr, PostgresQueryBuilder, Query, SimpleExpr, enum_def, extension::postgres::PgExpr as _,
27};
28use sea_query_binder::SqlxBinder;
29use sqlx::PgConnection;
30use tracing::{Instrument, info_span};
31use ulid::Ulid;
32use url::Url;
33use uuid::Uuid;
34
35use crate::{
36    DatabaseError, DatabaseInconsistencyError,
37    filter::{Filter, StatementExt},
38    iden::{OAuth2Clients, OAuth2Sessions},
39    pagination::QueryBuilderExt,
40    tracing::ExecuteExt,
41};
42
43/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
44pub struct PgOAuth2ClientRepository<'c> {
45    conn: &'c mut PgConnection,
46}
47
48impl<'c> PgOAuth2ClientRepository<'c> {
49    /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
50    /// connection
51    pub fn new(conn: &'c mut PgConnection) -> Self {
52        Self { conn }
53    }
54}
55
56#[expect(clippy::struct_excessive_bools)]
57#[derive(Debug, sqlx::FromRow)]
58#[enum_def]
59struct OAuth2ClientLookup {
60    oauth2_client_id: Uuid,
61    metadata_digest: Option<String>,
62    encrypted_client_secret: Option<String>,
63    application_type: Option<String>,
64    redirect_uris: Vec<String>,
65    grant_type_authorization_code: bool,
66    grant_type_refresh_token: bool,
67    grant_type_client_credentials: bool,
68    grant_type_device_code: bool,
69    client_name: Option<String>,
70    logo_uri: Option<String>,
71    client_uri: Option<String>,
72    policy_uri: Option<String>,
73    tos_uri: Option<String>,
74    jwks_uri: Option<String>,
75    jwks: Option<serde_json::Value>,
76    id_token_signed_response_alg: Option<String>,
77    userinfo_signed_response_alg: Option<String>,
78    token_endpoint_auth_method: Option<String>,
79    token_endpoint_auth_signing_alg: Option<String>,
80    initiate_login_uri: Option<String>,
81    is_static: bool,
82}
83
84impl Node<Ulid> for OAuth2ClientLookup {
85    fn cursor(&self) -> Ulid {
86        self.oauth2_client_id.into()
87    }
88}
89
90impl Filter for OAuth2ClientFilter<'_> {
91    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
92        sea_query::Condition::all()
93            .add_option(self.kind().map(|kind| {
94                let is_static = matches!(kind, OAuth2ClientKind::Static);
95                Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).eq(is_static)
96            }))
97            .add_option(self.client_name().map(|client_name| {
98                Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientName))
99                    .ilike(format!("%{client_name}%"))
100            }))
101            .add_option(self.client_uri().map(|client_uri| {
102                Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientUri))
103                    .ilike(format!("%{client_uri}%"))
104            }))
105            .add_option(self.grant_type().map(|grant_type| -> SimpleExpr {
106                let column = match grant_type {
107                    GrantType::AuthorizationCode => OAuth2Clients::GrantTypeAuthorizationCode,
108                    GrantType::RefreshToken => OAuth2Clients::GrantTypeRefreshToken,
109                    GrantType::ClientCredentials => OAuth2Clients::GrantTypeClientCredentials,
110                    GrantType::DeviceCode => OAuth2Clients::GrantTypeDeviceCode,
111                    // The other grant types don't have a dedicated column, so no
112                    // client can declare them: the filter matches nothing.
113                    _ => return Expr::val(false).into(),
114                };
115                Expr::col((OAuth2Clients::Table, column)).eq(true)
116            }))
117            .add_option(self.has_active_sessions().map(|has| -> SimpleExpr {
118                let exists = Expr::exists(
119                    Query::select()
120                        .expr(Expr::cust("1"))
121                        .from(OAuth2Sessions::Table)
122                        .and_where(
123                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
124                                .equals((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)),
125                        )
126                        .and_where(
127                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
128                                .is_null(),
129                        )
130                        .take(),
131                );
132                if has { exists } else { exists.not() }
133            }))
134    }
135}
136
137impl TryFrom<OAuth2ClientLookup> for Client {
138    type Error = DatabaseInconsistencyError;
139
140    fn try_from(value: OAuth2ClientLookup) -> Result<Self, Self::Error> {
141        let id = Ulid::from(value.oauth2_client_id);
142
143        let redirect_uris: Result<Vec<Url>, _> =
144            value.redirect_uris.iter().map(|s| s.parse()).collect();
145        let redirect_uris = redirect_uris.map_err(|e| {
146            DatabaseInconsistencyError::on("oauth2_clients")
147                .column("redirect_uris")
148                .row(id)
149                .source(e)
150        })?;
151
152        let application_type = value
153            .application_type
154            .map(|s| s.parse())
155            .transpose()
156            .map_err(|e| {
157                DatabaseInconsistencyError::on("oauth2_clients")
158                    .column("application_type")
159                    .row(id)
160                    .source(e)
161            })?;
162
163        let mut grant_types = Vec::new();
164        if value.grant_type_authorization_code {
165            grant_types.push(GrantType::AuthorizationCode);
166        }
167        if value.grant_type_refresh_token {
168            grant_types.push(GrantType::RefreshToken);
169        }
170        if value.grant_type_client_credentials {
171            grant_types.push(GrantType::ClientCredentials);
172        }
173        if value.grant_type_device_code {
174            grant_types.push(GrantType::DeviceCode);
175        }
176
177        let logo_uri = value.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
178            DatabaseInconsistencyError::on("oauth2_clients")
179                .column("logo_uri")
180                .row(id)
181                .source(e)
182        })?;
183
184        let client_uri = value
185            .client_uri
186            .map(|s| s.parse())
187            .transpose()
188            .map_err(|e| {
189                DatabaseInconsistencyError::on("oauth2_clients")
190                    .column("client_uri")
191                    .row(id)
192                    .source(e)
193            })?;
194
195        let policy_uri = value
196            .policy_uri
197            .map(|s| s.parse())
198            .transpose()
199            .map_err(|e| {
200                DatabaseInconsistencyError::on("oauth2_clients")
201                    .column("policy_uri")
202                    .row(id)
203                    .source(e)
204            })?;
205
206        let tos_uri = value.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
207            DatabaseInconsistencyError::on("oauth2_clients")
208                .column("tos_uri")
209                .row(id)
210                .source(e)
211        })?;
212
213        let id_token_signed_response_alg = value
214            .id_token_signed_response_alg
215            .map(|s| s.parse())
216            .transpose()
217            .map_err(|e| {
218                DatabaseInconsistencyError::on("oauth2_clients")
219                    .column("id_token_signed_response_alg")
220                    .row(id)
221                    .source(e)
222            })?;
223
224        let userinfo_signed_response_alg = value
225            .userinfo_signed_response_alg
226            .map(|s| s.parse())
227            .transpose()
228            .map_err(|e| {
229                DatabaseInconsistencyError::on("oauth2_clients")
230                    .column("userinfo_signed_response_alg")
231                    .row(id)
232                    .source(e)
233            })?;
234
235        let token_endpoint_auth_method = value
236            .token_endpoint_auth_method
237            .map(|s| s.parse())
238            .transpose()
239            .map_err(|e| {
240                DatabaseInconsistencyError::on("oauth2_clients")
241                    .column("token_endpoint_auth_method")
242                    .row(id)
243                    .source(e)
244            })?;
245
246        let token_endpoint_auth_signing_alg = value
247            .token_endpoint_auth_signing_alg
248            .map(|s| s.parse())
249            .transpose()
250            .map_err(|e| {
251                DatabaseInconsistencyError::on("oauth2_clients")
252                    .column("token_endpoint_auth_signing_alg")
253                    .row(id)
254                    .source(e)
255            })?;
256
257        let initiate_login_uri = value
258            .initiate_login_uri
259            .map(|s| s.parse())
260            .transpose()
261            .map_err(|e| {
262                DatabaseInconsistencyError::on("oauth2_clients")
263                    .column("initiate_login_uri")
264                    .row(id)
265                    .source(e)
266            })?;
267
268        let jwks = match (value.jwks, value.jwks_uri) {
269            (None, None) => None,
270            (Some(jwks), None) => {
271                let jwks = serde_json::from_value(jwks).map_err(|e| {
272                    DatabaseInconsistencyError::on("oauth2_clients")
273                        .column("jwks")
274                        .row(id)
275                        .source(e)
276                })?;
277                Some(JwksOrJwksUri::Jwks(jwks))
278            }
279            (None, Some(jwks_uri)) => {
280                let jwks_uri = jwks_uri.parse().map_err(|e| {
281                    DatabaseInconsistencyError::on("oauth2_clients")
282                        .column("jwks_uri")
283                        .row(id)
284                        .source(e)
285                })?;
286
287                Some(JwksOrJwksUri::JwksUri(jwks_uri))
288            }
289            _ => {
290                return Err(DatabaseInconsistencyError::on("oauth2_clients")
291                    .column("jwks(_uri)")
292                    .row(id));
293            }
294        };
295
296        Ok(Client {
297            id,
298            client_id: id.to_string(),
299            metadata_digest: value.metadata_digest,
300            encrypted_client_secret: value.encrypted_client_secret,
301            application_type,
302            redirect_uris,
303            grant_types,
304            client_name: value.client_name,
305            logo_uri,
306            client_uri,
307            policy_uri,
308            tos_uri,
309            jwks,
310            id_token_signed_response_alg,
311            userinfo_signed_response_alg,
312            token_endpoint_auth_method,
313            token_endpoint_auth_signing_alg,
314            initiate_login_uri,
315            is_static: value.is_static,
316        })
317    }
318}
319
320#[async_trait]
321impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
322    type Error = DatabaseError;
323
324    #[tracing::instrument(
325        name = "db.oauth2_client.lookup",
326        skip_all,
327        fields(
328            db.query.text,
329            oauth2_client.id = %id,
330        ),
331        err,
332    )]
333    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
334        let res = sqlx::query_as!(
335            OAuth2ClientLookup,
336            r#"
337                SELECT oauth2_client_id
338                     , metadata_digest
339                     , encrypted_client_secret
340                     , application_type
341                     , redirect_uris
342                     , grant_type_authorization_code
343                     , grant_type_refresh_token
344                     , grant_type_client_credentials
345                     , grant_type_device_code
346                     , client_name
347                     , logo_uri
348                     , client_uri
349                     , policy_uri
350                     , tos_uri
351                     , jwks_uri
352                     , jwks
353                     , id_token_signed_response_alg
354                     , userinfo_signed_response_alg
355                     , token_endpoint_auth_method
356                     , token_endpoint_auth_signing_alg
357                     , initiate_login_uri
358                     , is_static
359                FROM oauth2_clients c
360
361                WHERE oauth2_client_id = $1
362            "#,
363            Uuid::from(id),
364        )
365        .traced()
366        .fetch_optional(&mut *self.conn)
367        .await?;
368
369        let Some(res) = res else { return Ok(None) };
370
371        Ok(Some(res.try_into()?))
372    }
373
374    #[tracing::instrument(
375        name = "db.oauth2_client.find_by_metadata_digest",
376        skip_all,
377        fields(
378            db.query.text,
379        ),
380        err,
381    )]
382    async fn find_by_metadata_digest(
383        &mut self,
384        digest: &str,
385    ) -> Result<Option<Client>, Self::Error> {
386        let res = sqlx::query_as!(
387            OAuth2ClientLookup,
388            r#"
389                SELECT oauth2_client_id
390                    , metadata_digest
391                    , encrypted_client_secret
392                    , application_type
393                    , redirect_uris
394                    , grant_type_authorization_code
395                    , grant_type_refresh_token
396                    , grant_type_client_credentials
397                    , grant_type_device_code
398                    , client_name
399                    , logo_uri
400                    , client_uri
401                    , policy_uri
402                    , tos_uri
403                    , jwks_uri
404                    , jwks
405                    , id_token_signed_response_alg
406                    , userinfo_signed_response_alg
407                    , token_endpoint_auth_method
408                    , token_endpoint_auth_signing_alg
409                    , initiate_login_uri
410                    , is_static
411                FROM oauth2_clients
412                WHERE metadata_digest = $1
413            "#,
414            digest,
415        )
416        .traced()
417        .fetch_optional(&mut *self.conn)
418        .await?;
419
420        let Some(res) = res else { return Ok(None) };
421
422        Ok(Some(res.try_into()?))
423    }
424
425    #[tracing::instrument(
426        name = "db.oauth2_client.load_batch",
427        skip_all,
428        fields(
429            db.query.text,
430        ),
431        err,
432    )]
433    async fn load_batch(
434        &mut self,
435        ids: BTreeSet<Ulid>,
436    ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
437        let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
438        let res = sqlx::query_as!(
439            OAuth2ClientLookup,
440            r#"
441                SELECT oauth2_client_id
442                     , metadata_digest
443                     , encrypted_client_secret
444                     , application_type
445                     , redirect_uris
446                     , grant_type_authorization_code
447                     , grant_type_refresh_token
448                     , grant_type_client_credentials
449                     , grant_type_device_code
450                     , client_name
451                     , logo_uri
452                     , client_uri
453                     , policy_uri
454                     , tos_uri
455                     , jwks_uri
456                     , jwks
457                     , id_token_signed_response_alg
458                     , userinfo_signed_response_alg
459                     , token_endpoint_auth_method
460                     , token_endpoint_auth_signing_alg
461                     , initiate_login_uri
462                     , is_static
463                FROM oauth2_clients c
464
465                WHERE oauth2_client_id = ANY($1::uuid[])
466            "#,
467            &ids,
468        )
469        .traced()
470        .fetch_all(&mut *self.conn)
471        .await?;
472
473        res.into_iter()
474            .map(|r| {
475                r.try_into()
476                    .map(|c: Client| (c.id, c))
477                    .map_err(DatabaseError::from)
478            })
479            .collect()
480    }
481
482    #[tracing::instrument(
483        name = "db.oauth2_client.add",
484        skip_all,
485        fields(
486            db.query.text,
487            client.id,
488            client.name = client_name
489        ),
490        err,
491    )]
492    async fn add(
493        &mut self,
494        rng: &mut (dyn RngCore + Send),
495        clock: &dyn Clock,
496        redirect_uris: Vec<Url>,
497        metadata_digest: Option<String>,
498        encrypted_client_secret: Option<String>,
499        application_type: Option<ApplicationType>,
500        grant_types: Vec<GrantType>,
501        client_name: Option<String>,
502        logo_uri: Option<Url>,
503        client_uri: Option<Url>,
504        policy_uri: Option<Url>,
505        tos_uri: Option<Url>,
506        jwks_uri: Option<Url>,
507        jwks: Option<PublicJsonWebKeySet>,
508        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
509        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
510        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
511        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
512        initiate_login_uri: Option<Url>,
513    ) -> Result<Client, Self::Error> {
514        let now = clock.now();
515        let id = Ulid::from_datetime_with_source(now.into(), rng);
516        tracing::Span::current().record("client.id", tracing::field::display(id));
517
518        let jwks_json = jwks
519            .as_ref()
520            .map(serde_json::to_value)
521            .transpose()
522            .map_err(DatabaseError::to_invalid_operation)?;
523
524        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
525
526        sqlx::query!(
527            r#"
528                INSERT INTO oauth2_clients
529                    ( oauth2_client_id
530                    , metadata_digest
531                    , encrypted_client_secret
532                    , application_type
533                    , redirect_uris
534                    , grant_type_authorization_code
535                    , grant_type_refresh_token
536                    , grant_type_client_credentials
537                    , grant_type_device_code
538                    , client_name
539                    , logo_uri
540                    , client_uri
541                    , policy_uri
542                    , tos_uri
543                    , jwks_uri
544                    , jwks
545                    , id_token_signed_response_alg
546                    , userinfo_signed_response_alg
547                    , token_endpoint_auth_method
548                    , token_endpoint_auth_signing_alg
549                    , initiate_login_uri
550                    , is_static
551                    )
552                VALUES
553                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
554                    $14, $15, $16, $17, $18, $19, $20, $21, FALSE)
555            "#,
556            Uuid::from(id),
557            metadata_digest,
558            encrypted_client_secret,
559            application_type.as_ref().map(ToString::to_string),
560            &redirect_uris_array,
561            grant_types.contains(&GrantType::AuthorizationCode),
562            grant_types.contains(&GrantType::RefreshToken),
563            grant_types.contains(&GrantType::ClientCredentials),
564            grant_types.contains(&GrantType::DeviceCode),
565            client_name,
566            logo_uri.as_ref().map(Url::as_str),
567            client_uri.as_ref().map(Url::as_str),
568            policy_uri.as_ref().map(Url::as_str),
569            tos_uri.as_ref().map(Url::as_str),
570            jwks_uri.as_ref().map(Url::as_str),
571            jwks_json,
572            id_token_signed_response_alg
573                .as_ref()
574                .map(ToString::to_string),
575            userinfo_signed_response_alg
576                .as_ref()
577                .map(ToString::to_string),
578            token_endpoint_auth_method.as_ref().map(ToString::to_string),
579            token_endpoint_auth_signing_alg
580                .as_ref()
581                .map(ToString::to_string),
582            initiate_login_uri.as_ref().map(Url::as_str),
583        )
584        .traced()
585        .execute(&mut *self.conn)
586        .await?;
587
588        let jwks = match (jwks, jwks_uri) {
589            (None, None) => None,
590            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
591            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
592            _ => return Err(DatabaseError::invalid_operation()),
593        };
594
595        Ok(Client {
596            id,
597            client_id: id.to_string(),
598            metadata_digest: None,
599            encrypted_client_secret,
600            application_type,
601            redirect_uris,
602            grant_types,
603            client_name,
604            logo_uri,
605            client_uri,
606            policy_uri,
607            tos_uri,
608            jwks,
609            id_token_signed_response_alg,
610            userinfo_signed_response_alg,
611            token_endpoint_auth_method,
612            token_endpoint_auth_signing_alg,
613            initiate_login_uri,
614            is_static: false,
615        })
616    }
617
618    #[tracing::instrument(
619        name = "db.oauth2_client.upsert_static",
620        skip_all,
621        fields(
622            db.query.text,
623            client.id = %client_id,
624        ),
625        err,
626    )]
627    async fn upsert_static(
628        &mut self,
629        client_id: Ulid,
630        client_name: Option<String>,
631        client_auth_method: OAuthClientAuthenticationMethod,
632        encrypted_client_secret: Option<String>,
633        jwks: Option<PublicJsonWebKeySet>,
634        jwks_uri: Option<Url>,
635        redirect_uris: Vec<Url>,
636    ) -> Result<Client, Self::Error> {
637        let jwks_json = jwks
638            .as_ref()
639            .map(serde_json::to_value)
640            .transpose()
641            .map_err(DatabaseError::to_invalid_operation)?;
642
643        let client_auth_method = client_auth_method.to_string();
644        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
645
646        sqlx::query!(
647            r#"
648                INSERT INTO oauth2_clients
649                    ( oauth2_client_id
650                    , encrypted_client_secret
651                    , redirect_uris
652                    , grant_type_authorization_code
653                    , grant_type_refresh_token
654                    , grant_type_client_credentials
655                    , grant_type_device_code
656                    , token_endpoint_auth_method
657                    , jwks
658                    , client_name
659                    , jwks_uri
660                    , is_static
661                    )
662                VALUES
663                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, TRUE)
664                ON CONFLICT (oauth2_client_id)
665                DO
666                    UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
667                             , redirect_uris = EXCLUDED.redirect_uris
668                             , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
669                             , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
670                             , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
671                             , grant_type_device_code = EXCLUDED.grant_type_device_code
672                             , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
673                             , jwks = EXCLUDED.jwks
674                             , client_name = EXCLUDED.client_name
675                             , jwks_uri = EXCLUDED.jwks_uri
676                             , is_static = TRUE
677            "#,
678            Uuid::from(client_id),
679            encrypted_client_secret,
680            &redirect_uris_array,
681            true,
682            true,
683            true,
684            true,
685            client_auth_method,
686            jwks_json,
687            client_name,
688            jwks_uri.as_ref().map(Url::as_str),
689        )
690        .traced()
691        .execute(&mut *self.conn)
692        .await?;
693
694        let jwks = match (jwks, jwks_uri) {
695            (None, None) => None,
696            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
697            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
698            _ => return Err(DatabaseError::invalid_operation()),
699        };
700
701        Ok(Client {
702            id: client_id,
703            client_id: client_id.to_string(),
704            metadata_digest: None,
705            encrypted_client_secret,
706            application_type: None,
707            redirect_uris,
708            grant_types: vec![
709                GrantType::AuthorizationCode,
710                GrantType::RefreshToken,
711                GrantType::ClientCredentials,
712            ],
713            client_name,
714            logo_uri: None,
715            client_uri: None,
716            policy_uri: None,
717            tos_uri: None,
718            jwks,
719            id_token_signed_response_alg: None,
720            userinfo_signed_response_alg: None,
721            token_endpoint_auth_method: None,
722            token_endpoint_auth_signing_alg: None,
723            initiate_login_uri: None,
724            is_static: true,
725        })
726    }
727
728    #[tracing::instrument(
729        name = "db.oauth2_client.all_static",
730        skip_all,
731        fields(
732            db.query.text,
733        ),
734        err,
735    )]
736    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
737        let res = sqlx::query_as!(
738            OAuth2ClientLookup,
739            r#"
740                SELECT oauth2_client_id
741                     , metadata_digest
742                     , encrypted_client_secret
743                     , application_type
744                     , redirect_uris
745                     , grant_type_authorization_code
746                     , grant_type_refresh_token
747                     , grant_type_client_credentials
748                     , grant_type_device_code
749                     , client_name
750                     , logo_uri
751                     , client_uri
752                     , policy_uri
753                     , tos_uri
754                     , jwks_uri
755                     , jwks
756                     , id_token_signed_response_alg
757                     , userinfo_signed_response_alg
758                     , token_endpoint_auth_method
759                     , token_endpoint_auth_signing_alg
760                     , initiate_login_uri
761                     , is_static
762                FROM oauth2_clients c
763                WHERE is_static = TRUE
764            "#,
765        )
766        .traced()
767        .fetch_all(&mut *self.conn)
768        .await?;
769
770        res.into_iter()
771            .map(|r| r.try_into().map_err(DatabaseError::from))
772            .collect()
773    }
774
775    #[tracing::instrument(
776        name = "db.oauth2_client.delete_by_id",
777        skip_all,
778        fields(
779            db.query.text,
780            client.id = %id,
781        ),
782        err,
783    )]
784    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
785        // Delete the authorization grants
786        {
787            let span = info_span!(
788                "db.oauth2_client.delete_by_id.authorization_grants",
789                { DB_QUERY_TEXT } = tracing::field::Empty,
790            );
791
792            sqlx::query!(
793                r#"
794                    DELETE FROM oauth2_authorization_grants
795                    WHERE oauth2_client_id = $1
796                "#,
797                Uuid::from(id),
798            )
799            .record(&span)
800            .execute(&mut *self.conn)
801            .instrument(span)
802            .await?;
803        }
804
805        // Delete the OAuth 2 sessions related data
806        {
807            let span = info_span!(
808                "db.oauth2_client.delete_by_id.access_tokens",
809                { DB_QUERY_TEXT } = tracing::field::Empty,
810            );
811
812            sqlx::query!(
813                r#"
814                    DELETE FROM oauth2_access_tokens
815                    WHERE oauth2_session_id IN (
816                        SELECT oauth2_session_id
817                        FROM oauth2_sessions
818                        WHERE oauth2_client_id = $1
819                    )
820                "#,
821                Uuid::from(id),
822            )
823            .record(&span)
824            .execute(&mut *self.conn)
825            .instrument(span)
826            .await?;
827        }
828
829        {
830            let span = info_span!(
831                "db.oauth2_client.delete_by_id.refresh_tokens",
832                { DB_QUERY_TEXT } = tracing::field::Empty,
833            );
834
835            sqlx::query!(
836                r#"
837                    DELETE FROM oauth2_refresh_tokens
838                    WHERE oauth2_session_id IN (
839                        SELECT oauth2_session_id
840                        FROM oauth2_sessions
841                        WHERE oauth2_client_id = $1
842                    )
843                "#,
844                Uuid::from(id),
845            )
846            .record(&span)
847            .execute(&mut *self.conn)
848            .instrument(span)
849            .await?;
850        }
851
852        {
853            let span = info_span!(
854                "db.oauth2_client.delete_by_id.sessions",
855                { DB_QUERY_TEXT } = tracing::field::Empty,
856            );
857
858            sqlx::query!(
859                r#"
860                    DELETE FROM oauth2_sessions
861                    WHERE oauth2_client_id = $1
862                "#,
863                Uuid::from(id),
864            )
865            .record(&span)
866            .execute(&mut *self.conn)
867            .instrument(span)
868            .await?;
869        }
870
871        // Delete any personal access tokens & sessions owned
872        // by the client
873        {
874            let span = info_span!(
875                "db.oauth2_client.delete_by_id.personal_access_tokens",
876                { DB_QUERY_TEXT } = tracing::field::Empty,
877            );
878
879            sqlx::query!(
880                r#"
881                    DELETE FROM personal_access_tokens
882                    WHERE personal_session_id IN (
883                        SELECT personal_session_id
884                        FROM personal_sessions
885                        WHERE owner_oauth2_client_id = $1
886                    )
887                "#,
888                Uuid::from(id),
889            )
890            .record(&span)
891            .execute(&mut *self.conn)
892            .instrument(span)
893            .await?;
894        }
895        {
896            let span = info_span!(
897                "db.oauth2_client.delete_by_id.personal_sessions",
898                { DB_QUERY_TEXT } = tracing::field::Empty,
899            );
900
901            sqlx::query!(
902                r#"
903                    DELETE FROM personal_sessions
904                    WHERE owner_oauth2_client_id = $1
905                "#,
906                Uuid::from(id),
907            )
908            .record(&span)
909            .execute(&mut *self.conn)
910            .instrument(span)
911            .await?;
912        }
913
914        // Now delete the client itself
915        let res = sqlx::query!(
916            r#"
917                DELETE FROM oauth2_clients
918                WHERE oauth2_client_id = $1
919            "#,
920            Uuid::from(id),
921        )
922        .traced()
923        .execute(&mut *self.conn)
924        .await?;
925
926        DatabaseError::ensure_affected_rows(&res, 1)
927    }
928
929    #[tracing::instrument(
930        name = "db.oauth2_client.list",
931        skip_all,
932        fields(
933            db.query.text,
934        ),
935        err,
936    )]
937    async fn list(
938        &mut self,
939        filter: OAuth2ClientFilter<'_>,
940        pagination: Pagination,
941    ) -> Result<Page<Client>, Self::Error> {
942        let (sql, arguments) = Query::select()
943            .expr_as(
944                Expr::col((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)),
945                OAuth2ClientLookupIden::Oauth2ClientId,
946            )
947            .expr_as(
948                Expr::cust("metadata_digest"),
949                OAuth2ClientLookupIden::MetadataDigest,
950            )
951            .expr_as(
952                Expr::cust("encrypted_client_secret"),
953                OAuth2ClientLookupIden::EncryptedClientSecret,
954            )
955            .expr_as(
956                Expr::cust("application_type"),
957                OAuth2ClientLookupIden::ApplicationType,
958            )
959            .expr_as(
960                Expr::col((OAuth2Clients::Table, OAuth2Clients::RedirectUris)),
961                OAuth2ClientLookupIden::RedirectUris,
962            )
963            .expr_as(
964                Expr::cust("grant_type_authorization_code"),
965                OAuth2ClientLookupIden::GrantTypeAuthorizationCode,
966            )
967            .expr_as(
968                Expr::cust("grant_type_refresh_token"),
969                OAuth2ClientLookupIden::GrantTypeRefreshToken,
970            )
971            .expr_as(
972                Expr::cust("grant_type_client_credentials"),
973                OAuth2ClientLookupIden::GrantTypeClientCredentials,
974            )
975            .expr_as(
976                Expr::cust("grant_type_device_code"),
977                OAuth2ClientLookupIden::GrantTypeDeviceCode,
978            )
979            .expr_as(
980                Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientName)),
981                OAuth2ClientLookupIden::ClientName,
982            )
983            .expr_as(
984                Expr::col((OAuth2Clients::Table, OAuth2Clients::LogoUri)),
985                OAuth2ClientLookupIden::LogoUri,
986            )
987            .expr_as(
988                Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientUri)),
989                OAuth2ClientLookupIden::ClientUri,
990            )
991            .expr_as(Expr::cust("policy_uri"), OAuth2ClientLookupIden::PolicyUri)
992            .expr_as(Expr::cust("tos_uri"), OAuth2ClientLookupIden::TosUri)
993            .expr_as(Expr::cust("jwks_uri"), OAuth2ClientLookupIden::JwksUri)
994            .expr_as(Expr::cust("jwks"), OAuth2ClientLookupIden::Jwks)
995            .expr_as(
996                Expr::cust("id_token_signed_response_alg"),
997                OAuth2ClientLookupIden::IdTokenSignedResponseAlg,
998            )
999            .expr_as(
1000                Expr::cust("userinfo_signed_response_alg"),
1001                OAuth2ClientLookupIden::UserinfoSignedResponseAlg,
1002            )
1003            .expr_as(
1004                Expr::cust("token_endpoint_auth_method"),
1005                OAuth2ClientLookupIden::TokenEndpointAuthMethod,
1006            )
1007            .expr_as(
1008                Expr::cust("token_endpoint_auth_signing_alg"),
1009                OAuth2ClientLookupIden::TokenEndpointAuthSigningAlg,
1010            )
1011            .expr_as(
1012                Expr::cust("initiate_login_uri"),
1013                OAuth2ClientLookupIden::InitiateLoginUri,
1014            )
1015            .expr_as(
1016                Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)),
1017                OAuth2ClientLookupIden::IsStatic,
1018            )
1019            .from(OAuth2Clients::Table)
1020            .apply_filter(filter)
1021            .generate_pagination(
1022                (OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId),
1023                pagination,
1024            )
1025            .build_sqlx(PostgresQueryBuilder);
1026
1027        let edges: Vec<OAuth2ClientLookup> = sqlx::query_as_with(&sql, arguments)
1028            .traced()
1029            .fetch_all(&mut *self.conn)
1030            .await?;
1031
1032        let page = pagination.process(edges).try_map(Client::try_from)?;
1033
1034        Ok(page)
1035    }
1036
1037    #[tracing::instrument(
1038        name = "db.oauth2_client.count",
1039        skip_all,
1040        fields(
1041            db.query.text,
1042        ),
1043        err,
1044    )]
1045    async fn count(&mut self, filter: OAuth2ClientFilter<'_>) -> Result<usize, Self::Error> {
1046        let (sql, arguments) = Query::select()
1047            .expr(Expr::col((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)).count())
1048            .from(OAuth2Clients::Table)
1049            .apply_filter(filter)
1050            .build_sqlx(PostgresQueryBuilder);
1051
1052        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
1053            .traced()
1054            .fetch_one(&mut *self.conn)
1055            .await?;
1056
1057        count
1058            .try_into()
1059            .map_err(DatabaseError::to_invalid_operation)
1060    }
1061}