1use std::collections::BTreeMap;
9
10use camino::Utf8PathBuf;
11use mas_iana::jose::JsonWebSignatureAlg;
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize, de::Error};
14use serde_with::{serde_as, skip_serializing_none};
15use ulid::Ulid;
16use url::Url;
17
18use crate::{ClientSecret, ClientSecretRaw, ConfigurationSection};
19
20#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
22pub struct UpstreamOAuth2Config {
23 pub providers: Vec<Provider>,
25}
26
27impl UpstreamOAuth2Config {
28 pub(crate) fn is_default(&self) -> bool {
30 self.providers.is_empty()
31 }
32}
33
34impl ConfigurationSection for UpstreamOAuth2Config {
35 const PATH: Option<&'static str> = Some("upstream_oauth2");
36
37 fn validate(
38 &self,
39 figment: &figment::Figment,
40 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
41 for (index, provider) in self.providers.iter().enumerate() {
42 let annotate = |mut error: figment::Error| {
43 error.metadata = figment
44 .find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
45 .cloned();
46 error.profile = Some(figment::Profile::Default);
47 error.path = vec![
48 Self::PATH.unwrap().to_owned(),
49 "providers".to_owned(),
50 index.to_string(),
51 ];
52 error
53 };
54
55 if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
56 && provider.issuer.is_none()
57 {
58 return Err(annotate(figment::Error::custom(
59 "The `issuer` field is required when discovery is enabled",
60 ))
61 .into());
62 }
63
64 match provider.token_endpoint_auth_method {
65 TokenAuthMethod::None
66 | TokenAuthMethod::PrivateKeyJwt
67 | TokenAuthMethod::SignInWithApple => {
68 if provider.client_secret.is_some() {
69 return Err(annotate(figment::Error::custom(
70 "Unexpected field `client_secret` for the selected authentication method",
71 )).into());
72 }
73 }
74 TokenAuthMethod::ClientSecretBasic
75 | TokenAuthMethod::ClientSecretPost
76 | TokenAuthMethod::ClientSecretJwt => {
77 if provider.client_secret.is_none() {
78 return Err(annotate(figment::Error::missing_field("client_secret")).into());
79 }
80 }
81 }
82
83 match provider.token_endpoint_auth_method {
84 TokenAuthMethod::None
85 | TokenAuthMethod::ClientSecretBasic
86 | TokenAuthMethod::ClientSecretPost
87 | TokenAuthMethod::SignInWithApple => {
88 if provider.token_endpoint_auth_signing_alg.is_some() {
89 return Err(annotate(figment::Error::custom(
90 "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
91 )).into());
92 }
93 }
94 TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
95 if provider.token_endpoint_auth_signing_alg.is_none() {
96 return Err(annotate(figment::Error::missing_field(
97 "token_endpoint_auth_signing_alg",
98 ))
99 .into());
100 }
101 }
102 }
103
104 match provider.token_endpoint_auth_method {
105 TokenAuthMethod::SignInWithApple => {
106 if provider.sign_in_with_apple.is_none() {
107 return Err(
108 annotate(figment::Error::missing_field("sign_in_with_apple")).into(),
109 );
110 }
111 }
112
113 _ => {
114 if provider.sign_in_with_apple.is_some() {
115 return Err(annotate(figment::Error::custom(
116 "Unexpected field `sign_in_with_apple` for the selected authentication method",
117 )).into());
118 }
119 }
120 }
121
122 if provider.claims_imports.skip_confirmation {
123 if provider.claims_imports.localpart.action != ImportAction::Require {
124 return Err(annotate(figment::Error::custom(
125 "The field `action` must be `require` when `skip_confirmation` is set to `true`",
126 )).with_path("claims_imports.localpart").into());
127 }
128
129 if provider.claims_imports.email.action == ImportAction::Suggest {
130 return Err(annotate(figment::Error::custom(
131 "The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
132 )).with_path("claims_imports.email").into());
133 }
134
135 if provider.claims_imports.displayname.action == ImportAction::Suggest {
136 return Err(annotate(figment::Error::custom(
137 "The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
138 )).with_path("claims_imports.displayname").into());
139 }
140 }
141
142 if matches!(
143 provider.claims_imports.localpart.on_conflict,
144 OnConflict::Add | OnConflict::Replace | OnConflict::Set
145 ) && !matches!(
146 provider.claims_imports.localpart.action,
147 ImportAction::Force | ImportAction::Require
148 ) {
149 return Err(annotate(figment::Error::custom(
150 "The field `action` must be either `force` or `require` when `on_conflict` is set to `add`, `replace` or `set`",
151 )).with_path("claims_imports.localpart").into());
152 }
153 }
154
155 Ok(())
156 }
157}
158
159#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
161#[serde(rename_all = "snake_case")]
162pub enum ResponseMode {
163 Query,
166
167 FormPost,
172}
173
174#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
176#[serde(rename_all = "snake_case")]
177pub enum TokenAuthMethod {
178 None,
180
181 ClientSecretBasic,
184
185 ClientSecretPost,
188
189 ClientSecretJwt,
192
193 PrivateKeyJwt,
196
197 SignInWithApple,
199}
200
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
203#[serde(rename_all = "lowercase")]
204pub enum ImportAction {
205 #[default]
207 Ignore,
208
209 Suggest,
211
212 Force,
214
215 Require,
217}
218
219impl ImportAction {
220 #[expect(clippy::trivially_copy_pass_by_ref)]
221 const fn is_default(&self) -> bool {
222 matches!(self, ImportAction::Ignore)
223 }
224}
225
226#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
228#[serde(rename_all = "lowercase")]
229pub enum OnConflict {
230 #[default]
232 Fail,
233
234 Add,
237
238 Replace,
240
241 Set,
244}
245
246impl OnConflict {
247 #[expect(clippy::trivially_copy_pass_by_ref)]
248 const fn is_default(&self) -> bool {
249 matches!(self, OnConflict::Fail)
250 }
251}
252
253#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
255pub struct SubjectImportPreference {
256 #[serde(default, skip_serializing_if = "Option::is_none")]
260 pub template: Option<String>,
261}
262
263impl SubjectImportPreference {
264 const fn is_default(&self) -> bool {
265 self.template.is_none()
266 }
267}
268
269#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
271pub struct LocalpartImportPreference {
272 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
274 pub action: ImportAction,
275
276 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub template: Option<String>,
281
282 #[serde(default, skip_serializing_if = "OnConflict::is_default")]
284 pub on_conflict: OnConflict,
285}
286
287impl LocalpartImportPreference {
288 const fn is_default(&self) -> bool {
289 self.action.is_default() && self.template.is_none()
290 }
291}
292
293#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
295pub struct DisplaynameImportPreference {
296 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
298 pub action: ImportAction,
299
300 #[serde(default, skip_serializing_if = "Option::is_none")]
304 pub template: Option<String>,
305}
306
307impl DisplaynameImportPreference {
308 const fn is_default(&self) -> bool {
309 self.action.is_default() && self.template.is_none()
310 }
311}
312
313#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
315pub struct EmailImportPreference {
316 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
318 pub action: ImportAction,
319
320 #[serde(default, skip_serializing_if = "Option::is_none")]
324 pub template: Option<String>,
325}
326
327impl EmailImportPreference {
328 const fn is_default(&self) -> bool {
329 self.action.is_default() && self.template.is_none()
330 }
331}
332
333#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
335pub struct AccountNameImportPreference {
336 #[serde(default, skip_serializing_if = "Option::is_none")]
341 pub template: Option<String>,
342}
343
344impl AccountNameImportPreference {
345 const fn is_default(&self) -> bool {
346 self.template.is_none()
347 }
348}
349
350#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
352pub struct ClaimsImports {
353 #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
355 pub subject: SubjectImportPreference,
356
357 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
362 pub skip_confirmation: bool,
363
364 #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
366 pub localpart: LocalpartImportPreference,
367
368 #[serde(
370 default,
371 skip_serializing_if = "DisplaynameImportPreference::is_default"
372 )]
373 pub displayname: DisplaynameImportPreference,
374
375 #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
377 pub email: EmailImportPreference,
378
379 #[serde(
381 default,
382 skip_serializing_if = "AccountNameImportPreference::is_default"
383 )]
384 pub account_name: AccountNameImportPreference,
385}
386
387impl ClaimsImports {
388 const fn is_default(&self) -> bool {
389 self.subject.is_default()
390 && self.localpart.is_default()
391 && !self.skip_confirmation
392 && self.displayname.is_default()
393 && self.email.is_default()
394 && self.account_name.is_default()
395 }
396}
397
398#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
400#[serde(rename_all = "snake_case")]
401pub enum DiscoveryMode {
402 #[default]
404 Oidc,
405
406 Insecure,
408
409 Disabled,
411}
412
413impl DiscoveryMode {
414 #[expect(clippy::trivially_copy_pass_by_ref)]
415 const fn is_default(&self) -> bool {
416 matches!(self, DiscoveryMode::Oidc)
417 }
418}
419
420#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
423#[serde(rename_all = "snake_case")]
424pub enum PkceMethod {
425 #[default]
429 Auto,
430
431 Always,
433
434 Never,
436}
437
438impl PkceMethod {
439 #[expect(clippy::trivially_copy_pass_by_ref)]
440 const fn is_default(&self) -> bool {
441 matches!(self, PkceMethod::Auto)
442 }
443}
444
445fn default_true() -> bool {
446 true
447}
448
449#[expect(clippy::trivially_copy_pass_by_ref)]
450fn is_default_true(value: &bool) -> bool {
451 *value
452}
453
454fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
455 *signed_response_alg == signed_response_alg_default()
456}
457
458fn signed_response_alg_default() -> JsonWebSignatureAlg {
459 JsonWebSignatureAlg::Rs256
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
463pub struct SignInWithApple {
464 #[serde(skip_serializing_if = "Option::is_none")]
466 #[schemars(with = "Option<String>")]
467 pub private_key_file: Option<Utf8PathBuf>,
468
469 #[serde(skip_serializing_if = "Option::is_none")]
471 pub private_key: Option<String>,
472
473 pub team_id: String,
475
476 pub key_id: String,
478}
479
480fn default_scope() -> String {
481 "openid".to_owned()
482}
483
484fn is_default_scope(scope: &str) -> bool {
485 scope == default_scope()
486}
487
488#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
490#[serde(rename_all = "snake_case")]
491pub enum OnBackchannelLogout {
492 #[default]
494 DoNothing,
495
496 LogoutBrowserOnly,
498
499 LogoutAll,
502}
503
504impl OnBackchannelLogout {
505 #[expect(clippy::trivially_copy_pass_by_ref)]
506 const fn is_default(&self) -> bool {
507 matches!(self, OnBackchannelLogout::DoNothing)
508 }
509}
510
511#[serde_as]
513#[skip_serializing_none]
514#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
515#[expect(clippy::struct_excessive_bools)]
516pub struct Provider {
517 #[serde(default = "default_true", skip_serializing_if = "is_default_true")]
521 pub enabled: bool,
522
523 #[schemars(
525 with = "String",
526 regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
527 description = "A ULID as per https://github.com/ulid/spec"
528 )]
529 pub id: Ulid,
530
531 #[serde(skip_serializing_if = "Option::is_none")]
546 pub synapse_idp_id: Option<String>,
547
548 #[serde(skip_serializing_if = "Option::is_none")]
552 pub issuer: Option<String>,
553
554 #[serde(skip_serializing_if = "Option::is_none")]
556 pub human_name: Option<String>,
557
558 #[serde(skip_serializing_if = "Option::is_none")]
571 pub brand_name: Option<String>,
572
573 pub client_id: String,
575
576 #[schemars(with = "ClientSecretRaw")]
581 #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
582 #[serde(flatten)]
583 pub client_secret: Option<ClientSecret>,
584
585 pub token_endpoint_auth_method: TokenAuthMethod,
587
588 #[serde(skip_serializing_if = "Option::is_none")]
590 pub sign_in_with_apple: Option<SignInWithApple>,
591
592 #[serde(skip_serializing_if = "Option::is_none")]
597 pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
598
599 #[serde(
604 default = "signed_response_alg_default",
605 skip_serializing_if = "is_signed_response_alg_default"
606 )]
607 pub id_token_signed_response_alg: JsonWebSignatureAlg,
608
609 #[serde(default = "default_scope", skip_serializing_if = "is_default_scope")]
613 pub scope: String,
614
615 #[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
620 pub discovery_mode: DiscoveryMode,
621
622 #[serde(default, skip_serializing_if = "PkceMethod::is_default")]
627 pub pkce_method: PkceMethod,
628
629 #[serde(default)]
635 pub fetch_userinfo: bool,
636
637 #[serde(skip_serializing_if = "Option::is_none")]
643 pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
644
645 #[serde(skip_serializing_if = "Option::is_none")]
649 pub authorization_endpoint: Option<Url>,
650
651 #[serde(skip_serializing_if = "Option::is_none")]
655 pub userinfo_endpoint: Option<Url>,
656
657 #[serde(skip_serializing_if = "Option::is_none")]
661 pub token_endpoint: Option<Url>,
662
663 #[serde(skip_serializing_if = "Option::is_none")]
667 pub jwks_uri: Option<Url>,
668
669 #[serde(skip_serializing_if = "Option::is_none")]
671 pub response_mode: Option<ResponseMode>,
672
673 #[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
676 pub claims_imports: ClaimsImports,
677
678 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
713 pub additional_authorization_parameters: BTreeMap<String, String>,
714
715 #[serde(default)]
726 pub forward_login_hint: bool,
727
728 #[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
732 pub on_backchannel_logout: OnBackchannelLogout,
733
734 #[serde(default)]
738 pub registration_token_required: bool,
739}
740
741impl Provider {
742 pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
750 Ok(match &self.client_secret {
751 Some(client_secret) => Some(client_secret.value().await?),
752 None => None,
753 })
754 }
755}
756
757#[cfg(test)]
758mod tests {
759 #![expect(clippy::result_large_err)]
762
763 use std::str::FromStr;
764
765 use figment::{
766 Figment, Jail,
767 providers::{Format, Yaml},
768 };
769 use tokio::{runtime::Handle, task};
770
771 use super::*;
772
773 #[tokio::test]
774 async fn load_config() {
775 task::spawn_blocking(|| {
776 Jail::expect_with(|jail| {
777 jail.create_file(
778 "config.yaml",
779 r#"
780 upstream_oauth2:
781 providers:
782 - id: 01GFWR28C4KNE04WG3HKXB7C9R
783 client_id: upstream-oauth2
784 token_endpoint_auth_method: none
785
786 - id: 01GFWR32NCQ12B8Z0J8CPXRRB6
787 client_id: upstream-oauth2
788 client_secret_file: secret
789 token_endpoint_auth_method: client_secret_basic
790
791 - id: 01GFWR3WHR93Y5HK389H28VHZ9
792 client_id: upstream-oauth2
793 client_secret: c1!3n753c237
794 token_endpoint_auth_method: client_secret_post
795
796 - id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
797 client_id: upstream-oauth2
798 client_secret_file: secret
799 token_endpoint_auth_method: client_secret_jwt
800
801 - id: 01GFWR4BNFDCC4QDG6AMSP1VRR
802 client_id: upstream-oauth2
803 token_endpoint_auth_method: private_key_jwt
804 jwks:
805 keys:
806 - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
807 kty: "RSA"
808 alg: "RS256"
809 use: "sig"
810 e: "AQAB"
811 n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
812
813 - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
814 kty: "RSA"
815 alg: "RS256"
816 use: "sig"
817 e: "AQAB"
818 n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
819 "#,
820 )?;
821 jail.create_file("secret", r"c1!3n753c237")?;
822
823 let config = Figment::new()
824 .merge(Yaml::file("config.yaml"))
825 .extract_inner::<UpstreamOAuth2Config>("upstream_oauth2")?;
826
827 assert_eq!(config.providers.len(), 5);
828
829 assert_eq!(
830 config.providers[1].id,
831 Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
832 );
833
834 assert!(config.providers[0].client_secret.is_none());
835 assert!(matches!(config.providers[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
836 assert!(matches!(config.providers[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
837 assert!(matches!(config.providers[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
838 assert!(config.providers[4].client_secret.is_none());
839
840 Handle::current().block_on(async move {
841 assert_eq!(config.providers[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
842 assert_eq!(config.providers[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
843 assert_eq!(config.providers[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
844 });
845
846 Ok(())
847 });
848 }).await.unwrap();
849 }
850}