mas_storage/oauth2/client.rs
1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::collections::{BTreeMap, BTreeSet};
9
10use async_trait::async_trait;
11use mas_data_model::{Client, Clock};
12use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
13use mas_jose::jwk::PublicJsonWebKeySet;
14use oauth2_types::{oidc::ApplicationType, requests::GrantType};
15use rand_core::RngCore;
16use ulid::Ulid;
17use url::Url;
18
19use crate::{Page, Pagination, repository_impl};
20
21/// The kind of OAuth 2.0 client, used by [`OAuth2ClientFilter`]
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub enum OAuth2ClientKind {
24 /// Static clients, declared in the configuration file
25 Static,
26
27 /// Dynamic clients, registered through the dynamic client registration
28 /// endpoint
29 Dynamic,
30}
31
32/// Filter parameters for listing OAuth 2.0 clients
33#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
34pub struct OAuth2ClientFilter<'a> {
35 kind: Option<OAuth2ClientKind>,
36 client_name: Option<&'a str>,
37 client_uri: Option<&'a str>,
38 grant_type: Option<&'a GrantType>,
39 has_active_sessions: Option<bool>,
40}
41
42impl<'a> OAuth2ClientFilter<'a> {
43 /// Create a new [`OAuth2ClientFilter`] with default values
44 #[must_use]
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 /// Only return static clients (those declared in the configuration)
50 #[must_use]
51 pub fn only_static_clients(mut self) -> Self {
52 self.kind = Some(OAuth2ClientKind::Static);
53 self
54 }
55
56 /// Only return dynamic clients (those registered via the
57 /// dynamic-client-registration endpoint)
58 #[must_use]
59 pub fn only_dynamic_clients(mut self) -> Self {
60 self.kind = Some(OAuth2ClientKind::Dynamic);
61 self
62 }
63
64 /// Get the client kind filter
65 ///
66 /// Returns [`None`] if no client kind filter was set
67 #[must_use]
68 pub fn kind(&self) -> Option<OAuth2ClientKind> {
69 self.kind
70 }
71
72 /// Only return clients whose `client_name` matches the given substring
73 /// (case-insensitive)
74 #[must_use]
75 pub fn matching_client_name(mut self, client_name: &'a str) -> Self {
76 self.client_name = Some(client_name);
77 self
78 }
79
80 /// Get the client name filter
81 ///
82 /// Returns [`None`] if no client name filter was set
83 #[must_use]
84 pub fn client_name(&self) -> Option<&'a str> {
85 self.client_name
86 }
87
88 /// Only return clients whose `client_uri` matches the given substring
89 /// (case-insensitive)
90 #[must_use]
91 pub fn matching_client_uri(mut self, client_uri: &'a str) -> Self {
92 self.client_uri = Some(client_uri);
93 self
94 }
95
96 /// Get the client URI filter
97 ///
98 /// Returns [`None`] if no client URI filter was set
99 #[must_use]
100 pub fn client_uri(&self) -> Option<&'a str> {
101 self.client_uri
102 }
103
104 /// Only return clients which support the given grant type
105 #[must_use]
106 pub fn with_grant_type(mut self, grant_type: &'a GrantType) -> Self {
107 self.grant_type = Some(grant_type);
108 self
109 }
110
111 /// Get the grant type filter
112 ///
113 /// Returns [`None`] if no grant type filter was set
114 #[must_use]
115 pub fn grant_type(&self) -> Option<&'a GrantType> {
116 self.grant_type
117 }
118
119 /// Only return clients which have (or don't have) at least one active
120 /// (non-finished) `OAuth2` session
121 #[must_use]
122 pub fn with_active_sessions(mut self, has_active_sessions: bool) -> Self {
123 self.has_active_sessions = Some(has_active_sessions);
124 self
125 }
126
127 /// Get the active-sessions filter
128 ///
129 /// Returns [`None`] if no active-sessions filter was set
130 #[must_use]
131 pub fn has_active_sessions(&self) -> Option<bool> {
132 self.has_active_sessions
133 }
134}
135
136/// An [`OAuth2ClientRepository`] helps interacting with [`Client`] saved in the
137/// storage backend
138#[async_trait]
139pub trait OAuth2ClientRepository: Send + Sync {
140 /// The error type returned by the repository
141 type Error;
142
143 /// Lookup an OAuth client by its ID
144 ///
145 /// Returns `None` if the client does not exist
146 ///
147 /// # Parameters
148 ///
149 /// * `id`: The ID of the client to lookup
150 ///
151 /// # Errors
152 ///
153 /// Returns [`Self::Error`] if the underlying repository fails
154 async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error>;
155
156 /// Find an OAuth client by its client ID
157 async fn find_by_client_id(&mut self, client_id: &str) -> Result<Option<Client>, Self::Error> {
158 let Ok(id) = client_id.parse() else {
159 return Ok(None);
160 };
161 self.lookup(id).await
162 }
163
164 /// Find an OAuth client by its metadata digest
165 ///
166 /// Returns `None` if the client does not exist
167 ///
168 /// # Parameters
169 ///
170 /// * `digest`: The metadata digest (SHA-256 hash encoded in hex) of the
171 /// client to find
172 ///
173 /// # Errors
174 ///
175 /// Returns [`Self::Error`] if the underlying repository fails
176 async fn find_by_metadata_digest(
177 &mut self,
178 digest: &str,
179 ) -> Result<Option<Client>, Self::Error>;
180
181 /// Load a batch of OAuth clients by their IDs
182 ///
183 /// Returns a map of client IDs to clients. If a client does not exist, it
184 /// is not present in the map.
185 ///
186 /// # Parameters
187 ///
188 /// * `ids`: The IDs of the clients to load
189 ///
190 /// # Errors
191 ///
192 /// Returns [`Self::Error`] if the underlying repository fails
193 async fn load_batch(
194 &mut self,
195 ids: BTreeSet<Ulid>,
196 ) -> Result<BTreeMap<Ulid, Client>, Self::Error>;
197
198 /// Add a new OAuth client
199 ///
200 /// Returns the client that was added
201 ///
202 /// # Parameters
203 ///
204 /// * `rng`: The random number generator to use
205 /// * `clock`: The clock used to generate timestamps
206 /// * `redirect_uris`: The list of redirect URIs used by this client
207 /// * `metadata_digest`: The hash of the client metadata, if computed
208 /// * `encrypted_client_secret`: The encrypted client secret, if any
209 /// * `application_type`: The application type of this client
210 /// * `grant_types`: The list of grant types this client can use
211 /// * `client_name`: The human-readable name of this client, if given
212 /// * `logo_uri`: The URI of the logo of this client, if given
213 /// * `client_uri`: The URI of a website of this client, if given
214 /// * `policy_uri`: The URI of the privacy policy of this client, if given
215 /// * `tos_uri`: The URI of the terms of service of this client, if given
216 /// * `jwks_uri`: The URI of the JWKS of this client, if given
217 /// * `jwks`: The JWKS of this client, if given
218 /// * `id_token_signed_response_alg`: The algorithm used to sign the ID
219 /// token
220 /// * `userinfo_signed_response_alg`: The algorithm used to sign the user
221 /// info. If none, the user info endpoint will not sign the response
222 /// * `token_endpoint_auth_method`: The authentication method used by this
223 /// client when calling the token endpoint
224 /// * `token_endpoint_auth_signing_alg`: The algorithm used to sign the JWT
225 /// when using the `client_secret_jwt` or `private_key_jwt` authentication
226 /// methods
227 /// * `initiate_login_uri`: The URI used to initiate a login, if given
228 ///
229 /// # Errors
230 ///
231 /// Returns [`Self::Error`] if the underlying repository fails
232 #[expect(clippy::too_many_arguments)]
233 async fn add(
234 &mut self,
235 rng: &mut (dyn RngCore + Send),
236 clock: &dyn Clock,
237 redirect_uris: Vec<Url>,
238 metadata_digest: Option<String>,
239 encrypted_client_secret: Option<String>,
240 application_type: Option<ApplicationType>,
241 grant_types: Vec<GrantType>,
242 client_name: Option<String>,
243 logo_uri: Option<Url>,
244 client_uri: Option<Url>,
245 policy_uri: Option<Url>,
246 tos_uri: Option<Url>,
247 jwks_uri: Option<Url>,
248 jwks: Option<PublicJsonWebKeySet>,
249 id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
250 userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
251 token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
252 token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
253 initiate_login_uri: Option<Url>,
254 ) -> Result<Client, Self::Error>;
255
256 /// Add or replace a static client
257 ///
258 /// Returns the client that was added or replaced
259 ///
260 /// # Parameters
261 ///
262 /// * `client_id`: The client ID
263 /// * `client_auth_method`: The authentication method this client uses
264 /// * `encrypted_client_secret`: The encrypted client secret, if any
265 /// * `jwks`: The client JWKS, if any
266 /// * `jwks_uri`: The client JWKS URI, if any
267 /// * `redirect_uris`: The list of redirect URIs used by this client
268 ///
269 /// # Errors
270 ///
271 /// Returns [`Self::Error`] if the underlying repository fails
272 #[expect(clippy::too_many_arguments)]
273 async fn upsert_static(
274 &mut self,
275 client_id: Ulid,
276 client_name: Option<String>,
277 client_auth_method: OAuthClientAuthenticationMethod,
278 encrypted_client_secret: Option<String>,
279 jwks: Option<PublicJsonWebKeySet>,
280 jwks_uri: Option<Url>,
281 redirect_uris: Vec<Url>,
282 ) -> Result<Client, Self::Error>;
283
284 /// List all static clients
285 ///
286 /// # Errors
287 ///
288 /// Returns [`Self::Error`] if the underlying repository fails
289 async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error>;
290
291 /// Delete a client
292 ///
293 /// # Parameters
294 ///
295 /// * `client`: The client to delete
296 ///
297 /// # Errors
298 ///
299 /// Returns [`Self::Error`] if the underlying repository fails, or if the
300 /// client does not exist
301 async fn delete(&mut self, client: Client) -> Result<(), Self::Error> {
302 self.delete_by_id(client.id).await
303 }
304
305 /// Delete a client by ID
306 ///
307 /// # Parameters
308 ///
309 /// * `id`: The ID of the client to delete
310 ///
311 /// # Errors
312 ///
313 /// Returns [`Self::Error`] if the underlying repository fails, or if the
314 /// client does not exist
315 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
316
317 /// List [`Client`] with the given filter and pagination
318 ///
319 /// # Parameters
320 ///
321 /// * `filter`: The filter parameters
322 /// * `pagination`: The pagination parameters
323 ///
324 /// # Errors
325 ///
326 /// Returns [`Self::Error`] if the underlying repository fails
327 async fn list(
328 &mut self,
329 filter: OAuth2ClientFilter<'_>,
330 pagination: Pagination,
331 ) -> Result<Page<Client>, Self::Error>;
332
333 /// Count the [`Client`] with the given filter
334 ///
335 /// # Parameters
336 ///
337 /// * `filter`: The filter parameters
338 ///
339 /// # Errors
340 ///
341 /// Returns [`Self::Error`] if the underlying repository fails
342 async fn count(&mut self, filter: OAuth2ClientFilter<'_>) -> Result<usize, Self::Error>;
343}
344
345repository_impl!(OAuth2ClientRepository:
346 async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error>;
347
348 async fn find_by_metadata_digest(
349 &mut self,
350 digest: &str,
351 ) -> Result<Option<Client>, Self::Error>;
352
353 async fn load_batch(
354 &mut self,
355 ids: BTreeSet<Ulid>,
356 ) -> Result<BTreeMap<Ulid, Client>, Self::Error>;
357
358 async fn add(
359 &mut self,
360 rng: &mut (dyn RngCore + Send),
361 clock: &dyn Clock,
362 redirect_uris: Vec<Url>,
363 metadata_digest: Option<String>,
364 encrypted_client_secret: Option<String>,
365 application_type: Option<ApplicationType>,
366 grant_types: Vec<GrantType>,
367 client_name: Option<String>,
368 logo_uri: Option<Url>,
369 client_uri: Option<Url>,
370 policy_uri: Option<Url>,
371 tos_uri: Option<Url>,
372 jwks_uri: Option<Url>,
373 jwks: Option<PublicJsonWebKeySet>,
374 id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
375 userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
376 token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
377 token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
378 initiate_login_uri: Option<Url>,
379 ) -> Result<Client, Self::Error>;
380
381 async fn upsert_static(
382 &mut self,
383 client_id: Ulid,
384 client_name: Option<String>,
385 client_auth_method: OAuthClientAuthenticationMethod,
386 encrypted_client_secret: Option<String>,
387 jwks: Option<PublicJsonWebKeySet>,
388 jwks_uri: Option<Url>,
389 redirect_uris: Vec<Url>,
390 ) -> Result<Client, Self::Error>;
391
392 async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error>;
393
394 async fn delete(&mut self, client: Client) -> Result<(), Self::Error>;
395
396 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
397
398 async fn list(
399 &mut self,
400 filter: OAuth2ClientFilter<'_>,
401 pagination: Pagination,
402 ) -> Result<Page<Client>, Self::Error>;
403
404 async fn count(&mut self, filter: OAuth2ClientFilter<'_>) -> Result<usize, Self::Error>;
405);