mas_storage/upstream_oauth2/provider.rs
1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11 Clock, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
12 UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderOnBackchannelLogout,
13 UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
14 UpstreamOAuthProviderTokenAuthMethod,
15};
16use mas_iana::jose::JsonWebSignatureAlg;
17use oauth2_types::scope::Scope;
18use rand_core::RngCore;
19use ulid::Ulid;
20use url::Url;
21
22use crate::{Pagination, pagination::Page, repository_impl};
23
24/// Structure which holds parameters when inserting or updating an upstream
25/// OAuth 2.0 provider
26pub struct UpstreamOAuthProviderParams {
27 /// The OIDC issuer of the provider
28 pub issuer: Option<String>,
29
30 /// A human-readable name for the provider
31 pub human_name: Option<String>,
32
33 /// A brand identifier, e.g. "apple" or "google"
34 pub brand_name: Option<String>,
35
36 /// The scope to request during the authorization flow
37 pub scope: Scope,
38
39 /// The token endpoint authentication method
40 pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
41
42 /// The JWT signing algorithm to use when then `client_secret_jwt` or
43 /// `private_key_jwt` authentication methods are used
44 pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
45
46 /// Expected signature for the JWT payload returned by the token
47 /// authentication endpoint.
48 ///
49 /// Defaults to `RS256`.
50 pub id_token_signed_response_alg: JsonWebSignatureAlg,
51
52 /// Whether to fetch the user profile from the userinfo endpoint,
53 /// or to rely on the data returned in the `id_token` from the
54 /// `token_endpoint`.
55 pub fetch_userinfo: bool,
56
57 /// Expected signature for the JWT payload returned by the userinfo
58 /// endpoint.
59 ///
60 /// If not specified, the response is expected to be an unsigned JSON
61 /// payload. Defaults to `None`.
62 pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
63
64 /// The client ID to use when authenticating to the upstream
65 pub client_id: String,
66
67 /// The encrypted client secret to use when authenticating to the upstream
68 pub encrypted_client_secret: Option<String>,
69
70 /// How claims should be imported from the upstream provider
71 pub claims_imports: UpstreamOAuthProviderClaimsImports,
72
73 /// The URL to use as the authorization endpoint. If `None`, the URL will be
74 /// discovered
75 pub authorization_endpoint_override: Option<Url>,
76
77 /// The URL to use as the token endpoint. If `None`, the URL will be
78 /// discovered
79 pub token_endpoint_override: Option<Url>,
80
81 /// The URL to use as the userinfo endpoint. If `None`, the URL will be
82 /// discovered
83 pub userinfo_endpoint_override: Option<Url>,
84
85 /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
86 pub jwks_uri_override: Option<Url>,
87
88 /// How the provider metadata should be discovered
89 pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
90
91 /// How should PKCE be used
92 pub pkce_mode: UpstreamOAuthProviderPkceMode,
93
94 /// What response mode it should ask
95 pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
96
97 /// Additional parameters to include in the authorization request
98 pub additional_authorization_parameters: Vec<(String, String)>,
99
100 /// Whether to forward the login hint to the upstream provider.
101 pub forward_login_hint: bool,
102
103 /// The position of the provider in the UI
104 pub ui_order: i32,
105
106 /// The behavior when receiving a backchannel logout notification
107 pub on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout,
108
109 /// Whether or not to require a registration token on `OAuth2` auth
110 pub registration_token_required: bool,
111}
112
113/// Filter parameters for listing upstream OAuth 2.0 providers
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
115pub struct UpstreamOAuthProviderFilter<'a> {
116 /// Filter by whether the provider is enabled
117 ///
118 /// If `None`, all providers are returned
119 enabled: Option<bool>,
120
121 _lifetime: PhantomData<&'a ()>,
122}
123
124impl UpstreamOAuthProviderFilter<'_> {
125 /// Create a new [`UpstreamOAuthProviderFilter`] with default values
126 #[must_use]
127 pub fn new() -> Self {
128 Self::default()
129 }
130
131 /// Return only enabled providers
132 #[must_use]
133 pub const fn enabled_only(mut self) -> Self {
134 self.enabled = Some(true);
135 self
136 }
137
138 /// Return only disabled providers
139 #[must_use]
140 pub const fn disabled_only(mut self) -> Self {
141 self.enabled = Some(false);
142 self
143 }
144
145 /// Get the enabled filter
146 ///
147 /// Returns `None` if the filter is not set
148 #[must_use]
149 pub const fn enabled(&self) -> Option<bool> {
150 self.enabled
151 }
152}
153
154/// An [`UpstreamOAuthProviderRepository`] helps interacting with
155/// [`UpstreamOAuthProvider`] saved in the storage backend
156#[async_trait]
157pub trait UpstreamOAuthProviderRepository: Send + Sync {
158 /// The error type returned by the repository
159 type Error;
160
161 /// Lookup an upstream OAuth provider by its ID
162 ///
163 /// Returns `None` if the provider was not found
164 ///
165 /// # Parameters
166 ///
167 /// * `id`: The ID of the provider to lookup
168 ///
169 /// # Errors
170 ///
171 /// Returns [`Self::Error`] if the underlying repository fails
172 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
173
174 /// Add a new upstream OAuth provider
175 ///
176 /// Returns the newly created provider
177 ///
178 /// # Parameters
179 ///
180 /// * `rng`: A random number generator
181 /// * `clock`: The clock used to generate timestamps
182 /// * `params`: The parameters of the provider to add
183 ///
184 /// # Errors
185 ///
186 /// Returns [`Self::Error`] if the underlying repository fails
187 async fn add(
188 &mut self,
189 rng: &mut (dyn RngCore + Send),
190 clock: &dyn Clock,
191 params: UpstreamOAuthProviderParams,
192 ) -> Result<UpstreamOAuthProvider, Self::Error>;
193
194 /// Delete an upstream OAuth provider
195 ///
196 /// # Parameters
197 ///
198 /// * `provider`: The provider to delete
199 ///
200 /// # Errors
201 ///
202 /// Returns [`Self::Error`] if the underlying repository fails
203 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
204 self.delete_by_id(provider.id).await
205 }
206
207 /// Delete an upstream OAuth provider by its ID
208 ///
209 /// # Parameters
210 ///
211 /// * `id`: The ID of the provider to delete
212 ///
213 /// # Errors
214 ///
215 /// Returns [`Self::Error`] if the underlying repository fails
216 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
217
218 /// Insert or update an upstream OAuth provider
219 ///
220 /// # Parameters
221 ///
222 /// * `clock`: The clock used to generate timestamps
223 /// * `id`: The ID of the provider to update
224 /// * `params`: The parameters of the provider to update
225 ///
226 /// # Errors
227 ///
228 /// Returns [`Self::Error`] if the underlying repository fails
229 async fn upsert(
230 &mut self,
231 clock: &dyn Clock,
232 id: Ulid,
233 params: UpstreamOAuthProviderParams,
234 ) -> Result<UpstreamOAuthProvider, Self::Error>;
235
236 /// Disable an upstream OAuth provider
237 ///
238 /// Returns the disabled provider
239 ///
240 /// # Parameters
241 ///
242 /// * `clock`: The clock used to generate timestamps
243 /// * `provider`: The provider to disable
244 ///
245 /// # Errors
246 ///
247 /// Returns [`Self::Error`] if the underlying repository fails
248 async fn disable(
249 &mut self,
250 clock: &dyn Clock,
251 provider: UpstreamOAuthProvider,
252 ) -> Result<UpstreamOAuthProvider, Self::Error>;
253
254 /// List [`UpstreamOAuthProvider`] with the given filter and pagination
255 ///
256 /// # Parameters
257 ///
258 /// * `filter`: The filter to apply
259 /// * `pagination`: The pagination parameters
260 ///
261 /// # Errors
262 ///
263 /// Returns [`Self::Error`] if the underlying repository fails
264 async fn list(
265 &mut self,
266 filter: UpstreamOAuthProviderFilter<'_>,
267 pagination: Pagination,
268 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
269
270 /// Count the number of [`UpstreamOAuthProvider`] with the given filter
271 ///
272 /// # Parameters
273 ///
274 /// * `filter`: The filter to apply
275 ///
276 /// # Errors
277 ///
278 /// Returns [`Self::Error`] if the underlying repository fails
279 async fn count(
280 &mut self,
281 filter: UpstreamOAuthProviderFilter<'_>,
282 ) -> Result<usize, Self::Error>;
283
284 /// Get all enabled upstream OAuth providers
285 ///
286 /// # Errors
287 ///
288 /// Returns [`Self::Error`] if the underlying repository fails
289 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
290}
291
292repository_impl!(UpstreamOAuthProviderRepository:
293 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
294
295 async fn add(
296 &mut self,
297 rng: &mut (dyn RngCore + Send),
298 clock: &dyn Clock,
299 params: UpstreamOAuthProviderParams
300 ) -> Result<UpstreamOAuthProvider, Self::Error>;
301
302 async fn upsert(
303 &mut self,
304 clock: &dyn Clock,
305 id: Ulid,
306 params: UpstreamOAuthProviderParams
307 ) -> Result<UpstreamOAuthProvider, Self::Error>;
308
309 async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
310
311 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
312
313 async fn disable(
314 &mut self,
315 clock: &dyn Clock,
316 provider: UpstreamOAuthProvider
317 ) -> Result<UpstreamOAuthProvider, Self::Error>;
318
319 async fn list(
320 &mut self,
321 filter: UpstreamOAuthProviderFilter<'_>,
322 pagination: Pagination
323 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
324
325 async fn count(
326 &mut self,
327 filter: UpstreamOAuthProviderFilter<'_>
328 ) -> Result<usize, Self::Error>;
329
330 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
331);