1use std::{convert::Infallible, net::IpAddr, sync::Arc, time::Duration};
9
10use axum::extract::FromRequestParts;
11use governor::{RateLimiter, clock::QuantaClock, state::keyed::DashMapStateStore};
12use mas_config::RateLimitingConfig;
13use mas_data_model::{User, UserEmailAuthentication};
14use ulid::Ulid;
15
16use crate::ClientIp;
17
18#[derive(Debug, Clone, thiserror::Error)]
19pub enum AccountRecoveryLimitedError {
20 #[error("Too many account recovery requests for requester {0}")]
21 Requester(RequesterFingerprint),
22
23 #[error("Too many account recovery requests for e-mail {0}")]
24 Email(String),
25}
26
27#[derive(Debug, Clone, Copy, thiserror::Error)]
28pub enum PasswordCheckLimitedError {
29 #[error("Too many password checks for requester {0}")]
30 Requester(RequesterFingerprint),
31
32 #[error("Too many password checks for user {0}")]
33 User(Ulid),
34}
35
36#[derive(Debug, Clone, thiserror::Error)]
37pub enum RegistrationLimitedError {
38 #[error("Too many account registration requests for requester {0}")]
39 Requester(RequesterFingerprint),
40}
41
42#[derive(Debug, Clone, thiserror::Error)]
43pub enum EmailAuthenticationLimitedError {
44 #[error("Too many email authentication requests for requester {0}")]
45 Requester(RequesterFingerprint),
46
47 #[error("Too many email authentication requests for authentication session {0}")]
48 Authentication(Ulid),
49
50 #[error("Too many email authentication requests for email {0}")]
51 Email(String),
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub struct RequesterFingerprint {
57 ip: Option<IpAddr>,
58}
59
60impl std::fmt::Display for RequesterFingerprint {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 if let Some(ip) = self.ip {
63 write!(f, "{ip}")
64 } else {
65 write!(f, "(NO CLIENT IP)")
66 }
67 }
68}
69
70impl RequesterFingerprint {
71 pub const EMPTY: Self = Self { ip: None };
74
75 #[must_use]
77 pub const fn new(ip: IpAddr) -> Self {
78 Self { ip: Some(ip) }
79 }
80}
81
82impl<S: Send + Sync> FromRequestParts<S> for RequesterFingerprint {
83 type Rejection = Infallible;
84
85 async fn from_request_parts(
86 parts: &mut axum::http::request::Parts,
87 _state: &S,
88 ) -> Result<Self, Self::Rejection> {
89 let ip = parts.extensions.get::<ClientIp>().and_then(|ip| ip.0);
90
91 if let Some(ip) = ip {
92 Ok(Self::new(ip))
93 } else {
94 tracing::warn!(
97 "Could not infer client IP address for an operation which rate-limits based on IP addresses"
98 );
99 Ok(Self::EMPTY)
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct Limiter {
107 inner: Arc<LimiterInner>,
108}
109
110type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
111
112#[derive(Debug)]
113struct LimiterInner {
114 account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
115 account_recovery_per_email: KeyedRateLimiter<String>,
116 password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
117 password_check_for_user: KeyedRateLimiter<Ulid>,
118 registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
119 email_authentication_per_requester: KeyedRateLimiter<RequesterFingerprint>,
120 email_authentication_per_email: KeyedRateLimiter<String>,
121 email_authentication_emails_per_session: KeyedRateLimiter<Ulid>,
122 email_authentication_attempt_per_session: KeyedRateLimiter<Ulid>,
123}
124
125impl LimiterInner {
126 fn new(config: &RateLimitingConfig) -> Option<Self> {
127 Some(Self {
128 account_recovery_per_requester: RateLimiter::keyed(
129 config.account_recovery.per_ip.to_quota()?,
130 ),
131 account_recovery_per_email: RateLimiter::keyed(
132 config.account_recovery.per_address.to_quota()?,
133 ),
134 password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
135 password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
136 registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
137 email_authentication_per_email: RateLimiter::keyed(
138 config.email_authentication.per_address.to_quota()?,
139 ),
140 email_authentication_per_requester: RateLimiter::keyed(
141 config.email_authentication.per_ip.to_quota()?,
142 ),
143 email_authentication_emails_per_session: RateLimiter::keyed(
144 config.email_authentication.emails_per_session.to_quota()?,
145 ),
146 email_authentication_attempt_per_session: RateLimiter::keyed(
147 config.email_authentication.attempt_per_session.to_quota()?,
148 ),
149 })
150 }
151}
152
153impl Limiter {
154 #[must_use]
159 pub fn new(config: &RateLimitingConfig) -> Option<Self> {
160 Some(Self {
161 inner: Arc::new(LimiterInner::new(config)?),
162 })
163 }
164
165 pub fn start(&self) {
170 let this = self.clone();
172 tokio::spawn(async move {
173 let mut interval = tokio::time::interval(Duration::from_mins(1));
175 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
176
177 loop {
178 this.inner.account_recovery_per_email.retain_recent();
180 this.inner.account_recovery_per_requester.retain_recent();
181 this.inner.password_check_for_requester.retain_recent();
182 this.inner.password_check_for_user.retain_recent();
183 this.inner.registration_per_requester.retain_recent();
184 this.inner.email_authentication_per_email.retain_recent();
185 this.inner
186 .email_authentication_per_requester
187 .retain_recent();
188 this.inner
189 .email_authentication_emails_per_session
190 .retain_recent();
191 this.inner
192 .email_authentication_attempt_per_session
193 .retain_recent();
194
195 interval.tick().await;
196 }
197 });
198 }
199
200 pub fn check_account_recovery(
206 &self,
207 requester: RequesterFingerprint,
208 email_address: &str,
209 ) -> Result<(), AccountRecoveryLimitedError> {
210 self.inner
211 .account_recovery_per_requester
212 .check_key(&requester)
213 .map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;
214
215 let canonical_email = email_address.to_lowercase();
219 self.inner
220 .account_recovery_per_email
221 .check_key(&canonical_email)
222 .map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;
223
224 Ok(())
225 }
226
227 pub fn check_password(
233 &self,
234 key: RequesterFingerprint,
235 user: &User,
236 ) -> Result<(), PasswordCheckLimitedError> {
237 self.inner
238 .password_check_for_requester
239 .check_key(&key)
240 .map_err(|_| PasswordCheckLimitedError::Requester(key))?;
241
242 self.inner
243 .password_check_for_user
244 .check_key(&user.id)
245 .map_err(|_| PasswordCheckLimitedError::User(user.id))?;
246
247 Ok(())
248 }
249
250 pub fn check_registration(
256 &self,
257 requester: RequesterFingerprint,
258 ) -> Result<(), RegistrationLimitedError> {
259 self.inner
260 .registration_per_requester
261 .check_key(&requester)
262 .map_err(|_| RegistrationLimitedError::Requester(requester))?;
263
264 Ok(())
265 }
266
267 pub fn check_email_authentication_email(
274 &self,
275 requester: RequesterFingerprint,
276 email: &str,
277 ) -> Result<(), EmailAuthenticationLimitedError> {
278 self.inner
279 .email_authentication_per_requester
280 .check_key(&requester)
281 .map_err(|_| EmailAuthenticationLimitedError::Requester(requester))?;
282
283 let canonical_email = email.to_lowercase();
287 self.inner
288 .email_authentication_per_email
289 .check_key(&canonical_email)
290 .map_err(|_| EmailAuthenticationLimitedError::Email(email.to_owned()))?;
291 Ok(())
292 }
293
294 pub fn check_email_authentication_attempt(
300 &self,
301 authentication: &UserEmailAuthentication,
302 ) -> Result<(), EmailAuthenticationLimitedError> {
303 self.inner
304 .email_authentication_attempt_per_session
305 .check_key(&authentication.id)
306 .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
307 }
308
309 pub fn check_email_authentication_send_code(
316 &self,
317 requester: RequesterFingerprint,
318 authentication: &UserEmailAuthentication,
319 ) -> Result<(), EmailAuthenticationLimitedError> {
320 self.check_email_authentication_email(requester, &authentication.email)?;
321 self.inner
322 .email_authentication_emails_per_session
323 .check_key(&authentication.id)
324 .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use mas_data_model::{Clock, User, clock::MockClock};
331 use rand::SeedableRng;
332
333 use super::*;
334
335 #[test]
336 fn test_password_check_limiter() {
337 let now = MockClock::default().now();
338 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
339
340 let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
341
342 let requesters: [_; 768] = (0..=255)
344 .flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
345 .collect::<Vec<_>>()
346 .try_into()
347 .unwrap();
348
349 let alice = User {
350 id: Ulid::from_datetime_with_source(now.into(), &mut rng),
351 username: "alice".to_owned(),
352 sub: "123-456".to_owned(),
353 created_at: now,
354 locked_at: None,
355 deactivated_at: None,
356 can_request_admin: false,
357 is_guest: true,
358 };
359
360 let bob = User {
361 id: Ulid::from_datetime_with_source(now.into(), &mut rng),
362 username: "bob".to_owned(),
363 sub: "123-456".to_owned(),
364 created_at: now,
365 locked_at: None,
366 deactivated_at: None,
367 can_request_admin: false,
368 is_guest: true,
369 };
370
371 assert!(limiter.check_password(requesters[0], &alice).is_ok());
373 assert!(limiter.check_password(requesters[0], &alice).is_ok());
374 assert!(limiter.check_password(requesters[0], &alice).is_ok());
375
376 assert!(limiter.check_password(requesters[0], &alice).is_err());
378 assert!(limiter.check_password(requesters[0], &bob).is_err());
380
381 assert!(limiter.check_password(requesters[1], &alice).is_ok());
383
384 for requester in requesters.iter().skip(2).take(598) {
387 assert!(limiter.check_password(*requester, &alice).is_ok());
388 assert!(limiter.check_password(*requester, &alice).is_ok());
389 assert!(limiter.check_password(*requester, &alice).is_ok());
390 assert!(limiter.check_password(*requester, &alice).is_err());
391 }
392
393 assert!(limiter.check_password(requesters[600], &alice).is_ok());
396 assert!(limiter.check_password(requesters[601], &alice).is_ok());
397 assert!(limiter.check_password(requesters[602], &alice).is_err());
398
399 assert!(limiter.check_password(requesters[603], &bob).is_ok());
401 }
402}