Skip to main content

mas_policy/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::{SessionLimitConfig, Ulid};
13use opa_wasm::{
14    Runtime,
15    wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use serde::Serialize;
18use thiserror::Error;
19use tokio::io::{AsyncRead, AsyncReadExt};
20
21pub use self::model::{
22    AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput,
23    EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
24    ViolationVariant,
25};
26
27#[derive(Debug, Error)]
28pub enum LoadError {
29    #[error("failed to read module")]
30    Read(#[from] tokio::io::Error),
31
32    #[error("failed to create WASM engine")]
33    Engine(#[source] opa_wasm::wasmtime::Error),
34
35    #[error("module compilation task crashed")]
36    CompilationTask(#[from] tokio::task::JoinError),
37
38    #[error("failed to compile WASM module")]
39    Compilation(#[source] anyhow::Error),
40
41    #[error("invalid policy data")]
42    InvalidData(#[source] anyhow::Error),
43
44    #[error("failed to instantiate a test instance")]
45    Instantiate(#[source] InstantiateError),
46}
47
48impl LoadError {
49    /// Creates an example of an invalid data error, used for API response
50    /// documentation
51    #[doc(hidden)]
52    #[must_use]
53    pub fn invalid_data_example() -> Self {
54        Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
55    }
56}
57
58#[derive(Debug, Error)]
59pub enum InstantiateError {
60    #[error("failed to create WASM runtime")]
61    Runtime(#[source] anyhow::Error),
62
63    #[error("missing entrypoint {entrypoint}")]
64    MissingEntrypoint { entrypoint: String },
65
66    #[error("failed to load policy data")]
67    LoadData(#[source] anyhow::Error),
68}
69
70/// Holds the entrypoint of each policy
71#[derive(Debug, Clone)]
72pub struct Entrypoints {
73    pub register: String,
74    pub client_registration: String,
75    pub authorization_grant: String,
76    pub compat_login: String,
77    pub email: String,
78}
79
80impl Entrypoints {
81    fn all(&self) -> [&str; 5] {
82        [
83            self.register.as_str(),
84            self.client_registration.as_str(),
85            self.authorization_grant.as_str(),
86            self.compat_login.as_str(),
87            self.email.as_str(),
88        ]
89    }
90}
91
92/// Global static data that stays the same for the life of the [`PolicyFactory`]
93#[derive(Debug)]
94pub struct Data {
95    base: BaseData,
96
97    // We will merge this in a custom way, so don't emit as part of the base
98    rest: Option<serde_json::Value>,
99}
100
101#[derive(Serialize, Debug)]
102struct BaseData {
103    server_name: String,
104
105    /// Limits on the number of application sessions that each user can have
106    session_limit: Option<SessionLimitConfig>,
107}
108
109impl Data {
110    #[must_use]
111    pub fn new(server_name: String, session_limit: Option<SessionLimitConfig>) -> Self {
112        Self {
113            base: BaseData {
114                server_name,
115                session_limit,
116            },
117
118            rest: None,
119        }
120    }
121
122    #[must_use]
123    pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
124        self.rest = Some(rest);
125        self
126    }
127
128    fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
129        let base = serde_json::to_value(&self.base)?;
130
131        if let Some(rest) = &self.rest {
132            merge_data(base, rest.clone())
133        } else {
134            Ok(base)
135        }
136    }
137}
138
139fn value_kind(value: &serde_json::Value) -> &'static str {
140    match value {
141        serde_json::Value::Object(_) => "object",
142        serde_json::Value::Array(_) => "array",
143        serde_json::Value::String(_) => "string",
144        serde_json::Value::Number(_) => "number",
145        serde_json::Value::Bool(_) => "boolean",
146        serde_json::Value::Null => "null",
147    }
148}
149
150fn merge_data(
151    mut left: serde_json::Value,
152    right: serde_json::Value,
153) -> Result<serde_json::Value, anyhow::Error> {
154    merge_data_rec(&mut left, right)?;
155    Ok(left)
156}
157
158fn merge_data_rec(
159    left: &mut serde_json::Value,
160    right: serde_json::Value,
161) -> Result<(), anyhow::Error> {
162    match (left, right) {
163        (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
164            for (key, value) in right {
165                if let Some(left_value) = left.get_mut(&key) {
166                    merge_data_rec(left_value, value)?;
167                } else {
168                    left.insert(key, value);
169                }
170            }
171        }
172        (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
173            left.extend(right);
174        }
175        // Other values override
176        (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
177            *left = right;
178        }
179        (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
180            *left = right;
181        }
182        (serde_json::Value::String(left), serde_json::Value::String(right)) => {
183            *left = right;
184        }
185
186        // Null gets overridden by anything
187        (left, right) if left.is_null() => *left = right,
188
189        // Null on the right makes the left value null
190        (left, right) if right.is_null() => *left = right,
191
192        (left, right) => anyhow::bail!(
193            "Cannot merge a {} into a {}",
194            value_kind(&right),
195            value_kind(left),
196        ),
197    }
198
199    Ok(())
200}
201
202/// Global dynamic data
203///
204/// Hint: there is an admin API to manage this, see
205/// `crates/handlers/src/admin/v1/policy_data/set.rs`
206struct DynamicData {
207    version: Option<Ulid>,
208    merged: serde_json::Value,
209}
210
211pub struct PolicyFactory {
212    engine: Engine,
213    module: Module,
214    data: Data,
215    dynamic_data: ArcSwap<DynamicData>,
216    entrypoints: Entrypoints,
217}
218
219impl PolicyFactory {
220    /// Load the policy from the given data source.
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if the policy can't be loaded or instantiated.
225    #[tracing::instrument(name = "policy.load", skip(source))]
226    pub async fn load(
227        mut source: impl AsyncRead + std::marker::Unpin,
228        data: Data,
229        entrypoints: Entrypoints,
230    ) -> Result<Self, LoadError> {
231        let mut config = Config::default();
232        config.cranelift_opt_level(OptLevel::SpeedAndSize);
233
234        let engine = Engine::new(&config).map_err(LoadError::Engine)?;
235
236        // Read and compile the module
237        let mut buf = Vec::new();
238        source.read_to_end(&mut buf).await?;
239        // Compilation is CPU-bound, so spawn that in a blocking task
240        let (engine, module) = tokio::task::spawn_blocking(move || {
241            let module = Module::new(&engine, buf)?;
242            anyhow::Ok((engine, module))
243        })
244        .await?
245        .map_err(LoadError::Compilation)?;
246
247        let merged = data.to_value().map_err(LoadError::InvalidData)?;
248        let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
249            version: None,
250            merged,
251        }));
252
253        let factory = Self {
254            engine,
255            module,
256            data,
257            dynamic_data,
258            entrypoints,
259        };
260
261        // Try to instantiate
262        factory
263            .instantiate()
264            .await
265            .map_err(LoadError::Instantiate)?;
266
267        Ok(factory)
268    }
269
270    /// Set the dynamic data for the policy.
271    ///
272    /// The `dynamic_data` object is merged with the static data given when the
273    /// policy was loaded.
274    ///
275    /// Returns `true` if the data was updated, `false` if the version
276    /// of the dynamic data was the same as the one we already have.
277    ///
278    /// # Errors
279    ///
280    /// Returns an error if the data can't be merged with the static data, or if
281    /// the policy can't be instantiated with the new data.
282    pub async fn set_dynamic_data(
283        &self,
284        dynamic_data: mas_data_model::PolicyData,
285    ) -> Result<bool, LoadError> {
286        // Check if the version of the dynamic data we have is the same as the one we're
287        // trying to set
288        if self.dynamic_data.load().version == Some(dynamic_data.id) {
289            // Don't do anything if the version is the same
290            return Ok(false);
291        }
292
293        let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
294        let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
295
296        // Try to instantiate with the new data
297        self.instantiate_with_data(&merged)
298            .await
299            .map_err(LoadError::Instantiate)?;
300
301        // If instantiation succeeds, swap the data
302        self.dynamic_data.store(Arc::new(DynamicData {
303            version: Some(dynamic_data.id),
304            merged,
305        }));
306
307        Ok(true)
308    }
309
310    /// Create a new policy instance.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if the policy can't be instantiated with the current
315    /// dynamic data.
316    #[tracing::instrument(name = "policy.instantiate", skip_all)]
317    pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
318        let data = self.dynamic_data.load();
319        self.instantiate_with_data(&data.merged).await
320    }
321
322    async fn instantiate_with_data(
323        &self,
324        data: &serde_json::Value,
325    ) -> Result<Policy, InstantiateError> {
326        tracing::debug!("Instantiating policy with data={}", data);
327        let mut store = Store::new(&self.engine, ());
328        let runtime = Runtime::new(&mut store, &self.module)
329            .await
330            .map_err(InstantiateError::Runtime)?;
331
332        // Check that we have the required entrypoints
333        let policy_entrypoints = runtime.entrypoints();
334
335        for e in self.entrypoints.all() {
336            if !policy_entrypoints.contains(e) {
337                return Err(InstantiateError::MissingEntrypoint {
338                    entrypoint: e.to_owned(),
339                });
340            }
341        }
342
343        let instance = runtime
344            .with_data(&mut store, data)
345            .await
346            .map_err(InstantiateError::LoadData)?;
347
348        Ok(Policy {
349            store,
350            instance,
351            entrypoints: self.entrypoints.clone(),
352        })
353    }
354}
355
356pub struct Policy {
357    store: Store<()>,
358    instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
359    entrypoints: Entrypoints,
360}
361
362#[derive(Debug, Error)]
363#[error("failed to evaluate policy")]
364pub enum EvaluationError {
365    Serialization(#[from] serde_json::Error),
366    Evaluation(#[from] anyhow::Error),
367}
368
369impl Policy {
370    /// Evaluate the 'email' entrypoint.
371    ///
372    /// # Errors
373    ///
374    /// Returns an error if the policy engine fails to evaluate the entrypoint.
375    #[tracing::instrument(
376        name = "policy.evaluate_email",
377        skip_all,
378        fields(
379            %input.email,
380        ),
381    )]
382    pub async fn evaluate_email(
383        &mut self,
384        input: EmailInput<'_>,
385    ) -> Result<EvaluationResult, EvaluationError> {
386        let [res]: [EvaluationResult; 1] = self
387            .instance
388            .evaluate(&mut self.store, &self.entrypoints.email, &input)
389            .await?;
390
391        Ok(res)
392    }
393
394    /// Evaluate the 'register' entrypoint.
395    ///
396    /// # Errors
397    ///
398    /// Returns an error if the policy engine fails to evaluate the entrypoint.
399    #[tracing::instrument(
400        name = "policy.evaluate.register",
401        skip_all,
402        fields(
403            ?input.registration_method,
404            input.username = input.username,
405            input.email = input.email,
406        ),
407    )]
408    pub async fn evaluate_register(
409        &mut self,
410        input: RegisterInput<'_>,
411    ) -> Result<EvaluationResult, EvaluationError> {
412        let [res]: [EvaluationResult; 1] = self
413            .instance
414            .evaluate(&mut self.store, &self.entrypoints.register, &input)
415            .await?;
416
417        Ok(res)
418    }
419
420    /// Evaluate the `client_registration` entrypoint.
421    ///
422    /// # Errors
423    ///
424    /// Returns an error if the policy engine fails to evaluate the entrypoint.
425    #[tracing::instrument(skip(self))]
426    pub async fn evaluate_client_registration(
427        &mut self,
428        input: ClientRegistrationInput<'_>,
429    ) -> Result<EvaluationResult, EvaluationError> {
430        let [res]: [EvaluationResult; 1] = self
431            .instance
432            .evaluate(
433                &mut self.store,
434                &self.entrypoints.client_registration,
435                &input,
436            )
437            .await?;
438
439        Ok(res)
440    }
441
442    /// Evaluate the `authorization_grant` entrypoint.
443    ///
444    /// # Errors
445    ///
446    /// Returns an error if the policy engine fails to evaluate the entrypoint.
447    #[tracing::instrument(
448        name = "policy.evaluate.authorization_grant",
449        skip_all,
450        fields(
451            %input.scope,
452            %input.client.id,
453        ),
454    )]
455    pub async fn evaluate_authorization_grant(
456        &mut self,
457        input: AuthorizationGrantInput<'_>,
458    ) -> Result<EvaluationResult, EvaluationError> {
459        let [res]: [EvaluationResult; 1] = self
460            .instance
461            .evaluate(
462                &mut self.store,
463                &self.entrypoints.authorization_grant,
464                &input,
465            )
466            .await?;
467
468        Ok(res)
469    }
470
471    /// Evaluate the `compat_login` entrypoint.
472    ///
473    /// # Errors
474    ///
475    /// Returns an error if the policy engine fails to evaluate the entrypoint.
476    #[tracing::instrument(
477        name = "policy.evaluate.compat_login",
478        skip_all,
479        fields(
480            %input.user.id,
481        ),
482    )]
483    pub async fn evaluate_compat_login(
484        &mut self,
485        input: CompatLoginInput<'_>,
486    ) -> Result<EvaluationResult, EvaluationError> {
487        let [res]: [EvaluationResult; 1] = self
488            .instance
489            .evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
490            .await?;
491
492        Ok(res)
493    }
494}
495
496#[cfg(test)]
497mod tests {
498
499    use std::time::SystemTime;
500
501    use super::*;
502
503    fn make_entrypoints() -> Entrypoints {
504        Entrypoints {
505            register: "register/violation".to_owned(),
506            client_registration: "client_registration/violation".to_owned(),
507            authorization_grant: "authorization_grant/violation".to_owned(),
508            compat_login: "compat_login/violation".to_owned(),
509            email: "email/violation".to_owned(),
510        }
511    }
512
513    #[tokio::test]
514    async fn test_register() {
515        let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
516            "allowed_domains": ["element.io", "*.element.io"],
517            "banned_domains": ["staging.element.io"],
518        }));
519
520        #[allow(clippy::disallowed_types)]
521        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
522            .join("..")
523            .join("..")
524            .join("policies")
525            .join("policy.wasm");
526
527        let file = tokio::fs::File::open(path).await.unwrap();
528
529        let factory = PolicyFactory::load(file, data, make_entrypoints())
530            .await
531            .unwrap();
532
533        let mut policy = factory.instantiate().await.unwrap();
534
535        let res = policy
536            .evaluate_register(RegisterInput {
537                registration_method: RegistrationMethod::Password,
538                username: "hello",
539                email: Some("hello@example.com"),
540                requester: Requester {
541                    ip_address: None,
542                    user_agent: None,
543                },
544            })
545            .await
546            .unwrap();
547        assert!(!res.valid());
548
549        let res = policy
550            .evaluate_register(RegisterInput {
551                registration_method: RegistrationMethod::Password,
552                username: "hello",
553                email: Some("hello@foo.element.io"),
554                requester: Requester {
555                    ip_address: None,
556                    user_agent: None,
557                },
558            })
559            .await
560            .unwrap();
561        assert!(res.valid());
562
563        let res = policy
564            .evaluate_register(RegisterInput {
565                registration_method: RegistrationMethod::Password,
566                username: "hello",
567                email: Some("hello@staging.element.io"),
568                requester: Requester {
569                    ip_address: None,
570                    user_agent: None,
571                },
572            })
573            .await
574            .unwrap();
575        assert!(!res.valid());
576    }
577
578    #[tokio::test]
579    async fn test_dynamic_data() {
580        let data = Data::new("example.com".to_owned(), None);
581
582        #[allow(clippy::disallowed_types)]
583        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
584            .join("..")
585            .join("..")
586            .join("policies")
587            .join("policy.wasm");
588
589        let file = tokio::fs::File::open(path).await.unwrap();
590
591        let factory = PolicyFactory::load(file, data, make_entrypoints())
592            .await
593            .unwrap();
594
595        let mut policy = factory.instantiate().await.unwrap();
596
597        let res = policy
598            .evaluate_register(RegisterInput {
599                registration_method: RegistrationMethod::Password,
600                username: "hello",
601                email: Some("hello@example.com"),
602                requester: Requester {
603                    ip_address: None,
604                    user_agent: None,
605                },
606            })
607            .await
608            .unwrap();
609        assert!(res.valid());
610
611        // Update the policy data
612        factory
613            .set_dynamic_data(mas_data_model::PolicyData {
614                id: Ulid::nil(),
615                created_at: SystemTime::now().into(),
616                data: serde_json::json!({
617                    "emails": {
618                        "banned_addresses": {
619                            "substrings": ["hello"]
620                        }
621                    }
622                }),
623            })
624            .await
625            .unwrap();
626        let mut policy = factory.instantiate().await.unwrap();
627        let res = policy
628            .evaluate_register(RegisterInput {
629                registration_method: RegistrationMethod::Password,
630                username: "hello",
631                email: Some("hello@example.com"),
632                requester: Requester {
633                    ip_address: None,
634                    user_agent: None,
635                },
636            })
637            .await
638            .unwrap();
639        assert!(!res.valid());
640    }
641
642    #[tokio::test]
643    async fn test_big_dynamic_data() {
644        let data = Data::new("example.com".to_owned(), None);
645
646        #[allow(clippy::disallowed_types)]
647        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
648            .join("..")
649            .join("..")
650            .join("policies")
651            .join("policy.wasm");
652
653        let file = tokio::fs::File::open(path).await.unwrap();
654
655        let factory = PolicyFactory::load(file, data, make_entrypoints())
656            .await
657            .unwrap();
658
659        // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
660        // characters including the quotes and a comma.
661        let data: Vec<String> = (0..(1024 * 1024 / 8))
662            .map(|i| format!("{:05}", i % 100_000))
663            .collect();
664        let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
665        factory
666            .set_dynamic_data(mas_data_model::PolicyData {
667                id: Ulid::nil(),
668                created_at: SystemTime::now().into(),
669                data: json,
670            })
671            .await
672            .unwrap();
673
674        // Try instantiating the policy, make sure 5-digit numbers are banned from email
675        // addresses
676        let mut policy = factory.instantiate().await.unwrap();
677        let res = policy
678            .evaluate_register(RegisterInput {
679                registration_method: RegistrationMethod::Password,
680                username: "hello",
681                email: Some("12345@example.com"),
682                requester: Requester {
683                    ip_address: None,
684                    user_agent: None,
685                },
686            })
687            .await
688            .unwrap();
689        assert!(!res.valid());
690    }
691
692    #[test]
693    fn test_merge() {
694        use serde_json::json as j;
695
696        // Merging objects
697        let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
698        assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
699
700        // Override a value of the same type
701        let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
702        assert_eq!(res, j!({"hello": "john"}));
703
704        let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
705        assert_eq!(res, j!({"hello": false}));
706
707        let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
708        assert_eq!(res, j!({"hello": 42}));
709
710        // Override a value of a different type
711        merge_data(j!({"hello": "world"}), j!({"hello": 123}))
712            .expect_err("Can't merge different types");
713
714        // Merge arrays
715        let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
716        assert_eq!(res, j!({"hello": ["world", "john"]}));
717
718        // Null overrides a value
719        let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
720        assert_eq!(res, j!({"hello": null}));
721
722        // Null gets overridden by a value
723        let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
724        assert_eq!(res, j!({"hello": "world"}));
725
726        // Objects get deeply merged
727        let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
728        assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
729    }
730}