1pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::{SessionLimitConfig, Ulid};
13use opa_wasm::{
14 Runtime,
15 wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use serde::Serialize;
18use thiserror::Error;
19use tokio::io::{AsyncRead, AsyncReadExt};
20
21pub use self::model::{
22 AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput,
23 EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
24 ViolationVariant,
25};
26
27#[derive(Debug, Error)]
28pub enum LoadError {
29 #[error("failed to read module")]
30 Read(#[from] tokio::io::Error),
31
32 #[error("failed to create WASM engine")]
33 Engine(#[source] opa_wasm::wasmtime::Error),
34
35 #[error("module compilation task crashed")]
36 CompilationTask(#[from] tokio::task::JoinError),
37
38 #[error("failed to compile WASM module")]
39 Compilation(#[source] anyhow::Error),
40
41 #[error("invalid policy data")]
42 InvalidData(#[source] anyhow::Error),
43
44 #[error("failed to instantiate a test instance")]
45 Instantiate(#[source] InstantiateError),
46}
47
48impl LoadError {
49 #[doc(hidden)]
52 #[must_use]
53 pub fn invalid_data_example() -> Self {
54 Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
55 }
56}
57
58#[derive(Debug, Error)]
59pub enum InstantiateError {
60 #[error("failed to create WASM runtime")]
61 Runtime(#[source] anyhow::Error),
62
63 #[error("missing entrypoint {entrypoint}")]
64 MissingEntrypoint { entrypoint: String },
65
66 #[error("failed to load policy data")]
67 LoadData(#[source] anyhow::Error),
68}
69
70#[derive(Debug, Clone)]
72pub struct Entrypoints {
73 pub register: String,
74 pub client_registration: String,
75 pub authorization_grant: String,
76 pub compat_login: String,
77 pub email: String,
78}
79
80impl Entrypoints {
81 fn all(&self) -> [&str; 5] {
82 [
83 self.register.as_str(),
84 self.client_registration.as_str(),
85 self.authorization_grant.as_str(),
86 self.compat_login.as_str(),
87 self.email.as_str(),
88 ]
89 }
90}
91
92#[derive(Debug)]
94pub struct Data {
95 base: BaseData,
96
97 rest: Option<serde_json::Value>,
99}
100
101#[derive(Serialize, Debug)]
102struct BaseData {
103 server_name: String,
104
105 session_limit: Option<SessionLimitConfig>,
107}
108
109impl Data {
110 #[must_use]
111 pub fn new(server_name: String, session_limit: Option<SessionLimitConfig>) -> Self {
112 Self {
113 base: BaseData {
114 server_name,
115 session_limit,
116 },
117
118 rest: None,
119 }
120 }
121
122 #[must_use]
123 pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
124 self.rest = Some(rest);
125 self
126 }
127
128 fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
129 let base = serde_json::to_value(&self.base)?;
130
131 if let Some(rest) = &self.rest {
132 merge_data(base, rest.clone())
133 } else {
134 Ok(base)
135 }
136 }
137}
138
139fn value_kind(value: &serde_json::Value) -> &'static str {
140 match value {
141 serde_json::Value::Object(_) => "object",
142 serde_json::Value::Array(_) => "array",
143 serde_json::Value::String(_) => "string",
144 serde_json::Value::Number(_) => "number",
145 serde_json::Value::Bool(_) => "boolean",
146 serde_json::Value::Null => "null",
147 }
148}
149
150fn merge_data(
151 mut left: serde_json::Value,
152 right: serde_json::Value,
153) -> Result<serde_json::Value, anyhow::Error> {
154 merge_data_rec(&mut left, right)?;
155 Ok(left)
156}
157
158fn merge_data_rec(
159 left: &mut serde_json::Value,
160 right: serde_json::Value,
161) -> Result<(), anyhow::Error> {
162 match (left, right) {
163 (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
164 for (key, value) in right {
165 if let Some(left_value) = left.get_mut(&key) {
166 merge_data_rec(left_value, value)?;
167 } else {
168 left.insert(key, value);
169 }
170 }
171 }
172 (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
173 left.extend(right);
174 }
175 (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
177 *left = right;
178 }
179 (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
180 *left = right;
181 }
182 (serde_json::Value::String(left), serde_json::Value::String(right)) => {
183 *left = right;
184 }
185
186 (left, right) if left.is_null() => *left = right,
188
189 (left, right) if right.is_null() => *left = right,
191
192 (left, right) => anyhow::bail!(
193 "Cannot merge a {} into a {}",
194 value_kind(&right),
195 value_kind(left),
196 ),
197 }
198
199 Ok(())
200}
201
202struct DynamicData {
207 version: Option<Ulid>,
208 merged: serde_json::Value,
209}
210
211pub struct PolicyFactory {
212 engine: Engine,
213 module: Module,
214 data: Data,
215 dynamic_data: ArcSwap<DynamicData>,
216 entrypoints: Entrypoints,
217}
218
219impl PolicyFactory {
220 #[tracing::instrument(name = "policy.load", skip(source))]
226 pub async fn load(
227 mut source: impl AsyncRead + std::marker::Unpin,
228 data: Data,
229 entrypoints: Entrypoints,
230 ) -> Result<Self, LoadError> {
231 let mut config = Config::default();
232 config.cranelift_opt_level(OptLevel::SpeedAndSize);
233
234 let engine = Engine::new(&config).map_err(LoadError::Engine)?;
235
236 let mut buf = Vec::new();
238 source.read_to_end(&mut buf).await?;
239 let (engine, module) = tokio::task::spawn_blocking(move || {
241 let module = Module::new(&engine, buf)?;
242 anyhow::Ok((engine, module))
243 })
244 .await?
245 .map_err(LoadError::Compilation)?;
246
247 let merged = data.to_value().map_err(LoadError::InvalidData)?;
248 let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
249 version: None,
250 merged,
251 }));
252
253 let factory = Self {
254 engine,
255 module,
256 data,
257 dynamic_data,
258 entrypoints,
259 };
260
261 factory
263 .instantiate()
264 .await
265 .map_err(LoadError::Instantiate)?;
266
267 Ok(factory)
268 }
269
270 pub async fn set_dynamic_data(
283 &self,
284 dynamic_data: mas_data_model::PolicyData,
285 ) -> Result<bool, LoadError> {
286 if self.dynamic_data.load().version == Some(dynamic_data.id) {
289 return Ok(false);
291 }
292
293 let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
294 let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
295
296 self.instantiate_with_data(&merged)
298 .await
299 .map_err(LoadError::Instantiate)?;
300
301 self.dynamic_data.store(Arc::new(DynamicData {
303 version: Some(dynamic_data.id),
304 merged,
305 }));
306
307 Ok(true)
308 }
309
310 #[tracing::instrument(name = "policy.instantiate", skip_all)]
317 pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
318 let data = self.dynamic_data.load();
319 self.instantiate_with_data(&data.merged).await
320 }
321
322 async fn instantiate_with_data(
323 &self,
324 data: &serde_json::Value,
325 ) -> Result<Policy, InstantiateError> {
326 tracing::debug!("Instantiating policy with data={}", data);
327 let mut store = Store::new(&self.engine, ());
328 let runtime = Runtime::new(&mut store, &self.module)
329 .await
330 .map_err(InstantiateError::Runtime)?;
331
332 let policy_entrypoints = runtime.entrypoints();
334
335 for e in self.entrypoints.all() {
336 if !policy_entrypoints.contains(e) {
337 return Err(InstantiateError::MissingEntrypoint {
338 entrypoint: e.to_owned(),
339 });
340 }
341 }
342
343 let instance = runtime
344 .with_data(&mut store, data)
345 .await
346 .map_err(InstantiateError::LoadData)?;
347
348 Ok(Policy {
349 store,
350 instance,
351 entrypoints: self.entrypoints.clone(),
352 })
353 }
354}
355
356pub struct Policy {
357 store: Store<()>,
358 instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
359 entrypoints: Entrypoints,
360}
361
362#[derive(Debug, Error)]
363#[error("failed to evaluate policy")]
364pub enum EvaluationError {
365 Serialization(#[from] serde_json::Error),
366 Evaluation(#[from] anyhow::Error),
367}
368
369impl Policy {
370 #[tracing::instrument(
376 name = "policy.evaluate_email",
377 skip_all,
378 fields(
379 %input.email,
380 ),
381 )]
382 pub async fn evaluate_email(
383 &mut self,
384 input: EmailInput<'_>,
385 ) -> Result<EvaluationResult, EvaluationError> {
386 let [res]: [EvaluationResult; 1] = self
387 .instance
388 .evaluate(&mut self.store, &self.entrypoints.email, &input)
389 .await?;
390
391 Ok(res)
392 }
393
394 #[tracing::instrument(
400 name = "policy.evaluate.register",
401 skip_all,
402 fields(
403 ?input.registration_method,
404 input.username = input.username,
405 input.email = input.email,
406 ),
407 )]
408 pub async fn evaluate_register(
409 &mut self,
410 input: RegisterInput<'_>,
411 ) -> Result<EvaluationResult, EvaluationError> {
412 let [res]: [EvaluationResult; 1] = self
413 .instance
414 .evaluate(&mut self.store, &self.entrypoints.register, &input)
415 .await?;
416
417 Ok(res)
418 }
419
420 #[tracing::instrument(skip(self))]
426 pub async fn evaluate_client_registration(
427 &mut self,
428 input: ClientRegistrationInput<'_>,
429 ) -> Result<EvaluationResult, EvaluationError> {
430 let [res]: [EvaluationResult; 1] = self
431 .instance
432 .evaluate(
433 &mut self.store,
434 &self.entrypoints.client_registration,
435 &input,
436 )
437 .await?;
438
439 Ok(res)
440 }
441
442 #[tracing::instrument(
448 name = "policy.evaluate.authorization_grant",
449 skip_all,
450 fields(
451 %input.scope,
452 %input.client.id,
453 ),
454 )]
455 pub async fn evaluate_authorization_grant(
456 &mut self,
457 input: AuthorizationGrantInput<'_>,
458 ) -> Result<EvaluationResult, EvaluationError> {
459 let [res]: [EvaluationResult; 1] = self
460 .instance
461 .evaluate(
462 &mut self.store,
463 &self.entrypoints.authorization_grant,
464 &input,
465 )
466 .await?;
467
468 Ok(res)
469 }
470
471 #[tracing::instrument(
477 name = "policy.evaluate.compat_login",
478 skip_all,
479 fields(
480 %input.user.id,
481 ),
482 )]
483 pub async fn evaluate_compat_login(
484 &mut self,
485 input: CompatLoginInput<'_>,
486 ) -> Result<EvaluationResult, EvaluationError> {
487 let [res]: [EvaluationResult; 1] = self
488 .instance
489 .evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
490 .await?;
491
492 Ok(res)
493 }
494}
495
496#[cfg(test)]
497mod tests {
498
499 use std::time::SystemTime;
500
501 use super::*;
502
503 fn make_entrypoints() -> Entrypoints {
504 Entrypoints {
505 register: "register/violation".to_owned(),
506 client_registration: "client_registration/violation".to_owned(),
507 authorization_grant: "authorization_grant/violation".to_owned(),
508 compat_login: "compat_login/violation".to_owned(),
509 email: "email/violation".to_owned(),
510 }
511 }
512
513 #[tokio::test]
514 async fn test_register() {
515 let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
516 "allowed_domains": ["element.io", "*.element.io"],
517 "banned_domains": ["staging.element.io"],
518 }));
519
520 #[allow(clippy::disallowed_types)]
521 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
522 .join("..")
523 .join("..")
524 .join("policies")
525 .join("policy.wasm");
526
527 let file = tokio::fs::File::open(path).await.unwrap();
528
529 let factory = PolicyFactory::load(file, data, make_entrypoints())
530 .await
531 .unwrap();
532
533 let mut policy = factory.instantiate().await.unwrap();
534
535 let res = policy
536 .evaluate_register(RegisterInput {
537 registration_method: RegistrationMethod::Password,
538 username: "hello",
539 email: Some("hello@example.com"),
540 requester: Requester {
541 ip_address: None,
542 user_agent: None,
543 },
544 })
545 .await
546 .unwrap();
547 assert!(!res.valid());
548
549 let res = policy
550 .evaluate_register(RegisterInput {
551 registration_method: RegistrationMethod::Password,
552 username: "hello",
553 email: Some("hello@foo.element.io"),
554 requester: Requester {
555 ip_address: None,
556 user_agent: None,
557 },
558 })
559 .await
560 .unwrap();
561 assert!(res.valid());
562
563 let res = policy
564 .evaluate_register(RegisterInput {
565 registration_method: RegistrationMethod::Password,
566 username: "hello",
567 email: Some("hello@staging.element.io"),
568 requester: Requester {
569 ip_address: None,
570 user_agent: None,
571 },
572 })
573 .await
574 .unwrap();
575 assert!(!res.valid());
576 }
577
578 #[tokio::test]
579 async fn test_dynamic_data() {
580 let data = Data::new("example.com".to_owned(), None);
581
582 #[allow(clippy::disallowed_types)]
583 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
584 .join("..")
585 .join("..")
586 .join("policies")
587 .join("policy.wasm");
588
589 let file = tokio::fs::File::open(path).await.unwrap();
590
591 let factory = PolicyFactory::load(file, data, make_entrypoints())
592 .await
593 .unwrap();
594
595 let mut policy = factory.instantiate().await.unwrap();
596
597 let res = policy
598 .evaluate_register(RegisterInput {
599 registration_method: RegistrationMethod::Password,
600 username: "hello",
601 email: Some("hello@example.com"),
602 requester: Requester {
603 ip_address: None,
604 user_agent: None,
605 },
606 })
607 .await
608 .unwrap();
609 assert!(res.valid());
610
611 factory
613 .set_dynamic_data(mas_data_model::PolicyData {
614 id: Ulid::nil(),
615 created_at: SystemTime::now().into(),
616 data: serde_json::json!({
617 "emails": {
618 "banned_addresses": {
619 "substrings": ["hello"]
620 }
621 }
622 }),
623 })
624 .await
625 .unwrap();
626 let mut policy = factory.instantiate().await.unwrap();
627 let res = policy
628 .evaluate_register(RegisterInput {
629 registration_method: RegistrationMethod::Password,
630 username: "hello",
631 email: Some("hello@example.com"),
632 requester: Requester {
633 ip_address: None,
634 user_agent: None,
635 },
636 })
637 .await
638 .unwrap();
639 assert!(!res.valid());
640 }
641
642 #[tokio::test]
643 async fn test_big_dynamic_data() {
644 let data = Data::new("example.com".to_owned(), None);
645
646 #[allow(clippy::disallowed_types)]
647 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
648 .join("..")
649 .join("..")
650 .join("policies")
651 .join("policy.wasm");
652
653 let file = tokio::fs::File::open(path).await.unwrap();
654
655 let factory = PolicyFactory::load(file, data, make_entrypoints())
656 .await
657 .unwrap();
658
659 let data: Vec<String> = (0..(1024 * 1024 / 8))
662 .map(|i| format!("{:05}", i % 100_000))
663 .collect();
664 let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
665 factory
666 .set_dynamic_data(mas_data_model::PolicyData {
667 id: Ulid::nil(),
668 created_at: SystemTime::now().into(),
669 data: json,
670 })
671 .await
672 .unwrap();
673
674 let mut policy = factory.instantiate().await.unwrap();
677 let res = policy
678 .evaluate_register(RegisterInput {
679 registration_method: RegistrationMethod::Password,
680 username: "hello",
681 email: Some("12345@example.com"),
682 requester: Requester {
683 ip_address: None,
684 user_agent: None,
685 },
686 })
687 .await
688 .unwrap();
689 assert!(!res.valid());
690 }
691
692 #[test]
693 fn test_merge() {
694 use serde_json::json as j;
695
696 let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
698 assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
699
700 let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
702 assert_eq!(res, j!({"hello": "john"}));
703
704 let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
705 assert_eq!(res, j!({"hello": false}));
706
707 let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
708 assert_eq!(res, j!({"hello": 42}));
709
710 merge_data(j!({"hello": "world"}), j!({"hello": 123}))
712 .expect_err("Can't merge different types");
713
714 let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
716 assert_eq!(res, j!({"hello": ["world", "john"]}));
717
718 let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
720 assert_eq!(res, j!({"hello": null}));
721
722 let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
724 assert_eq!(res, j!({"hello": "world"}));
725
726 let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
728 assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
729 }
730}