Skip to main content

mas_storage/upstream_oauth2/
provider.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::marker::PhantomData;
8
9use async_trait::async_trait;
10use mas_data_model::{
11    Clock, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
12    UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderOnBackchannelLogout,
13    UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderResponseMode,
14    UpstreamOAuthProviderTokenAuthMethod,
15};
16use mas_iana::jose::JsonWebSignatureAlg;
17use oauth2_types::scope::Scope;
18use rand_core::RngCore;
19use ulid::Ulid;
20use url::Url;
21
22use crate::{Pagination, pagination::Page, repository_impl};
23
24/// Structure which holds parameters when inserting or updating an upstream
25/// OAuth 2.0 provider
26pub struct UpstreamOAuthProviderParams {
27    /// The OIDC issuer of the provider
28    pub issuer: Option<String>,
29
30    /// A human-readable name for the provider
31    pub human_name: Option<String>,
32
33    /// A brand identifier, e.g. "apple" or "google"
34    pub brand_name: Option<String>,
35
36    /// The scope to request during the authorization flow
37    pub scope: Scope,
38
39    /// The token endpoint authentication method
40    pub token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod,
41
42    /// The JWT signing algorithm to use when then `client_secret_jwt` or
43    /// `private_key_jwt` authentication methods are used
44    pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
45
46    /// Expected signature for the JWT payload returned by the token
47    /// authentication endpoint.
48    ///
49    /// Defaults to `RS256`.
50    pub id_token_signed_response_alg: JsonWebSignatureAlg,
51
52    /// Whether to fetch the user profile from the userinfo endpoint,
53    /// or to rely on the data returned in the `id_token` from the
54    /// `token_endpoint`.
55    pub fetch_userinfo: bool,
56
57    /// Expected signature for the JWT payload returned by the userinfo
58    /// endpoint.
59    ///
60    /// If not specified, the response is expected to be an unsigned JSON
61    /// payload. Defaults to `None`.
62    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
63
64    /// The client ID to use when authenticating to the upstream
65    pub client_id: String,
66
67    /// The encrypted client secret to use when authenticating to the upstream
68    pub encrypted_client_secret: Option<String>,
69
70    /// How claims should be imported from the upstream provider
71    pub claims_imports: UpstreamOAuthProviderClaimsImports,
72
73    /// The URL to use as the authorization endpoint. If `None`, the URL will be
74    /// discovered
75    pub authorization_endpoint_override: Option<Url>,
76
77    /// The URL to use as the token endpoint. If `None`, the URL will be
78    /// discovered
79    pub token_endpoint_override: Option<Url>,
80
81    /// The URL to use as the userinfo endpoint. If `None`, the URL will be
82    /// discovered
83    pub userinfo_endpoint_override: Option<Url>,
84
85    /// The URL to use when fetching JWKS. If `None`, the URL will be discovered
86    pub jwks_uri_override: Option<Url>,
87
88    /// How the provider metadata should be discovered
89    pub discovery_mode: UpstreamOAuthProviderDiscoveryMode,
90
91    /// How should PKCE be used
92    pub pkce_mode: UpstreamOAuthProviderPkceMode,
93
94    /// What response mode it should ask
95    pub response_mode: Option<UpstreamOAuthProviderResponseMode>,
96
97    /// Additional parameters to include in the authorization request
98    pub additional_authorization_parameters: Vec<(String, String)>,
99
100    /// Whether to forward the login hint to the upstream provider.
101    pub forward_login_hint: bool,
102
103    /// The position of the provider in the UI
104    pub ui_order: i32,
105
106    /// The behavior when receiving a backchannel logout notification
107    pub on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout,
108
109    /// Whether or not to require a registration token on `OAuth2` auth
110    pub registration_token_required: bool,
111}
112
113/// Filter parameters for listing upstream OAuth 2.0 providers
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
115pub struct UpstreamOAuthProviderFilter<'a> {
116    /// Filter by whether the provider is enabled
117    ///
118    /// If `None`, all providers are returned
119    enabled: Option<bool>,
120
121    _lifetime: PhantomData<&'a ()>,
122}
123
124impl UpstreamOAuthProviderFilter<'_> {
125    /// Create a new [`UpstreamOAuthProviderFilter`] with default values
126    #[must_use]
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    /// Return only enabled providers
132    #[must_use]
133    pub const fn enabled_only(mut self) -> Self {
134        self.enabled = Some(true);
135        self
136    }
137
138    /// Return only disabled providers
139    #[must_use]
140    pub const fn disabled_only(mut self) -> Self {
141        self.enabled = Some(false);
142        self
143    }
144
145    /// Get the enabled filter
146    ///
147    /// Returns `None` if the filter is not set
148    #[must_use]
149    pub const fn enabled(&self) -> Option<bool> {
150        self.enabled
151    }
152}
153
154/// An [`UpstreamOAuthProviderRepository`] helps interacting with
155/// [`UpstreamOAuthProvider`] saved in the storage backend
156#[async_trait]
157pub trait UpstreamOAuthProviderRepository: Send + Sync {
158    /// The error type returned by the repository
159    type Error;
160
161    /// Lookup an upstream OAuth provider by its ID
162    ///
163    /// Returns `None` if the provider was not found
164    ///
165    /// # Parameters
166    ///
167    /// * `id`: The ID of the provider to lookup
168    ///
169    /// # Errors
170    ///
171    /// Returns [`Self::Error`] if the underlying repository fails
172    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
173
174    /// Add a new upstream OAuth provider
175    ///
176    /// Returns the newly created provider
177    ///
178    /// # Parameters
179    ///
180    /// * `rng`: A random number generator
181    /// * `clock`: The clock used to generate timestamps
182    /// * `params`: The parameters of the provider to add
183    ///
184    /// # Errors
185    ///
186    /// Returns [`Self::Error`] if the underlying repository fails
187    async fn add(
188        &mut self,
189        rng: &mut (dyn RngCore + Send),
190        clock: &dyn Clock,
191        params: UpstreamOAuthProviderParams,
192    ) -> Result<UpstreamOAuthProvider, Self::Error>;
193
194    /// Delete an upstream OAuth provider
195    ///
196    /// # Parameters
197    ///
198    /// * `provider`: The provider to delete
199    ///
200    /// # Errors
201    ///
202    /// Returns [`Self::Error`] if the underlying repository fails
203    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error> {
204        self.delete_by_id(provider.id).await
205    }
206
207    /// Delete an upstream OAuth provider by its ID
208    ///
209    /// # Parameters
210    ///
211    /// * `id`: The ID of the provider to delete
212    ///
213    /// # Errors
214    ///
215    /// Returns [`Self::Error`] if the underlying repository fails
216    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
217
218    /// Insert or update an upstream OAuth provider
219    ///
220    /// # Parameters
221    ///
222    /// * `clock`: The clock used to generate timestamps
223    /// * `id`: The ID of the provider to update
224    /// * `params`: The parameters of the provider to update
225    ///
226    /// # Errors
227    ///
228    /// Returns [`Self::Error`] if the underlying repository fails
229    async fn upsert(
230        &mut self,
231        clock: &dyn Clock,
232        id: Ulid,
233        params: UpstreamOAuthProviderParams,
234    ) -> Result<UpstreamOAuthProvider, Self::Error>;
235
236    /// Disable an upstream OAuth provider
237    ///
238    /// Returns the disabled provider
239    ///
240    /// # Parameters
241    ///
242    /// * `clock`: The clock used to generate timestamps
243    /// * `provider`: The provider to disable
244    ///
245    /// # Errors
246    ///
247    /// Returns [`Self::Error`] if the underlying repository fails
248    async fn disable(
249        &mut self,
250        clock: &dyn Clock,
251        provider: UpstreamOAuthProvider,
252    ) -> Result<UpstreamOAuthProvider, Self::Error>;
253
254    /// List [`UpstreamOAuthProvider`] with the given filter and pagination
255    ///
256    /// # Parameters
257    ///
258    /// * `filter`: The filter to apply
259    /// * `pagination`: The pagination parameters
260    ///
261    /// # Errors
262    ///
263    /// Returns [`Self::Error`] if the underlying repository fails
264    async fn list(
265        &mut self,
266        filter: UpstreamOAuthProviderFilter<'_>,
267        pagination: Pagination,
268    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
269
270    /// Count the number of [`UpstreamOAuthProvider`] with the given filter
271    ///
272    /// # Parameters
273    ///
274    /// * `filter`: The filter to apply
275    ///
276    /// # Errors
277    ///
278    /// Returns [`Self::Error`] if the underlying repository fails
279    async fn count(
280        &mut self,
281        filter: UpstreamOAuthProviderFilter<'_>,
282    ) -> Result<usize, Self::Error>;
283
284    /// Get all enabled upstream OAuth providers
285    ///
286    /// # Errors
287    ///
288    /// Returns [`Self::Error`] if the underlying repository fails
289    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
290}
291
292repository_impl!(UpstreamOAuthProviderRepository:
293    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error>;
294
295    async fn add(
296        &mut self,
297        rng: &mut (dyn RngCore + Send),
298        clock: &dyn Clock,
299        params: UpstreamOAuthProviderParams
300    ) -> Result<UpstreamOAuthProvider, Self::Error>;
301
302    async fn upsert(
303        &mut self,
304        clock: &dyn Clock,
305        id: Ulid,
306        params: UpstreamOAuthProviderParams
307    ) -> Result<UpstreamOAuthProvider, Self::Error>;
308
309    async fn delete(&mut self, provider: UpstreamOAuthProvider) -> Result<(), Self::Error>;
310
311    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
312
313    async fn disable(
314        &mut self,
315        clock: &dyn Clock,
316        provider: UpstreamOAuthProvider
317    ) -> Result<UpstreamOAuthProvider, Self::Error>;
318
319    async fn list(
320        &mut self,
321        filter: UpstreamOAuthProviderFilter<'_>,
322        pagination: Pagination
323    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error>;
324
325    async fn count(
326        &mut self,
327        filter: UpstreamOAuthProviderFilter<'_>
328    ) -> Result<usize, Self::Error>;
329
330    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error>;
331);