Skip to main content

mas_config/sections/
upstream_oauth2.rs

1// Copyright 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::collections::BTreeMap;
9
10use camino::Utf8PathBuf;
11use mas_iana::jose::JsonWebSignatureAlg;
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize, de::Error};
14use serde_with::{serde_as, skip_serializing_none};
15use ulid::Ulid;
16use url::Url;
17
18use crate::{ClientSecret, ClientSecretRaw, ConfigurationSection};
19
20/// Upstream OAuth 2.0 providers configuration
21#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
22pub struct UpstreamOAuth2Config {
23    /// List of OAuth 2.0 providers
24    pub providers: Vec<Provider>,
25}
26
27impl UpstreamOAuth2Config {
28    /// Returns true if the configuration is the default one
29    pub(crate) fn is_default(&self) -> bool {
30        self.providers.is_empty()
31    }
32}
33
34impl ConfigurationSection for UpstreamOAuth2Config {
35    const PATH: Option<&'static str> = Some("upstream_oauth2");
36
37    fn validate(
38        &self,
39        figment: &figment::Figment,
40    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
41        for (index, provider) in self.providers.iter().enumerate() {
42            let annotate = |mut error: figment::Error| {
43                error.metadata = figment
44                    .find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
45                    .cloned();
46                error.profile = Some(figment::Profile::Default);
47                error.path = vec![
48                    Self::PATH.unwrap().to_owned(),
49                    "providers".to_owned(),
50                    index.to_string(),
51                ];
52                error
53            };
54
55            if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
56                && provider.issuer.is_none()
57            {
58                return Err(annotate(figment::Error::custom(
59                    "The `issuer` field is required when discovery is enabled",
60                ))
61                .into());
62            }
63
64            match provider.token_endpoint_auth_method {
65                TokenAuthMethod::None
66                | TokenAuthMethod::PrivateKeyJwt
67                | TokenAuthMethod::SignInWithApple => {
68                    if provider.client_secret.is_some() {
69                        return Err(annotate(figment::Error::custom(
70                            "Unexpected field `client_secret` for the selected authentication method",
71                        )).into());
72                    }
73                }
74                TokenAuthMethod::ClientSecretBasic
75                | TokenAuthMethod::ClientSecretPost
76                | TokenAuthMethod::ClientSecretJwt => {
77                    if provider.client_secret.is_none() {
78                        return Err(annotate(figment::Error::missing_field("client_secret")).into());
79                    }
80                }
81            }
82
83            match provider.token_endpoint_auth_method {
84                TokenAuthMethod::None
85                | TokenAuthMethod::ClientSecretBasic
86                | TokenAuthMethod::ClientSecretPost
87                | TokenAuthMethod::SignInWithApple => {
88                    if provider.token_endpoint_auth_signing_alg.is_some() {
89                        return Err(annotate(figment::Error::custom(
90                            "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
91                        )).into());
92                    }
93                }
94                TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
95                    if provider.token_endpoint_auth_signing_alg.is_none() {
96                        return Err(annotate(figment::Error::missing_field(
97                            "token_endpoint_auth_signing_alg",
98                        ))
99                        .into());
100                    }
101                }
102            }
103
104            match provider.token_endpoint_auth_method {
105                TokenAuthMethod::SignInWithApple => {
106                    if provider.sign_in_with_apple.is_none() {
107                        return Err(
108                            annotate(figment::Error::missing_field("sign_in_with_apple")).into(),
109                        );
110                    }
111                }
112
113                _ => {
114                    if provider.sign_in_with_apple.is_some() {
115                        return Err(annotate(figment::Error::custom(
116                            "Unexpected field `sign_in_with_apple` for the selected authentication method",
117                        )).into());
118                    }
119                }
120            }
121
122            if provider.claims_imports.skip_confirmation {
123                if provider.claims_imports.localpart.action != ImportAction::Require {
124                    return Err(annotate(figment::Error::custom(
125                        "The field `action` must be `require` when `skip_confirmation` is set to `true`",
126                    )).with_path("claims_imports.localpart").into());
127                }
128
129                if provider.claims_imports.email.action == ImportAction::Suggest {
130                    return Err(annotate(figment::Error::custom(
131                        "The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
132                    )).with_path("claims_imports.email").into());
133                }
134
135                if provider.claims_imports.displayname.action == ImportAction::Suggest {
136                    return Err(annotate(figment::Error::custom(
137                        "The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
138                    )).with_path("claims_imports.displayname").into());
139                }
140            }
141
142            if matches!(
143                provider.claims_imports.localpart.on_conflict,
144                OnConflict::Add | OnConflict::Replace | OnConflict::Set
145            ) && !matches!(
146                provider.claims_imports.localpart.action,
147                ImportAction::Force | ImportAction::Require
148            ) {
149                return Err(annotate(figment::Error::custom(
150                    "The field `action` must be either `force` or `require` when `on_conflict` is set to `add`, `replace` or `set`",
151                )).with_path("claims_imports.localpart").into());
152            }
153        }
154
155        Ok(())
156    }
157}
158
159/// The response mode we ask the provider to use for the callback
160#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
161#[serde(rename_all = "snake_case")]
162pub enum ResponseMode {
163    /// `query`: The provider will send the response as a query string in the
164    /// URL search parameters
165    Query,
166
167    /// `form_post`: The provider will send the response as a POST request with
168    /// the response parameters in the request body
169    ///
170    /// <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html>
171    FormPost,
172}
173
174/// Authentication methods used against the OAuth 2.0 provider
175#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
176#[serde(rename_all = "snake_case")]
177pub enum TokenAuthMethod {
178    /// `none`: No authentication
179    None,
180
181    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
182    /// authorization credentials
183    ClientSecretBasic,
184
185    /// `client_secret_post`: `client_id` and `client_secret` sent in the
186    /// request body
187    ClientSecretPost,
188
189    /// `client_secret_jwt`: a `client_assertion` sent in the request body and
190    /// signed using the `client_secret`
191    ClientSecretJwt,
192
193    /// `private_key_jwt`: a `client_assertion` sent in the request body and
194    /// signed by an asymmetric key
195    PrivateKeyJwt,
196
197    /// `sign_in_with_apple`: a special method for Signin with Apple
198    SignInWithApple,
199}
200
201/// How to handle a claim
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
203#[serde(rename_all = "lowercase")]
204pub enum ImportAction {
205    /// Ignore the claim
206    #[default]
207    Ignore,
208
209    /// Suggest the claim value, but allow the user to change it
210    Suggest,
211
212    /// Force the claim value, but don't fail if it is missing
213    Force,
214
215    /// Force the claim value, and fail if it is missing
216    Require,
217}
218
219impl ImportAction {
220    #[expect(clippy::trivially_copy_pass_by_ref)]
221    const fn is_default(&self) -> bool {
222        matches!(self, ImportAction::Ignore)
223    }
224}
225
226/// How to handle an existing localpart claim
227#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
228#[serde(rename_all = "lowercase")]
229pub enum OnConflict {
230    /// Fails the upstream OAuth 2.0 login on conflict
231    #[default]
232    Fail,
233
234    /// Adds the upstream OAuth 2.0 identity link, regardless of whether there
235    /// is an existing link or not
236    Add,
237
238    /// Replace any existing upstream OAuth 2.0 identity link
239    Replace,
240
241    /// Adds the upstream OAuth 2.0 identity link *only* if there is no existing
242    /// link for this provider on the matching user
243    Set,
244}
245
246impl OnConflict {
247    #[expect(clippy::trivially_copy_pass_by_ref)]
248    const fn is_default(&self) -> bool {
249        matches!(self, OnConflict::Fail)
250    }
251}
252
253/// What should be done for the subject attribute
254#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
255pub struct SubjectImportPreference {
256    /// The Jinja2 template to use for the subject attribute
257    ///
258    /// If not provided, the default template is `{{ user.sub }}`
259    #[serde(default, skip_serializing_if = "Option::is_none")]
260    pub template: Option<String>,
261}
262
263impl SubjectImportPreference {
264    const fn is_default(&self) -> bool {
265        self.template.is_none()
266    }
267}
268
269/// What should be done for the localpart attribute
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
271pub struct LocalpartImportPreference {
272    /// How to handle the attribute
273    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
274    pub action: ImportAction,
275
276    /// The Jinja2 template to use for the localpart attribute
277    ///
278    /// If not provided, the default template is `{{ user.preferred_username }}`
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub template: Option<String>,
281
282    /// How to handle conflicts on the claim, default value is `Fail`
283    #[serde(default, skip_serializing_if = "OnConflict::is_default")]
284    pub on_conflict: OnConflict,
285}
286
287impl LocalpartImportPreference {
288    const fn is_default(&self) -> bool {
289        self.action.is_default() && self.template.is_none()
290    }
291}
292
293/// What should be done for the displayname attribute
294#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
295pub struct DisplaynameImportPreference {
296    /// How to handle the attribute
297    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
298    pub action: ImportAction,
299
300    /// The Jinja2 template to use for the displayname attribute
301    ///
302    /// If not provided, the default template is `{{ user.name }}`
303    #[serde(default, skip_serializing_if = "Option::is_none")]
304    pub template: Option<String>,
305}
306
307impl DisplaynameImportPreference {
308    const fn is_default(&self) -> bool {
309        self.action.is_default() && self.template.is_none()
310    }
311}
312
313/// What should be done with the email attribute
314#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
315pub struct EmailImportPreference {
316    /// How to handle the claim
317    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
318    pub action: ImportAction,
319
320    /// The Jinja2 template to use for the email address attribute
321    ///
322    /// If not provided, the default template is `{{ user.email }}`
323    #[serde(default, skip_serializing_if = "Option::is_none")]
324    pub template: Option<String>,
325}
326
327impl EmailImportPreference {
328    const fn is_default(&self) -> bool {
329        self.action.is_default() && self.template.is_none()
330    }
331}
332
333/// What should be done for the account name attribute
334#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
335pub struct AccountNameImportPreference {
336    /// The Jinja2 template to use for the account name. This name is only used
337    /// for display purposes.
338    ///
339    /// If not provided, it will be ignored.
340    #[serde(default, skip_serializing_if = "Option::is_none")]
341    pub template: Option<String>,
342}
343
344impl AccountNameImportPreference {
345    const fn is_default(&self) -> bool {
346        self.template.is_none()
347    }
348}
349
350/// How claims should be imported
351#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
352pub struct ClaimsImports {
353    /// How to determine the subject of the user
354    #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
355    pub subject: SubjectImportPreference,
356
357    /// Whether to skip the interactive screen prompting the user to confirm the
358    /// attributes that are being imported. This requires `localpart.action` to
359    /// be `require` and other attribute actions to be either `ignore`, `force`
360    /// or `require`
361    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
362    pub skip_confirmation: bool,
363
364    /// Import the localpart of the MXID
365    #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
366    pub localpart: LocalpartImportPreference,
367
368    /// Import the displayname of the user.
369    #[serde(
370        default,
371        skip_serializing_if = "DisplaynameImportPreference::is_default"
372    )]
373    pub displayname: DisplaynameImportPreference,
374
375    /// Import the email address of the user
376    #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
377    pub email: EmailImportPreference,
378
379    /// Set a human-readable name for the upstream account for display purposes
380    #[serde(
381        default,
382        skip_serializing_if = "AccountNameImportPreference::is_default"
383    )]
384    pub account_name: AccountNameImportPreference,
385}
386
387impl ClaimsImports {
388    const fn is_default(&self) -> bool {
389        self.subject.is_default()
390            && self.localpart.is_default()
391            && !self.skip_confirmation
392            && self.displayname.is_default()
393            && self.email.is_default()
394            && self.account_name.is_default()
395    }
396}
397
398/// How to discover the provider's configuration
399#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
400#[serde(rename_all = "snake_case")]
401pub enum DiscoveryMode {
402    /// Use OIDC discovery with strict metadata verification
403    #[default]
404    Oidc,
405
406    /// Use OIDC discovery with relaxed metadata verification
407    Insecure,
408
409    /// Use a static configuration
410    Disabled,
411}
412
413impl DiscoveryMode {
414    #[expect(clippy::trivially_copy_pass_by_ref)]
415    const fn is_default(&self) -> bool {
416        matches!(self, DiscoveryMode::Oidc)
417    }
418}
419
420/// Whether to use proof key for code exchange (PKCE) when requesting and
421/// exchanging the token.
422#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
423#[serde(rename_all = "snake_case")]
424pub enum PkceMethod {
425    /// Use PKCE if the provider supports it
426    ///
427    /// Defaults to no PKCE if provider discovery is disabled
428    #[default]
429    Auto,
430
431    /// Always use PKCE with the S256 challenge method
432    Always,
433
434    /// Never use PKCE
435    Never,
436}
437
438impl PkceMethod {
439    #[expect(clippy::trivially_copy_pass_by_ref)]
440    const fn is_default(&self) -> bool {
441        matches!(self, PkceMethod::Auto)
442    }
443}
444
445fn default_true() -> bool {
446    true
447}
448
449#[expect(clippy::trivially_copy_pass_by_ref)]
450fn is_default_true(value: &bool) -> bool {
451    *value
452}
453
454fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
455    *signed_response_alg == signed_response_alg_default()
456}
457
458fn signed_response_alg_default() -> JsonWebSignatureAlg {
459    JsonWebSignatureAlg::Rs256
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
463pub struct SignInWithApple {
464    /// The private key file used to sign the `id_token`
465    #[serde(skip_serializing_if = "Option::is_none")]
466    #[schemars(with = "Option<String>")]
467    pub private_key_file: Option<Utf8PathBuf>,
468
469    /// The private key used to sign the `id_token`
470    #[serde(skip_serializing_if = "Option::is_none")]
471    pub private_key: Option<String>,
472
473    /// The Team ID of the Apple Developer Portal
474    pub team_id: String,
475
476    /// The key ID of the Apple Developer Portal
477    pub key_id: String,
478}
479
480fn default_scope() -> String {
481    "openid".to_owned()
482}
483
484fn is_default_scope(scope: &str) -> bool {
485    scope == default_scope()
486}
487
488/// What to do when receiving an OIDC Backchannel logout request.
489#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
490#[serde(rename_all = "snake_case")]
491pub enum OnBackchannelLogout {
492    /// Do nothing
493    #[default]
494    DoNothing,
495
496    /// Only log out the MAS 'browser session' started by this OIDC session
497    LogoutBrowserOnly,
498
499    /// Log out all sessions started by this OIDC session, including MAS
500    /// 'browser sessions' and client sessions
501    LogoutAll,
502}
503
504impl OnBackchannelLogout {
505    #[expect(clippy::trivially_copy_pass_by_ref)]
506    const fn is_default(&self) -> bool {
507        matches!(self, OnBackchannelLogout::DoNothing)
508    }
509}
510
511/// Configuration for one upstream OAuth 2 provider.
512#[serde_as]
513#[skip_serializing_none]
514#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
515#[expect(clippy::struct_excessive_bools)]
516pub struct Provider {
517    /// Whether this provider is enabled.
518    ///
519    /// Defaults to `true`
520    #[serde(default = "default_true", skip_serializing_if = "is_default_true")]
521    pub enabled: bool,
522
523    /// An internal unique identifier for this provider
524    #[schemars(
525        with = "String",
526        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
527        description = "A ULID as per https://github.com/ulid/spec"
528    )]
529    pub id: Ulid,
530
531    /// The ID of the provider that was used by Synapse.
532    /// In order to perform a Synapse-to-MAS migration, this must be specified.
533    ///
534    /// ## For providers that used OAuth 2.0 or OpenID Connect in Synapse
535    ///
536    /// ### For `oidc_providers`:
537    /// This should be specified as `oidc-` followed by the ID that was
538    /// configured as `idp_id` in one of the `oidc_providers` in the Synapse
539    /// configuration.
540    /// For example, if Synapse's configuration contained `idp_id: wombat` for
541    /// this provider, then specify `oidc-wombat` here.
542    ///
543    /// ### For `oidc_config` (legacy):
544    /// Specify `oidc` here.
545    #[serde(skip_serializing_if = "Option::is_none")]
546    pub synapse_idp_id: Option<String>,
547
548    /// The OIDC issuer URL
549    ///
550    /// This is required if OIDC discovery is enabled (which is the default)
551    #[serde(skip_serializing_if = "Option::is_none")]
552    pub issuer: Option<String>,
553
554    /// A human-readable name for the provider, that will be shown to users
555    #[serde(skip_serializing_if = "Option::is_none")]
556    pub human_name: Option<String>,
557
558    /// A brand identifier used to customise the UI, e.g. `apple`, `google`,
559    /// `github`, etc.
560    ///
561    /// Values supported by the default template are:
562    ///
563    ///  - `apple`
564    ///  - `google`
565    ///  - `facebook`
566    ///  - `github`
567    ///  - `gitlab`
568    ///  - `twitter`
569    ///  - `discord`
570    #[serde(skip_serializing_if = "Option::is_none")]
571    pub brand_name: Option<String>,
572
573    /// The client ID to use when authenticating with the provider
574    pub client_id: String,
575
576    /// The client secret to use when authenticating with the provider
577    ///
578    /// Used by the `client_secret_basic`, `client_secret_post`, and
579    /// `client_secret_jwt` methods
580    #[schemars(with = "ClientSecretRaw")]
581    #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
582    #[serde(flatten)]
583    pub client_secret: Option<ClientSecret>,
584
585    /// The method to authenticate the client with the provider
586    pub token_endpoint_auth_method: TokenAuthMethod,
587
588    /// Additional parameters for the `sign_in_with_apple` method
589    #[serde(skip_serializing_if = "Option::is_none")]
590    pub sign_in_with_apple: Option<SignInWithApple>,
591
592    /// The JWS algorithm to use when authenticating the client with the
593    /// provider
594    ///
595    /// Used by the `client_secret_jwt` and `private_key_jwt` methods
596    #[serde(skip_serializing_if = "Option::is_none")]
597    pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
598
599    /// Expected signature for the JWT payload returned by the token
600    /// authentication endpoint.
601    ///
602    /// Defaults to `RS256`.
603    #[serde(
604        default = "signed_response_alg_default",
605        skip_serializing_if = "is_signed_response_alg_default"
606    )]
607    pub id_token_signed_response_alg: JsonWebSignatureAlg,
608
609    /// The scopes to request from the provider
610    ///
611    /// Defaults to `openid`.
612    #[serde(default = "default_scope", skip_serializing_if = "is_default_scope")]
613    pub scope: String,
614
615    /// How to discover the provider's configuration
616    ///
617    /// Defaults to `oidc`, which uses OIDC discovery with strict metadata
618    /// verification
619    #[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
620    pub discovery_mode: DiscoveryMode,
621
622    /// Whether to use proof key for code exchange (PKCE) when requesting and
623    /// exchanging the token.
624    ///
625    /// Defaults to `auto`, which uses PKCE if the provider supports it.
626    #[serde(default, skip_serializing_if = "PkceMethod::is_default")]
627    pub pkce_method: PkceMethod,
628
629    /// Whether to fetch the user profile from the userinfo endpoint,
630    /// or to rely on the data returned in the `id_token` from the
631    /// `token_endpoint`.
632    ///
633    /// Defaults to `false`.
634    #[serde(default)]
635    pub fetch_userinfo: bool,
636
637    /// Expected signature for the JWT payload returned by the userinfo
638    /// endpoint.
639    ///
640    /// If not specified, the response is expected to be an unsigned JSON
641    /// payload.
642    #[serde(skip_serializing_if = "Option::is_none")]
643    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
644
645    /// The URL to use for the provider's authorization endpoint
646    ///
647    /// Defaults to the `authorization_endpoint` provided through discovery
648    #[serde(skip_serializing_if = "Option::is_none")]
649    pub authorization_endpoint: Option<Url>,
650
651    /// The URL to use for the provider's userinfo endpoint
652    ///
653    /// Defaults to the `userinfo_endpoint` provided through discovery
654    #[serde(skip_serializing_if = "Option::is_none")]
655    pub userinfo_endpoint: Option<Url>,
656
657    /// The URL to use for the provider's token endpoint
658    ///
659    /// Defaults to the `token_endpoint` provided through discovery
660    #[serde(skip_serializing_if = "Option::is_none")]
661    pub token_endpoint: Option<Url>,
662
663    /// The URL to use for getting the provider's public keys
664    ///
665    /// Defaults to the `jwks_uri` provided through discovery
666    #[serde(skip_serializing_if = "Option::is_none")]
667    pub jwks_uri: Option<Url>,
668
669    /// The response mode we ask the provider to use for the callback
670    #[serde(skip_serializing_if = "Option::is_none")]
671    pub response_mode: Option<ResponseMode>,
672
673    /// How claims should be imported from the `id_token` provided by the
674    /// provider
675    #[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
676    pub claims_imports: ClaimsImports,
677
678    /// Additional parameters to include in the authorization request.
679    ///
680    /// Each value is a [`MiniJinja`] template. The template context
681    /// exposes a `params` map containing the raw query parameters from
682    /// the downstream authorization request. The map is empty when the
683    /// upstream login was not initiated by a downstream OAuth/OIDC
684    /// authorization request (e.g. account linking, direct login from
685    /// the login page).
686    ///
687    /// [`MiniJinja`]: https://docs.rs/minijinja
688    ///
689    /// Templates that render to an empty string are dropped — so
690    /// referencing a downstream parameter that wasn't supplied (e.g.
691    /// `{{ params.login_hint }}`) results in no parameter being
692    /// forwarded, rather than an empty one.
693    ///
694    /// Plain strings (without `{{ … }}`) are valid templates that render
695    /// to themselves.
696    ///
697    /// Example:
698    ///
699    /// ```yaml
700    /// additional_authorization_parameters:
701    ///   login_hint: "{{ params.login_hint }}"
702    ///   acr_values: "{{ params.acr_values }}"
703    ///   kc_idp_hint: "saml"
704    /// ```
705    ///
706    /// `params` exposes the entire raw query string of the downstream
707    /// request (including `client_id`, `state`, `code_challenge`, …).
708    /// Forward specific keys deliberately; don't blindly proxy the
709    /// whole map.
710    ///
711    /// Order of keys is not preserved.
712    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
713    pub additional_authorization_parameters: BTreeMap<String, String>,
714
715    /// Whether the `login_hint` should be forwarded to the provider in the
716    /// authorization request.
717    ///
718    /// Defaults to `false`.
719    ///
720    /// Deprecated: prefer adding
721    /// `login_hint: "{{ params.login_hint }}"` to
722    /// `additional_authorization_parameters` instead. When this flag is
723    /// set, a `login_hint` template entry is injected automatically if
724    /// one is not already present.
725    #[serde(default)]
726    pub forward_login_hint: bool,
727
728    /// What to do when receiving an OIDC Backchannel logout request.
729    ///
730    /// Defaults to `do_nothing`.
731    #[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
732    pub on_backchannel_logout: OnBackchannelLogout,
733
734    /// Whether or not to require a registration token on `OAuth2` auth
735    ///
736    /// Defaults to `false`
737    #[serde(default)]
738    pub registration_token_required: bool,
739}
740
741impl Provider {
742    /// Returns the client secret.
743    ///
744    /// If `client_secret_file` was given, the secret is read from that file.
745    ///
746    /// # Errors
747    ///
748    /// Returns an error when the client secret could not be read from file.
749    pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
750        Ok(match &self.client_secret {
751            Some(client_secret) => Some(client_secret.value().await?),
752            None => None,
753        })
754    }
755}
756
757#[cfg(test)]
758mod tests {
759    // The closures passed to `Jail::expect_with` return `figment::Error`, which is
760    // large, and we can't change figment's API.
761    #![expect(clippy::result_large_err)]
762
763    use std::str::FromStr;
764
765    use figment::{
766        Figment, Jail,
767        providers::{Format, Yaml},
768    };
769    use tokio::{runtime::Handle, task};
770
771    use super::*;
772
773    #[tokio::test]
774    async fn load_config() {
775        task::spawn_blocking(|| {
776            Jail::expect_with(|jail| {
777                jail.create_file(
778                    "config.yaml",
779                    r#"
780                      upstream_oauth2:
781                        providers:
782                          - id: 01GFWR28C4KNE04WG3HKXB7C9R
783                            client_id: upstream-oauth2
784                            token_endpoint_auth_method: none
785
786                          - id: 01GFWR32NCQ12B8Z0J8CPXRRB6
787                            client_id: upstream-oauth2
788                            client_secret_file: secret
789                            token_endpoint_auth_method: client_secret_basic
790
791                          - id: 01GFWR3WHR93Y5HK389H28VHZ9
792                            client_id: upstream-oauth2
793                            client_secret: c1!3n753c237
794                            token_endpoint_auth_method: client_secret_post
795
796                          - id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
797                            client_id: upstream-oauth2
798                            client_secret_file: secret
799                            token_endpoint_auth_method: client_secret_jwt
800
801                          - id: 01GFWR4BNFDCC4QDG6AMSP1VRR
802                            client_id: upstream-oauth2
803                            token_endpoint_auth_method: private_key_jwt
804                            jwks:
805                              keys:
806                              - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
807                                kty: "RSA"
808                                alg: "RS256"
809                                use: "sig"
810                                e: "AQAB"
811                                n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
812
813                              - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
814                                kty: "RSA"
815                                alg: "RS256"
816                                use: "sig"
817                                e: "AQAB"
818                                n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
819                    "#,
820                )?;
821                jail.create_file("secret", r"c1!3n753c237")?;
822
823                let config = Figment::new()
824                    .merge(Yaml::file("config.yaml"))
825                    .extract_inner::<UpstreamOAuth2Config>("upstream_oauth2")?;
826
827                assert_eq!(config.providers.len(), 5);
828
829                assert_eq!(
830                    config.providers[1].id,
831                    Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
832                );
833
834                assert!(config.providers[0].client_secret.is_none());
835                assert!(matches!(config.providers[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
836                assert!(matches!(config.providers[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
837                assert!(matches!(config.providers[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
838                assert!(config.providers[4].client_secret.is_none());
839
840                Handle::current().block_on(async move {
841                    assert_eq!(config.providers[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
842                    assert_eq!(config.providers[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
843                    assert_eq!(config.providers[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
844                });
845
846                Ok(())
847            });
848        }).await.unwrap();
849    }
850}