Skip to main content

mas_storage/oauth2/
client.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::collections::{BTreeMap, BTreeSet};
9
10use async_trait::async_trait;
11use mas_data_model::{Client, Clock};
12use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
13use mas_jose::jwk::PublicJsonWebKeySet;
14use oauth2_types::{oidc::ApplicationType, requests::GrantType};
15use rand_core::RngCore;
16use ulid::Ulid;
17use url::Url;
18
19use crate::{Page, Pagination, repository_impl};
20
21/// The kind of OAuth 2.0 client, used by [`OAuth2ClientFilter`]
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub enum OAuth2ClientKind {
24    /// Static clients, declared in the configuration file
25    Static,
26
27    /// Dynamic clients, registered through the dynamic client registration
28    /// endpoint
29    Dynamic,
30}
31
32/// Filter parameters for listing OAuth 2.0 clients
33#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
34pub struct OAuth2ClientFilter<'a> {
35    kind: Option<OAuth2ClientKind>,
36    client_name: Option<&'a str>,
37    client_uri: Option<&'a str>,
38    grant_type: Option<&'a GrantType>,
39    has_active_sessions: Option<bool>,
40}
41
42impl<'a> OAuth2ClientFilter<'a> {
43    /// Create a new [`OAuth2ClientFilter`] with default values
44    #[must_use]
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Only return static clients (those declared in the configuration)
50    #[must_use]
51    pub fn only_static_clients(mut self) -> Self {
52        self.kind = Some(OAuth2ClientKind::Static);
53        self
54    }
55
56    /// Only return dynamic clients (those registered via the
57    /// dynamic-client-registration endpoint)
58    #[must_use]
59    pub fn only_dynamic_clients(mut self) -> Self {
60        self.kind = Some(OAuth2ClientKind::Dynamic);
61        self
62    }
63
64    /// Get the client kind filter
65    ///
66    /// Returns [`None`] if no client kind filter was set
67    #[must_use]
68    pub fn kind(&self) -> Option<OAuth2ClientKind> {
69        self.kind
70    }
71
72    /// Only return clients whose `client_name` matches the given substring
73    /// (case-insensitive)
74    #[must_use]
75    pub fn matching_client_name(mut self, client_name: &'a str) -> Self {
76        self.client_name = Some(client_name);
77        self
78    }
79
80    /// Get the client name filter
81    ///
82    /// Returns [`None`] if no client name filter was set
83    #[must_use]
84    pub fn client_name(&self) -> Option<&'a str> {
85        self.client_name
86    }
87
88    /// Only return clients whose `client_uri` matches the given substring
89    /// (case-insensitive)
90    #[must_use]
91    pub fn matching_client_uri(mut self, client_uri: &'a str) -> Self {
92        self.client_uri = Some(client_uri);
93        self
94    }
95
96    /// Get the client URI filter
97    ///
98    /// Returns [`None`] if no client URI filter was set
99    #[must_use]
100    pub fn client_uri(&self) -> Option<&'a str> {
101        self.client_uri
102    }
103
104    /// Only return clients which support the given grant type
105    #[must_use]
106    pub fn with_grant_type(mut self, grant_type: &'a GrantType) -> Self {
107        self.grant_type = Some(grant_type);
108        self
109    }
110
111    /// Get the grant type filter
112    ///
113    /// Returns [`None`] if no grant type filter was set
114    #[must_use]
115    pub fn grant_type(&self) -> Option<&'a GrantType> {
116        self.grant_type
117    }
118
119    /// Only return clients which have (or don't have) at least one active
120    /// (non-finished) `OAuth2` session
121    #[must_use]
122    pub fn with_active_sessions(mut self, has_active_sessions: bool) -> Self {
123        self.has_active_sessions = Some(has_active_sessions);
124        self
125    }
126
127    /// Get the active-sessions filter
128    ///
129    /// Returns [`None`] if no active-sessions filter was set
130    #[must_use]
131    pub fn has_active_sessions(&self) -> Option<bool> {
132        self.has_active_sessions
133    }
134}
135
136/// An [`OAuth2ClientRepository`] helps interacting with [`Client`] saved in the
137/// storage backend
138#[async_trait]
139pub trait OAuth2ClientRepository: Send + Sync {
140    /// The error type returned by the repository
141    type Error;
142
143    /// Lookup an OAuth client by its ID
144    ///
145    /// Returns `None` if the client does not exist
146    ///
147    /// # Parameters
148    ///
149    /// * `id`: The ID of the client to lookup
150    ///
151    /// # Errors
152    ///
153    /// Returns [`Self::Error`] if the underlying repository fails
154    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error>;
155
156    /// Find an OAuth client by its client ID
157    async fn find_by_client_id(&mut self, client_id: &str) -> Result<Option<Client>, Self::Error> {
158        let Ok(id) = client_id.parse() else {
159            return Ok(None);
160        };
161        self.lookup(id).await
162    }
163
164    /// Find an OAuth client by its metadata digest
165    ///
166    /// Returns `None` if the client does not exist
167    ///
168    /// # Parameters
169    ///
170    /// * `digest`: The metadata digest (SHA-256 hash encoded in hex) of the
171    ///   client to find
172    ///
173    /// # Errors
174    ///
175    /// Returns [`Self::Error`] if the underlying repository fails
176    async fn find_by_metadata_digest(
177        &mut self,
178        digest: &str,
179    ) -> Result<Option<Client>, Self::Error>;
180
181    /// Load a batch of OAuth clients by their IDs
182    ///
183    /// Returns a map of client IDs to clients. If a client does not exist, it
184    /// is not present in the map.
185    ///
186    /// # Parameters
187    ///
188    /// * `ids`: The IDs of the clients to load
189    ///
190    /// # Errors
191    ///
192    /// Returns [`Self::Error`] if the underlying repository fails
193    async fn load_batch(
194        &mut self,
195        ids: BTreeSet<Ulid>,
196    ) -> Result<BTreeMap<Ulid, Client>, Self::Error>;
197
198    /// Add a new OAuth client
199    ///
200    /// Returns the client that was added
201    ///
202    /// # Parameters
203    ///
204    /// * `rng`: The random number generator to use
205    /// * `clock`: The clock used to generate timestamps
206    /// * `redirect_uris`: The list of redirect URIs used by this client
207    /// * `metadata_digest`: The hash of the client metadata, if computed
208    /// * `encrypted_client_secret`: The encrypted client secret, if any
209    /// * `application_type`: The application type of this client
210    /// * `grant_types`: The list of grant types this client can use
211    /// * `client_name`: The human-readable name of this client, if given
212    /// * `logo_uri`: The URI of the logo of this client, if given
213    /// * `client_uri`: The URI of a website of this client, if given
214    /// * `policy_uri`: The URI of the privacy policy of this client, if given
215    /// * `tos_uri`: The URI of the terms of service of this client, if given
216    /// * `jwks_uri`: The URI of the JWKS of this client, if given
217    /// * `jwks`: The JWKS of this client, if given
218    /// * `id_token_signed_response_alg`: The algorithm used to sign the ID
219    ///   token
220    /// * `userinfo_signed_response_alg`: The algorithm used to sign the user
221    ///   info. If none, the user info endpoint will not sign the response
222    /// * `token_endpoint_auth_method`: The authentication method used by this
223    ///   client when calling the token endpoint
224    /// * `token_endpoint_auth_signing_alg`: The algorithm used to sign the JWT
225    ///   when using the `client_secret_jwt` or `private_key_jwt` authentication
226    ///   methods
227    /// * `initiate_login_uri`: The URI used to initiate a login, if given
228    ///
229    /// # Errors
230    ///
231    /// Returns [`Self::Error`] if the underlying repository fails
232    #[expect(clippy::too_many_arguments)]
233    async fn add(
234        &mut self,
235        rng: &mut (dyn RngCore + Send),
236        clock: &dyn Clock,
237        redirect_uris: Vec<Url>,
238        metadata_digest: Option<String>,
239        encrypted_client_secret: Option<String>,
240        application_type: Option<ApplicationType>,
241        grant_types: Vec<GrantType>,
242        client_name: Option<String>,
243        logo_uri: Option<Url>,
244        client_uri: Option<Url>,
245        policy_uri: Option<Url>,
246        tos_uri: Option<Url>,
247        jwks_uri: Option<Url>,
248        jwks: Option<PublicJsonWebKeySet>,
249        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
250        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
251        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
252        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
253        initiate_login_uri: Option<Url>,
254    ) -> Result<Client, Self::Error>;
255
256    /// Add or replace a static client
257    ///
258    /// Returns the client that was added or replaced
259    ///
260    /// # Parameters
261    ///
262    /// * `client_id`: The client ID
263    /// * `client_auth_method`: The authentication method this client uses
264    /// * `encrypted_client_secret`: The encrypted client secret, if any
265    /// * `jwks`: The client JWKS, if any
266    /// * `jwks_uri`: The client JWKS URI, if any
267    /// * `redirect_uris`: The list of redirect URIs used by this client
268    ///
269    /// # Errors
270    ///
271    /// Returns [`Self::Error`] if the underlying repository fails
272    #[expect(clippy::too_many_arguments)]
273    async fn upsert_static(
274        &mut self,
275        client_id: Ulid,
276        client_name: Option<String>,
277        client_auth_method: OAuthClientAuthenticationMethod,
278        encrypted_client_secret: Option<String>,
279        jwks: Option<PublicJsonWebKeySet>,
280        jwks_uri: Option<Url>,
281        redirect_uris: Vec<Url>,
282    ) -> Result<Client, Self::Error>;
283
284    /// List all static clients
285    ///
286    /// # Errors
287    ///
288    /// Returns [`Self::Error`] if the underlying repository fails
289    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error>;
290
291    /// Delete a client
292    ///
293    /// # Parameters
294    ///
295    /// * `client`: The client to delete
296    ///
297    /// # Errors
298    ///
299    /// Returns [`Self::Error`] if the underlying repository fails, or if the
300    /// client does not exist
301    async fn delete(&mut self, client: Client) -> Result<(), Self::Error> {
302        self.delete_by_id(client.id).await
303    }
304
305    /// Delete a client by ID
306    ///
307    /// # Parameters
308    ///
309    /// * `id`: The ID of the client to delete
310    ///
311    /// # Errors
312    ///
313    /// Returns [`Self::Error`] if the underlying repository fails, or if the
314    /// client does not exist
315    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
316
317    /// List [`Client`] with the given filter and pagination
318    ///
319    /// # Parameters
320    ///
321    /// * `filter`: The filter parameters
322    /// * `pagination`: The pagination parameters
323    ///
324    /// # Errors
325    ///
326    /// Returns [`Self::Error`] if the underlying repository fails
327    async fn list(
328        &mut self,
329        filter: OAuth2ClientFilter<'_>,
330        pagination: Pagination,
331    ) -> Result<Page<Client>, Self::Error>;
332
333    /// Count the [`Client`] with the given filter
334    ///
335    /// # Parameters
336    ///
337    /// * `filter`: The filter parameters
338    ///
339    /// # Errors
340    ///
341    /// Returns [`Self::Error`] if the underlying repository fails
342    async fn count(&mut self, filter: OAuth2ClientFilter<'_>) -> Result<usize, Self::Error>;
343}
344
345repository_impl!(OAuth2ClientRepository:
346    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error>;
347
348    async fn find_by_metadata_digest(
349        &mut self,
350        digest: &str,
351    ) -> Result<Option<Client>, Self::Error>;
352
353    async fn load_batch(
354        &mut self,
355        ids: BTreeSet<Ulid>,
356    ) -> Result<BTreeMap<Ulid, Client>, Self::Error>;
357
358    async fn add(
359        &mut self,
360        rng: &mut (dyn RngCore + Send),
361        clock: &dyn Clock,
362        redirect_uris: Vec<Url>,
363        metadata_digest: Option<String>,
364        encrypted_client_secret: Option<String>,
365        application_type: Option<ApplicationType>,
366        grant_types: Vec<GrantType>,
367        client_name: Option<String>,
368        logo_uri: Option<Url>,
369        client_uri: Option<Url>,
370        policy_uri: Option<Url>,
371        tos_uri: Option<Url>,
372        jwks_uri: Option<Url>,
373        jwks: Option<PublicJsonWebKeySet>,
374        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
375        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
376        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
377        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
378        initiate_login_uri: Option<Url>,
379    ) -> Result<Client, Self::Error>;
380
381    async fn upsert_static(
382        &mut self,
383        client_id: Ulid,
384        client_name: Option<String>,
385        client_auth_method: OAuthClientAuthenticationMethod,
386        encrypted_client_secret: Option<String>,
387        jwks: Option<PublicJsonWebKeySet>,
388        jwks_uri: Option<Url>,
389        redirect_uris: Vec<Url>,
390    ) -> Result<Client, Self::Error>;
391
392    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error>;
393
394    async fn delete(&mut self, client: Client) -> Result<(), Self::Error>;
395
396    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
397
398    async fn list(
399        &mut self,
400        filter: OAuth2ClientFilter<'_>,
401        pagination: Pagination,
402    ) -> Result<Page<Client>, Self::Error>;
403
404    async fn count(&mut self, filter: OAuth2ClientFilter<'_>) -> Result<usize, Self::Error>;
405);