Skip to main content

mas_policy/
model.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! Input and output types for policy evaluation.
8//!
9//! This is useful to generate JSON schemas for each input type, which can then
10//! be type-checked by Open Policy Agent.
11
12use std::net::IpAddr;
13
14use mas_data_model::{Client, User};
15use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18
19/// Violation variants identified by a well-known policy code (under the `code`
20/// key).
21#[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
22#[serde(tag = "code", rename_all = "kebab-case")]
23pub enum ViolationVariant {
24    /// The username is too short.
25    UsernameTooShort,
26
27    /// The username is too long.
28    UsernameTooLong,
29
30    /// The username contains invalid characters.
31    UsernameInvalidChars,
32
33    /// The username contains only numeric characters.
34    UsernameAllNumeric,
35
36    /// The username is banned.
37    UsernameBanned,
38
39    /// The username is not allowed.
40    UsernameNotAllowed,
41
42    /// The email domain is not allowed.
43    EmailDomainNotAllowed,
44
45    /// The email domain is banned.
46    EmailDomainBanned,
47
48    /// The email address is not allowed.
49    EmailNotAllowed,
50
51    /// The email address is banned.
52    EmailBanned,
53
54    /// The user has reached their session limit.
55    TooManySessions,
56}
57
58impl ViolationVariant {
59    /// Returns the code as a string
60    #[must_use]
61    pub fn as_str(&self) -> &'static str {
62        match self {
63            Self::UsernameTooShort => "username-too-short",
64            Self::UsernameTooLong => "username-too-long",
65            Self::UsernameInvalidChars => "username-invalid-chars",
66            Self::UsernameAllNumeric => "username-all-numeric",
67            Self::UsernameBanned => "username-banned",
68            Self::UsernameNotAllowed => "username-not-allowed",
69            Self::EmailDomainNotAllowed => "email-domain-not-allowed",
70            Self::EmailDomainBanned => "email-domain-banned",
71            Self::EmailNotAllowed => "email-not-allowed",
72            Self::EmailBanned => "email-banned",
73            Self::TooManySessions => "too-many-sessions",
74        }
75    }
76}
77
78/// A single violation of a policy.
79#[derive(Serialize, Deserialize, Debug, JsonSchema)]
80pub struct Violation {
81    pub msg: String,
82    pub redirect_uri: Option<String>,
83    pub field: Option<String>,
84
85    // We flatten as policies expect `code` as another top-level field.
86    //
87    // This also means all of the extra fields from the variant will be splatted at this
88    // level which is fine (arbitrary).
89    #[serde(flatten)]
90    pub variant: Option<ViolationVariant>,
91}
92
93/// The result of a policy evaluation.
94#[derive(Deserialize, Debug)]
95pub struct EvaluationResult {
96    #[serde(rename = "result")]
97    pub violations: Vec<Violation>,
98}
99
100impl std::fmt::Display for EvaluationResult {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        let mut first = true;
103        for violation in &self.violations {
104            if first {
105                first = false;
106            } else {
107                write!(f, ", ")?;
108            }
109            write!(f, "{}", violation.msg)?;
110        }
111        Ok(())
112    }
113}
114
115impl EvaluationResult {
116    /// Returns true if the policy evaluation was successful.
117    #[must_use]
118    pub fn valid(&self) -> bool {
119        self.violations.is_empty()
120    }
121}
122
123/// Identity of the requester
124#[derive(Serialize, Debug, Default, JsonSchema)]
125#[serde(rename_all = "snake_case")]
126pub struct Requester {
127    /// IP address of the entity making the request
128    pub ip_address: Option<IpAddr>,
129
130    /// User agent of the entity making the request
131    pub user_agent: Option<String>,
132}
133
134#[derive(Serialize, Debug, JsonSchema)]
135pub enum RegistrationMethod {
136    #[serde(rename = "password")]
137    Password,
138
139    #[serde(rename = "upstream-oauth2")]
140    UpstreamOAuth2,
141}
142
143/// Input for the user registration policy.
144#[derive(Serialize, Debug, JsonSchema)]
145#[serde(tag = "registration_method")]
146pub struct RegisterInput<'a> {
147    pub registration_method: RegistrationMethod,
148
149    pub username: &'a str,
150
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub email: Option<&'a str>,
153
154    pub requester: Requester,
155}
156
157/// Input for the client registration policy.
158#[derive(Serialize, Debug, JsonSchema)]
159#[serde(rename_all = "snake_case")]
160pub struct ClientRegistrationInput<'a> {
161    #[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
162    pub client_metadata: &'a VerifiedClientMetadata,
163    pub requester: Requester,
164}
165
166#[derive(Serialize, Debug, JsonSchema)]
167#[serde(rename_all = "snake_case")]
168pub enum GrantType {
169    AuthorizationCode,
170    ClientCredentials,
171    #[serde(rename = "urn:ietf:params:oauth:grant-type:device_code")]
172    DeviceCode,
173}
174
175/// Input for the authorization grant policy.
176#[derive(Serialize, Debug, JsonSchema)]
177#[serde(rename_all = "snake_case")]
178pub struct AuthorizationGrantInput<'a> {
179    #[schemars(with = "Option<std::collections::HashMap<String, serde_json::Value>>")]
180    pub user: Option<&'a User>,
181
182    /// How many sessions the user has.
183    /// Not populated if it's not a user logging in.
184    pub session_counts: Option<SessionCounts>,
185
186    #[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
187    pub client: &'a Client,
188
189    #[schemars(with = "String")]
190    pub scope: &'a Scope,
191
192    pub grant_type: GrantType,
193
194    pub requester: Requester,
195}
196
197/// Input for the compatibility login policy.
198#[derive(Serialize, Debug, JsonSchema)]
199#[serde(rename_all = "snake_case")]
200pub struct CompatLoginInput<'a> {
201    #[schemars(with = "std::collections::HashMap<String, serde_json::Value>")]
202    pub user: &'a User,
203
204    /// How many sessions the user has.
205    pub session_counts: SessionCounts,
206
207    /// Whether a session will be replaced by this login
208    pub session_replaced: bool,
209
210    /// What type of login is being performed.
211    /// This also determines whether the login is interactive.
212    pub login: CompatLogin,
213
214    pub requester: Requester,
215}
216
217#[derive(Serialize, Debug, JsonSchema)]
218#[serde(tag = "type")]
219pub enum CompatLogin {
220    /// Used as the interactive part of SSO login.
221    #[serde(rename = "m.login.sso")]
222    Sso { redirect_uri: String },
223
224    /// Used as the final (non-interactive) stage of SSO login.
225    #[serde(rename = "m.login.token")]
226    Token,
227
228    /// Non-interactive password-over-the-API login.
229    #[serde(rename = "m.login.password")]
230    Password,
231}
232
233/// Information about how many sessions the user has
234#[derive(Serialize, Debug, JsonSchema)]
235pub struct SessionCounts {
236    pub total: u64,
237
238    pub oauth2: u64,
239    pub compat: u64,
240    pub personal: u64,
241}
242
243/// Input for the email add policy.
244#[derive(Serialize, Debug, JsonSchema)]
245#[serde(rename_all = "snake_case")]
246pub struct EmailInput<'a> {
247    pub email: &'a str,
248
249    pub requester: Requester,
250}