1#![allow(clippy::module_name_repetitions)]
13
14use std::{
15 borrow::Cow,
16 collections::BTreeSet,
17 iter::FromIterator,
18 ops::{Deref, DerefMut},
19 str::FromStr,
20};
21
22use serde::{Deserialize, Serialize};
23use thiserror::Error;
24
25#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)]
27#[error("Invalid scope format")]
28pub struct InvalidScope;
29
30#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
32pub struct ScopeToken(Cow<'static, str>);
33
34impl ScopeToken {
35 #[must_use]
38 pub const fn from_static(token: &'static str) -> Self {
39 Self(Cow::Borrowed(token))
40 }
41
42 #[must_use]
44 pub fn as_str(&self) -> &str {
45 self.0.as_ref()
46 }
47}
48
49pub const OPENID: ScopeToken = ScopeToken::from_static("openid");
53
54pub const PROFILE: ScopeToken = ScopeToken::from_static("profile");
58
59pub const EMAIL: ScopeToken = ScopeToken::from_static("email");
63
64pub const ADDRESS: ScopeToken = ScopeToken::from_static("address");
68
69pub const PHONE: ScopeToken = ScopeToken::from_static("phone");
73
74pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::from_static("offline_access");
80
81fn nqchar(c: char) -> bool {
86 '\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c)
87}
88
89impl FromStr for ScopeToken {
90 type Err = InvalidScope;
91
92 fn from_str(s: &str) -> Result<Self, Self::Err> {
93 if !s.is_empty() && s.chars().all(nqchar) {
98 Ok(ScopeToken(Cow::Owned(s.into())))
99 } else {
100 Err(InvalidScope)
101 }
102 }
103}
104
105impl Deref for ScopeToken {
106 type Target = str;
107
108 fn deref(&self) -> &Self::Target {
109 &self.0
110 }
111}
112
113impl std::fmt::Display for ScopeToken {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 self.0.fmt(f)
116 }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
121pub struct Scope(BTreeSet<ScopeToken>);
122
123impl Deref for Scope {
124 type Target = BTreeSet<ScopeToken>;
125
126 fn deref(&self) -> &Self::Target {
127 &self.0
128 }
129}
130
131impl DerefMut for Scope {
132 fn deref_mut(&mut self) -> &mut Self::Target {
133 &mut self.0
134 }
135}
136
137impl FromStr for Scope {
138 type Err = InvalidScope;
139
140 fn from_str(s: &str) -> Result<Self, Self::Err> {
141 let scopes: Result<BTreeSet<ScopeToken>, InvalidScope> =
146 s.split(' ').map(ScopeToken::from_str).collect();
147
148 Ok(Self(scopes?))
149 }
150}
151
152impl Scope {
153 #[must_use]
155 pub fn is_empty(&self) -> bool {
156 self.0.is_empty()
158 }
159
160 #[must_use]
162 pub fn len(&self) -> usize {
163 self.0.len()
164 }
165
166 #[must_use]
168 pub fn contains(&self, token: &str) -> bool {
169 ScopeToken::from_str(token).is_ok_and(|token| self.0.contains(&token))
170 }
171
172 pub fn insert(&mut self, value: ScopeToken) -> bool {
176 self.0.insert(value)
177 }
178}
179
180impl std::fmt::Display for Scope {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 for (index, token) in self.0.iter().enumerate() {
183 if index == 0 {
184 write!(f, "{token}")?;
185 } else {
186 write!(f, " {token}")?;
187 }
188 }
189
190 Ok(())
191 }
192}
193
194impl Serialize for Scope {
195 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
196 where
197 S: serde::Serializer,
198 {
199 self.to_string().serialize(serializer)
200 }
201}
202
203impl<'de> Deserialize<'de> for Scope {
204 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
205 where
206 D: serde::Deserializer<'de>,
207 {
208 let scope: String = Deserialize::deserialize(deserializer)?;
210 Scope::from_str(&scope).map_err(serde::de::Error::custom)
211 }
212}
213
214impl FromIterator<ScopeToken> for Scope {
215 fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
216 Self(BTreeSet::from_iter(iter))
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn parse_scope_token() {
226 assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID));
227
228 assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
229 }
230
231 #[test]
232 fn parse_scope() {
233 let scope = Scope::from_str("openid profile address").unwrap();
234 assert_eq!(scope.len(), 3);
235 assert!(scope.contains("openid"));
236 assert!(scope.contains("profile"));
237 assert!(scope.contains("address"));
238 assert!(!scope.contains("unknown"));
239
240 assert!(
241 Scope::from_str("").is_err(),
242 "there should always be at least one token in the scope"
243 );
244
245 assert!(Scope::from_str("invalid\\scope").is_err());
246 assert!(Scope::from_str("no double space").is_err());
247 assert!(Scope::from_str(" no leading space").is_err());
248 assert!(Scope::from_str("no trailing space ").is_err());
249
250 let scope = Scope::from_str("openid").unwrap();
251 assert_eq!(scope.len(), 1);
252 assert!(scope.contains("openid"));
253 assert!(!scope.contains("profile"));
254 assert!(!scope.contains("address"));
255
256 assert_eq!(
257 Scope::from_str("order does not matter"),
258 Scope::from_str("matter not order does"),
259 );
260
261 assert!(Scope::from_str("http://example.com").is_ok());
262 assert!(Scope::from_str("urn:matrix:client:api:*").is_ok());
263 assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:api:*").is_ok());
264 }
265}