Skip to main content

mas_config/sections/
database.rs

1// Copyright 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::{num::NonZeroU32, time::Duration};
9
10use camino::Utf8PathBuf;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use serde_with::serde_as;
14
15use super::ConfigurationSection;
16use crate::schema;
17
18#[expect(clippy::unnecessary_wraps)]
19fn default_connection_string() -> Option<String> {
20    Some("postgresql://".to_owned())
21}
22
23fn default_max_connections() -> NonZeroU32 {
24    NonZeroU32::new(10).unwrap()
25}
26
27fn default_connect_timeout() -> Duration {
28    Duration::from_secs(30)
29}
30
31#[expect(clippy::unnecessary_wraps)]
32fn default_idle_timeout() -> Option<Duration> {
33    Some(Duration::from_mins(10))
34}
35
36#[expect(clippy::unnecessary_wraps)]
37fn default_max_lifetime() -> Option<Duration> {
38    Some(Duration::from_mins(30))
39}
40
41impl Default for DatabaseConfig {
42    fn default() -> Self {
43        Self {
44            uri: default_connection_string(),
45            host: None,
46            port: None,
47            socket: None,
48            username: None,
49            password: None,
50            database: None,
51            ssl_mode: None,
52            ssl_ca: None,
53            ssl_ca_file: None,
54            ssl_certificate: None,
55            ssl_certificate_file: None,
56            ssl_key: None,
57            ssl_key_file: None,
58            max_connections: default_max_connections(),
59            min_connections: Default::default(),
60            connect_timeout: default_connect_timeout(),
61            idle_timeout: default_idle_timeout(),
62            max_lifetime: default_max_lifetime(),
63        }
64    }
65}
66
67/// Options for controlling the level of protection provided for PostgreSQL SSL
68/// connections.
69#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
70#[serde(rename_all = "kebab-case")]
71pub enum PgSslMode {
72    /// Only try a non-SSL connection.
73    Disable,
74
75    /// First try a non-SSL connection; if that fails, try an SSL connection.
76    Allow,
77
78    /// First try an SSL connection; if that fails, try a non-SSL connection.
79    Prefer,
80
81    /// Only try an SSL connection. If a root CA file is present, verify the
82    /// connection in the same way as if `VerifyCa` was specified.
83    Require,
84
85    /// Only try an SSL connection, and verify that the server certificate is
86    /// issued by a trusted certificate authority (CA).
87    VerifyCa,
88
89    /// Only try an SSL connection; verify that the server certificate is issued
90    /// by a trusted CA and that the requested server host name matches that
91    /// in the certificate.
92    VerifyFull,
93}
94
95/// Database connection configuration
96#[serde_as]
97#[derive(Debug, Serialize, Deserialize, JsonSchema)]
98pub struct DatabaseConfig {
99    /// Connection URI
100    ///
101    /// This must not be specified if `host`, `port`, `socket`, `username`,
102    /// `password`, or `database` are specified.
103    #[serde(skip_serializing_if = "Option::is_none")]
104    #[schemars(url, default = "default_connection_string")]
105    pub uri: Option<String>,
106
107    /// Name of host to connect to
108    ///
109    /// This must not be specified if `uri` is specified.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    #[schemars(with = "Option::<schema::Hostname>")]
112    pub host: Option<String>,
113
114    /// Port number to connect at the server host
115    ///
116    /// This must not be specified if `uri` is specified.
117    #[serde(skip_serializing_if = "Option::is_none")]
118    #[schemars(range(min = 1, max = 65535))]
119    pub port: Option<u16>,
120
121    /// Directory containing the UNIX socket to connect to
122    ///
123    /// This must not be specified if `uri` is specified.
124    #[serde(skip_serializing_if = "Option::is_none")]
125    #[schemars(with = "Option<String>")]
126    pub socket: Option<Utf8PathBuf>,
127
128    /// PostgreSQL user name to connect as
129    ///
130    /// This must not be specified if `uri` is specified.
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub username: Option<String>,
133
134    /// Password to be used if the server demands password authentication
135    ///
136    /// This must not be specified if `uri` is specified.
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub password: Option<String>,
139
140    /// The database name
141    ///
142    /// This must not be specified if `uri` is specified.
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub database: Option<String>,
145
146    /// How to handle SSL connections
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub ssl_mode: Option<PgSslMode>,
149
150    /// The PEM-encoded root certificate for SSL connections
151    ///
152    /// This must not be specified if the `ssl_ca_file` option is specified.
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub ssl_ca: Option<String>,
155
156    /// Path to the root certificate for SSL connections
157    ///
158    /// This must not be specified if the `ssl_ca` option is specified.
159    #[serde(skip_serializing_if = "Option::is_none")]
160    #[schemars(with = "Option<String>")]
161    pub ssl_ca_file: Option<Utf8PathBuf>,
162
163    /// The PEM-encoded client certificate for SSL connections
164    ///
165    /// This must not be specified if the `ssl_certificate_file` option is
166    /// specified.
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub ssl_certificate: Option<String>,
169
170    /// Path to the client certificate for SSL connections
171    ///
172    /// This must not be specified if the `ssl_certificate` option is specified.
173    #[serde(skip_serializing_if = "Option::is_none")]
174    #[schemars(with = "Option<String>")]
175    pub ssl_certificate_file: Option<Utf8PathBuf>,
176
177    /// The PEM-encoded client key for SSL connections
178    ///
179    /// This must not be specified if the `ssl_key_file` option is specified.
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub ssl_key: Option<String>,
182
183    /// Path to the client key for SSL connections
184    ///
185    /// This must not be specified if the `ssl_key` option is specified.
186    #[serde(skip_serializing_if = "Option::is_none")]
187    #[schemars(with = "Option<String>")]
188    pub ssl_key_file: Option<Utf8PathBuf>,
189
190    /// Set the maximum number of connections the pool should maintain
191    #[serde(default = "default_max_connections")]
192    pub max_connections: NonZeroU32,
193
194    /// Set the minimum number of connections the pool should maintain
195    #[serde(default)]
196    pub min_connections: u32,
197
198    /// Set the amount of time to attempt connecting to the database
199    #[schemars(with = "u64")]
200    #[serde(default = "default_connect_timeout")]
201    #[serde_as(as = "serde_with::DurationSeconds<u64>")]
202    pub connect_timeout: Duration,
203
204    /// Set a maximum idle duration for individual connections
205    #[schemars(with = "Option<u64>")]
206    #[serde(
207        default = "default_idle_timeout",
208        skip_serializing_if = "Option::is_none"
209    )]
210    #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
211    pub idle_timeout: Option<Duration>,
212
213    /// Set the maximum lifetime of individual connections
214    #[schemars(with = "u64")]
215    #[serde(
216        default = "default_max_lifetime",
217        skip_serializing_if = "Option::is_none"
218    )]
219    #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
220    pub max_lifetime: Option<Duration>,
221}
222
223impl ConfigurationSection for DatabaseConfig {
224    const PATH: Option<&'static str> = Some("database");
225
226    fn validate(
227        &self,
228        figment: &figment::Figment,
229    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
230        let metadata = figment.find_metadata(Self::PATH.unwrap());
231        let annotate = |mut error: figment::Error| {
232            error.metadata = metadata.cloned();
233            error.profile = Some(figment::Profile::Default);
234            error.path = vec![Self::PATH.unwrap().to_owned()];
235            error
236        };
237
238        // Check that the user did not specify both `uri` and the split options at the
239        // same time
240        let has_split_options = self.host.is_some()
241            || self.port.is_some()
242            || self.socket.is_some()
243            || self.username.is_some()
244            || self.password.is_some()
245            || self.database.is_some();
246
247        if self.uri.is_some() && has_split_options {
248            return Err(annotate(figment::error::Error::from(
249                "uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(),
250            )).into());
251        }
252
253        if self.ssl_ca.is_some() && self.ssl_ca_file.is_some() {
254            return Err(annotate(figment::error::Error::from(
255                "ssl_ca must not be specified if ssl_ca_file is specified".to_owned(),
256            ))
257            .into());
258        }
259
260        if self.ssl_certificate.is_some() && self.ssl_certificate_file.is_some() {
261            return Err(annotate(figment::error::Error::from(
262                "ssl_certificate must not be specified if ssl_certificate_file is specified"
263                    .to_owned(),
264            ))
265            .into());
266        }
267
268        if self.ssl_key.is_some() && self.ssl_key_file.is_some() {
269            return Err(annotate(figment::error::Error::from(
270                "ssl_key must not be specified if ssl_key_file is specified".to_owned(),
271            ))
272            .into());
273        }
274
275        if (self.ssl_key.is_some() || self.ssl_key_file.is_some())
276            ^ (self.ssl_certificate.is_some() || self.ssl_certificate_file.is_some())
277        {
278            return Err(annotate(figment::error::Error::from(
279                "both a ssl_certificate and a ssl_key must be set at the same time or none of them"
280                    .to_owned(),
281            ))
282            .into());
283        }
284
285        Ok(())
286    }
287}
288#[cfg(test)]
289mod tests {
290    // The closures passed to `Jail::expect_with` return `figment::Error`, which is
291    // large, and we can't change figment's API.
292    #![expect(clippy::result_large_err)]
293
294    use figment::{
295        Figment, Jail,
296        providers::{Format, Yaml},
297    };
298
299    use super::*;
300
301    #[test]
302    fn load_config() {
303        Jail::expect_with(|jail| {
304            jail.create_file(
305                "config.yaml",
306                r"
307                    database:
308                      uri: postgresql://user:password@host/database
309                ",
310            )?;
311
312            let config = Figment::new()
313                .merge(Yaml::file("config.yaml"))
314                .extract_inner::<DatabaseConfig>("database")?;
315
316            assert_eq!(
317                config.uri.as_deref(),
318                Some("postgresql://user:password@host/database")
319            );
320
321            Ok(())
322        });
323    }
324}