Skip to main content

mas_handlers/
rate_limit.rs

1// Copyright 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::{convert::Infallible, net::IpAddr, sync::Arc, time::Duration};
9
10use axum::extract::FromRequestParts;
11use governor::{RateLimiter, clock::QuantaClock, state::keyed::DashMapStateStore};
12use mas_config::RateLimitingConfig;
13use mas_data_model::{User, UserEmailAuthentication};
14use ulid::Ulid;
15
16use crate::ClientIp;
17
18#[derive(Debug, Clone, thiserror::Error)]
19pub enum AccountRecoveryLimitedError {
20    #[error("Too many account recovery requests for requester {0}")]
21    Requester(RequesterFingerprint),
22
23    #[error("Too many account recovery requests for e-mail {0}")]
24    Email(String),
25}
26
27#[derive(Debug, Clone, Copy, thiserror::Error)]
28pub enum PasswordCheckLimitedError {
29    #[error("Too many password checks for requester {0}")]
30    Requester(RequesterFingerprint),
31
32    #[error("Too many password checks for user {0}")]
33    User(Ulid),
34}
35
36#[derive(Debug, Clone, thiserror::Error)]
37pub enum RegistrationLimitedError {
38    #[error("Too many account registration requests for requester {0}")]
39    Requester(RequesterFingerprint),
40}
41
42#[derive(Debug, Clone, thiserror::Error)]
43pub enum EmailAuthenticationLimitedError {
44    #[error("Too many email authentication requests for requester {0}")]
45    Requester(RequesterFingerprint),
46
47    #[error("Too many email authentication requests for authentication session {0}")]
48    Authentication(Ulid),
49
50    #[error("Too many email authentication requests for email {0}")]
51    Email(String),
52}
53
54/// Key used to rate limit requests per requester
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub struct RequesterFingerprint {
57    ip: Option<IpAddr>,
58}
59
60impl std::fmt::Display for RequesterFingerprint {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        if let Some(ip) = self.ip {
63            write!(f, "{ip}")
64        } else {
65            write!(f, "(NO CLIENT IP)")
66        }
67    }
68}
69
70impl RequesterFingerprint {
71    /// An anonymous key with no IP address set. This should not be used in
72    /// production, and we should warn users if we can't find their client IPs.
73    pub const EMPTY: Self = Self { ip: None };
74
75    /// Create a new anonymous key with the given IP address
76    #[must_use]
77    pub const fn new(ip: IpAddr) -> Self {
78        Self { ip: Some(ip) }
79    }
80}
81
82impl<S: Send + Sync> FromRequestParts<S> for RequesterFingerprint {
83    type Rejection = Infallible;
84
85    async fn from_request_parts(
86        parts: &mut axum::http::request::Parts,
87        _state: &S,
88    ) -> Result<Self, Self::Rejection> {
89        let ip = parts.extensions.get::<ClientIp>().and_then(|ip| ip.0);
90
91        if let Some(ip) = ip {
92            Ok(Self::new(ip))
93        } else {
94            // If we can't infer the IP address, we'll just use an empty fingerprint and
95            // warn about it
96            tracing::warn!(
97                "Could not infer client IP address for an operation which rate-limits based on IP addresses"
98            );
99            Ok(Self::EMPTY)
100        }
101    }
102}
103
104/// Rate limiters for the different operations
105#[derive(Debug, Clone)]
106pub struct Limiter {
107    inner: Arc<LimiterInner>,
108}
109
110type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
111
112#[derive(Debug)]
113struct LimiterInner {
114    account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
115    account_recovery_per_email: KeyedRateLimiter<String>,
116    password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
117    password_check_for_user: KeyedRateLimiter<Ulid>,
118    registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
119    email_authentication_per_requester: KeyedRateLimiter<RequesterFingerprint>,
120    email_authentication_per_email: KeyedRateLimiter<String>,
121    email_authentication_emails_per_session: KeyedRateLimiter<Ulid>,
122    email_authentication_attempt_per_session: KeyedRateLimiter<Ulid>,
123}
124
125impl LimiterInner {
126    fn new(config: &RateLimitingConfig) -> Option<Self> {
127        Some(Self {
128            account_recovery_per_requester: RateLimiter::keyed(
129                config.account_recovery.per_ip.to_quota()?,
130            ),
131            account_recovery_per_email: RateLimiter::keyed(
132                config.account_recovery.per_address.to_quota()?,
133            ),
134            password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
135            password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
136            registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
137            email_authentication_per_email: RateLimiter::keyed(
138                config.email_authentication.per_address.to_quota()?,
139            ),
140            email_authentication_per_requester: RateLimiter::keyed(
141                config.email_authentication.per_ip.to_quota()?,
142            ),
143            email_authentication_emails_per_session: RateLimiter::keyed(
144                config.email_authentication.emails_per_session.to_quota()?,
145            ),
146            email_authentication_attempt_per_session: RateLimiter::keyed(
147                config.email_authentication.attempt_per_session.to_quota()?,
148            ),
149        })
150    }
151}
152
153impl Limiter {
154    /// Creates a new `Limiter` based on a `RateLimitingConfig`.
155    ///
156    /// If the config is not valid, returns `None`.
157    /// (This should not happen if the config was validated, though.)
158    #[must_use]
159    pub fn new(config: &RateLimitingConfig) -> Option<Self> {
160        Some(Self {
161            inner: Arc::new(LimiterInner::new(config)?),
162        })
163    }
164
165    /// Start the rate limiter housekeeping task
166    ///
167    /// This task will periodically remove old entries from the rate limiters,
168    /// to make sure we don't build up a huge number of entries in memory.
169    pub fn start(&self) {
170        // Spawn a task that will periodically clean the rate limiters
171        let this = self.clone();
172        tokio::spawn(async move {
173            // Run the task every minute
174            let mut interval = tokio::time::interval(Duration::from_mins(1));
175            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
176
177            loop {
178                // Call the retain_recent method on each rate limiter
179                this.inner.account_recovery_per_email.retain_recent();
180                this.inner.account_recovery_per_requester.retain_recent();
181                this.inner.password_check_for_requester.retain_recent();
182                this.inner.password_check_for_user.retain_recent();
183                this.inner.registration_per_requester.retain_recent();
184                this.inner.email_authentication_per_email.retain_recent();
185                this.inner
186                    .email_authentication_per_requester
187                    .retain_recent();
188                this.inner
189                    .email_authentication_emails_per_session
190                    .retain_recent();
191                this.inner
192                    .email_authentication_attempt_per_session
193                    .retain_recent();
194
195                interval.tick().await;
196            }
197        });
198    }
199
200    /// Check if an account recovery can be performed
201    ///
202    /// # Errors
203    ///
204    /// Returns an error if the operation is rate limited.
205    pub fn check_account_recovery(
206        &self,
207        requester: RequesterFingerprint,
208        email_address: &str,
209    ) -> Result<(), AccountRecoveryLimitedError> {
210        self.inner
211            .account_recovery_per_requester
212            .check_key(&requester)
213            .map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;
214
215        // Convert to lowercase to prevent bypassing the limit by enumerating different
216        // case variations.
217        // A case-folding transformation may be more proper.
218        let canonical_email = email_address.to_lowercase();
219        self.inner
220            .account_recovery_per_email
221            .check_key(&canonical_email)
222            .map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;
223
224        Ok(())
225    }
226
227    /// Check if a password check can be performed
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if the operation is rate limited
232    pub fn check_password(
233        &self,
234        key: RequesterFingerprint,
235        user: &User,
236    ) -> Result<(), PasswordCheckLimitedError> {
237        self.inner
238            .password_check_for_requester
239            .check_key(&key)
240            .map_err(|_| PasswordCheckLimitedError::Requester(key))?;
241
242        self.inner
243            .password_check_for_user
244            .check_key(&user.id)
245            .map_err(|_| PasswordCheckLimitedError::User(user.id))?;
246
247        Ok(())
248    }
249
250    /// Check if an account registration can be performed
251    ///
252    /// # Errors
253    ///
254    /// Returns an error if the operation is rate limited.
255    pub fn check_registration(
256        &self,
257        requester: RequesterFingerprint,
258    ) -> Result<(), RegistrationLimitedError> {
259        self.inner
260            .registration_per_requester
261            .check_key(&requester)
262            .map_err(|_| RegistrationLimitedError::Requester(requester))?;
263
264        Ok(())
265    }
266
267    /// Check if an email can be sent to the address for an email
268    /// authentication session
269    ///
270    /// # Errors
271    ///
272    /// Returns an error if the operation is rate limited.
273    pub fn check_email_authentication_email(
274        &self,
275        requester: RequesterFingerprint,
276        email: &str,
277    ) -> Result<(), EmailAuthenticationLimitedError> {
278        self.inner
279            .email_authentication_per_requester
280            .check_key(&requester)
281            .map_err(|_| EmailAuthenticationLimitedError::Requester(requester))?;
282
283        // Convert to lowercase to prevent bypassing the limit by enumerating different
284        // case variations.
285        // A case-folding transformation may be more proper.
286        let canonical_email = email.to_lowercase();
287        self.inner
288            .email_authentication_per_email
289            .check_key(&canonical_email)
290            .map_err(|_| EmailAuthenticationLimitedError::Email(email.to_owned()))?;
291        Ok(())
292    }
293
294    /// Check if an attempt can be done on an email authentication session
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if the operation is rate limited.
299    pub fn check_email_authentication_attempt(
300        &self,
301        authentication: &UserEmailAuthentication,
302    ) -> Result<(), EmailAuthenticationLimitedError> {
303        self.inner
304            .email_authentication_attempt_per_session
305            .check_key(&authentication.id)
306            .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
307    }
308
309    /// Check if a new authentication code can be sent for an email
310    /// authentication session
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if the operation is rate limited.
315    pub fn check_email_authentication_send_code(
316        &self,
317        requester: RequesterFingerprint,
318        authentication: &UserEmailAuthentication,
319    ) -> Result<(), EmailAuthenticationLimitedError> {
320        self.check_email_authentication_email(requester, &authentication.email)?;
321        self.inner
322            .email_authentication_emails_per_session
323            .check_key(&authentication.id)
324            .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use mas_data_model::{Clock, User, clock::MockClock};
331    use rand::SeedableRng;
332
333    use super::*;
334
335    #[test]
336    fn test_password_check_limiter() {
337        let now = MockClock::default().now();
338        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
339
340        let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
341
342        // Let's create a lot of requesters to test account-level rate limiting
343        let requesters: [_; 768] = (0..=255)
344            .flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
345            .collect::<Vec<_>>()
346            .try_into()
347            .unwrap();
348
349        let alice = User {
350            id: Ulid::from_datetime_with_source(now.into(), &mut rng),
351            username: "alice".to_owned(),
352            sub: "123-456".to_owned(),
353            created_at: now,
354            locked_at: None,
355            deactivated_at: None,
356            can_request_admin: false,
357            is_guest: true,
358        };
359
360        let bob = User {
361            id: Ulid::from_datetime_with_source(now.into(), &mut rng),
362            username: "bob".to_owned(),
363            sub: "123-456".to_owned(),
364            created_at: now,
365            locked_at: None,
366            deactivated_at: None,
367            can_request_admin: false,
368            is_guest: true,
369        };
370
371        // Three times the same IP address should be allowed
372        assert!(limiter.check_password(requesters[0], &alice).is_ok());
373        assert!(limiter.check_password(requesters[0], &alice).is_ok());
374        assert!(limiter.check_password(requesters[0], &alice).is_ok());
375
376        // But the fourth time should be rejected
377        assert!(limiter.check_password(requesters[0], &alice).is_err());
378        // Using another user should also be rejected
379        assert!(limiter.check_password(requesters[0], &bob).is_err());
380
381        // Using a different IP address should be allowed, the account isn't locked yet
382        assert!(limiter.check_password(requesters[1], &alice).is_ok());
383
384        // At this point, we consumed 4 cells out of 1800 on alice, let's distribute the
385        // requests with other IPs so that we get rate-limited on the account-level
386        for requester in requesters.iter().skip(2).take(598) {
387            assert!(limiter.check_password(*requester, &alice).is_ok());
388            assert!(limiter.check_password(*requester, &alice).is_ok());
389            assert!(limiter.check_password(*requester, &alice).is_ok());
390            assert!(limiter.check_password(*requester, &alice).is_err());
391        }
392
393        // We now have consumed 4+598*3 = 1798 cells on the account, so we should be
394        // rejected soon
395        assert!(limiter.check_password(requesters[600], &alice).is_ok());
396        assert!(limiter.check_password(requesters[601], &alice).is_ok());
397        assert!(limiter.check_password(requesters[602], &alice).is_err());
398
399        // The other account isn't rate-limited
400        assert!(limiter.check_password(requesters[603], &bob).is_ok());
401    }
402}