1use std::collections::BTreeMap;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13 AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Clock, Pkce, Session,
14};
15use mas_iana::oauth::PkceCodeChallengeMethod;
16use mas_storage::oauth2::OAuth2AuthorizationGrantRepository;
17use oauth2_types::{requests::ResponseMode, scope::Scope};
18use rand::RngCore;
19use sqlx::{PgConnection, types::Json};
20use ulid::Ulid;
21use url::Url;
22use uuid::Uuid;
23
24use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
25
26pub struct PgOAuth2AuthorizationGrantRepository<'c> {
29 conn: &'c mut PgConnection,
30}
31
32impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
33 pub fn new(conn: &'c mut PgConnection) -> Self {
36 Self { conn }
37 }
38}
39
40struct GrantLookup {
41 oauth2_authorization_grant_id: Uuid,
42 created_at: DateTime<Utc>,
43 cancelled_at: Option<DateTime<Utc>>,
44 fulfilled_at: Option<DateTime<Utc>>,
45 exchanged_at: Option<DateTime<Utc>>,
46 scope: String,
47 state: Option<String>,
48 nonce: Option<String>,
49 redirect_uri: String,
50 response_mode: String,
51 response_type_code: bool,
52 response_type_id_token: bool,
53 authorization_code: Option<String>,
54 code_challenge: Option<String>,
55 code_challenge_method: Option<String>,
56 login_hint: Option<String>,
57 locale: Option<String>,
58 raw_parameters: Option<Json<BTreeMap<String, String>>>,
59 oauth2_client_id: Uuid,
60 oauth2_session_id: Option<Uuid>,
61}
62
63impl TryFrom<GrantLookup> for AuthorizationGrant {
64 type Error = DatabaseInconsistencyError;
65
66 fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
67 let id = value.oauth2_authorization_grant_id.into();
68 let scope: Scope = value.scope.parse().map_err(|e| {
69 DatabaseInconsistencyError::on("oauth2_authorization_grants")
70 .column("scope")
71 .row(id)
72 .source(e)
73 })?;
74
75 let stage = match (
76 value.fulfilled_at,
77 value.exchanged_at,
78 value.cancelled_at,
79 value.oauth2_session_id,
80 ) {
81 (None, None, None, None) => AuthorizationGrantStage::Pending,
82 (Some(fulfilled_at), None, None, Some(session_id)) => {
83 AuthorizationGrantStage::Fulfilled {
84 session_id: session_id.into(),
85 fulfilled_at,
86 }
87 }
88 (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
89 AuthorizationGrantStage::Exchanged {
90 session_id: session_id.into(),
91 fulfilled_at,
92 exchanged_at,
93 }
94 }
95 (None, None, Some(cancelled_at), None) => {
96 AuthorizationGrantStage::Cancelled { cancelled_at }
97 }
98 _ => {
99 return Err(
100 DatabaseInconsistencyError::on("oauth2_authorization_grants")
101 .column("stage")
102 .row(id),
103 );
104 }
105 };
106
107 let pkce = match (value.code_challenge, value.code_challenge_method) {
108 (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
109 Some(Pkce {
110 challenge_method: PkceCodeChallengeMethod::Plain,
111 challenge,
112 })
113 }
114 (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
115 challenge_method: PkceCodeChallengeMethod::S256,
116 challenge,
117 }),
118 (None, None) => None,
119 _ => {
120 return Err(
121 DatabaseInconsistencyError::on("oauth2_authorization_grants")
122 .column("code_challenge_method")
123 .row(id),
124 );
125 }
126 };
127
128 let code: Option<AuthorizationCode> =
129 match (value.response_type_code, value.authorization_code, pkce) {
130 (false, None, None) => None,
131 (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
132 _ => {
133 return Err(
134 DatabaseInconsistencyError::on("oauth2_authorization_grants")
135 .column("authorization_code")
136 .row(id),
137 );
138 }
139 };
140
141 let redirect_uri = value.redirect_uri.parse().map_err(|e| {
142 DatabaseInconsistencyError::on("oauth2_authorization_grants")
143 .column("redirect_uri")
144 .row(id)
145 .source(e)
146 })?;
147
148 let response_mode = value.response_mode.parse().map_err(|e| {
149 DatabaseInconsistencyError::on("oauth2_authorization_grants")
150 .column("response_mode")
151 .row(id)
152 .source(e)
153 })?;
154
155 Ok(AuthorizationGrant {
156 id,
157 stage,
158 client_id: value.oauth2_client_id.into(),
159 code,
160 scope,
161 state: value.state,
162 nonce: value.nonce,
163 response_mode,
164 redirect_uri,
165 created_at: value.created_at,
166 response_type_id_token: value.response_type_id_token,
167 login_hint: value.login_hint,
168 locale: value.locale,
169 raw_parameters: value.raw_parameters.map(|Json(x)| x).unwrap_or_default(),
170 })
171 }
172}
173
174#[async_trait]
175impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
176 type Error = DatabaseError;
177
178 #[tracing::instrument(
179 name = "db.oauth2_authorization_grant.add",
180 skip_all,
181 fields(
182 db.query.text,
183 grant.id,
184 grant.scope = %scope,
185 %client.id,
186 ),
187 err,
188 )]
189 async fn add(
190 &mut self,
191 rng: &mut (dyn RngCore + Send),
192 clock: &dyn Clock,
193 client: &Client,
194 redirect_uri: Url,
195 scope: Scope,
196 code: Option<AuthorizationCode>,
197 state: Option<String>,
198 nonce: Option<String>,
199 response_mode: ResponseMode,
200 response_type_id_token: bool,
201 login_hint: Option<String>,
202 locale: Option<String>,
203 raw_parameters: BTreeMap<String, String>,
204 ) -> Result<AuthorizationGrant, Self::Error> {
205 let code_challenge = code
206 .as_ref()
207 .and_then(|c| c.pkce.as_ref())
208 .map(|p| &p.challenge);
209 let code_challenge_method = code
210 .as_ref()
211 .and_then(|c| c.pkce.as_ref())
212 .map(|p| p.challenge_method.to_string());
213 let code_str = code.as_ref().map(|c| &c.code);
214
215 let created_at = clock.now();
216 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
217 tracing::Span::current().record("grant.id", tracing::field::display(id));
218
219 sqlx::query!(
220 r#"
221 INSERT INTO oauth2_authorization_grants (
222 oauth2_authorization_grant_id,
223 oauth2_client_id,
224 redirect_uri,
225 scope,
226 state,
227 nonce,
228 response_mode,
229 code_challenge,
230 code_challenge_method,
231 response_type_code,
232 response_type_id_token,
233 authorization_code,
234 login_hint,
235 locale,
236 raw_parameters,
237 created_at
238 )
239 VALUES
240 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
241 "#,
242 Uuid::from(id),
243 Uuid::from(client.id),
244 redirect_uri.to_string(),
245 scope.to_string(),
246 state,
247 nonce,
248 response_mode.to_string(),
249 code_challenge,
250 code_challenge_method,
251 code.is_some(),
252 response_type_id_token,
253 code_str,
254 login_hint,
255 locale,
256 Json(&raw_parameters) as _,
257 created_at,
258 )
259 .traced()
260 .execute(&mut *self.conn)
261 .await?;
262
263 Ok(AuthorizationGrant {
264 id,
265 stage: AuthorizationGrantStage::Pending,
266 code,
267 redirect_uri,
268 client_id: client.id,
269 scope,
270 state,
271 nonce,
272 response_mode,
273 created_at,
274 response_type_id_token,
275 login_hint,
276 locale,
277 raw_parameters,
278 })
279 }
280
281 #[tracing::instrument(
282 name = "db.oauth2_authorization_grant.lookup",
283 skip_all,
284 fields(
285 db.query.text,
286 grant.id = %id,
287 ),
288 err,
289 )]
290 async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
291 let res = sqlx::query_as!(
292 GrantLookup,
293 r#"
294 SELECT oauth2_authorization_grant_id
295 , created_at
296 , cancelled_at
297 , fulfilled_at
298 , exchanged_at
299 , scope
300 , state
301 , redirect_uri
302 , response_mode
303 , nonce
304 , oauth2_client_id
305 , authorization_code
306 , response_type_code
307 , response_type_id_token
308 , code_challenge
309 , code_challenge_method
310 , login_hint
311 , locale
312 , raw_parameters AS "raw_parameters: Json<BTreeMap<String, String>>"
313 , oauth2_session_id
314 FROM
315 oauth2_authorization_grants
316
317 WHERE oauth2_authorization_grant_id = $1
318 "#,
319 Uuid::from(id),
320 )
321 .traced()
322 .fetch_optional(&mut *self.conn)
323 .await?;
324
325 let Some(res) = res else { return Ok(None) };
326
327 Ok(Some(res.try_into()?))
328 }
329
330 #[tracing::instrument(
331 name = "db.oauth2_authorization_grant.find_by_code",
332 skip_all,
333 fields(
334 db.query.text,
335 ),
336 err,
337 )]
338 async fn find_by_code(
339 &mut self,
340 code: &str,
341 ) -> Result<Option<AuthorizationGrant>, Self::Error> {
342 let res = sqlx::query_as!(
343 GrantLookup,
344 r#"
345 SELECT oauth2_authorization_grant_id
346 , created_at
347 , cancelled_at
348 , fulfilled_at
349 , exchanged_at
350 , scope
351 , state
352 , redirect_uri
353 , response_mode
354 , nonce
355 , oauth2_client_id
356 , authorization_code
357 , response_type_code
358 , response_type_id_token
359 , code_challenge
360 , code_challenge_method
361 , login_hint
362 , locale
363 , raw_parameters AS "raw_parameters: Json<BTreeMap<String, String>>"
364 , oauth2_session_id
365 FROM
366 oauth2_authorization_grants
367
368 WHERE authorization_code = $1
369 "#,
370 code,
371 )
372 .traced()
373 .fetch_optional(&mut *self.conn)
374 .await?;
375
376 let Some(res) = res else { return Ok(None) };
377
378 Ok(Some(res.try_into()?))
379 }
380
381 #[tracing::instrument(
382 name = "db.oauth2_authorization_grant.fulfill",
383 skip_all,
384 fields(
385 db.query.text,
386 %grant.id,
387 client.id = %grant.client_id,
388 %session.id,
389 ),
390 err,
391 )]
392 async fn fulfill(
393 &mut self,
394 clock: &dyn Clock,
395 session: &Session,
396 grant: AuthorizationGrant,
397 ) -> Result<AuthorizationGrant, Self::Error> {
398 let fulfilled_at = clock.now();
399 let res = sqlx::query!(
400 r#"
401 UPDATE oauth2_authorization_grants
402 SET fulfilled_at = $2
403 , oauth2_session_id = $3
404 WHERE oauth2_authorization_grant_id = $1
405 "#,
406 Uuid::from(grant.id),
407 fulfilled_at,
408 Uuid::from(session.id),
409 )
410 .traced()
411 .execute(&mut *self.conn)
412 .await?;
413
414 DatabaseError::ensure_affected_rows(&res, 1)?;
415
416 let grant = grant
418 .fulfill(fulfilled_at, session)
419 .map_err(DatabaseError::to_invalid_operation)?;
420
421 Ok(grant)
422 }
423
424 #[tracing::instrument(
425 name = "db.oauth2_authorization_grant.exchange",
426 skip_all,
427 fields(
428 db.query.text,
429 %grant.id,
430 client.id = %grant.client_id,
431 ),
432 err,
433 )]
434 async fn exchange(
435 &mut self,
436 clock: &dyn Clock,
437 grant: AuthorizationGrant,
438 ) -> Result<AuthorizationGrant, Self::Error> {
439 let exchanged_at = clock.now();
440 let res = sqlx::query!(
441 r#"
442 UPDATE oauth2_authorization_grants
443 SET exchanged_at = $2
444 WHERE oauth2_authorization_grant_id = $1
445 "#,
446 Uuid::from(grant.id),
447 exchanged_at,
448 )
449 .traced()
450 .execute(&mut *self.conn)
451 .await?;
452
453 DatabaseError::ensure_affected_rows(&res, 1)?;
454
455 let grant = grant
456 .exchange(exchanged_at)
457 .map_err(DatabaseError::to_invalid_operation)?;
458
459 Ok(grant)
460 }
461
462 #[tracing::instrument(
463 name = "db.oauth2_authorization_grant.cleanup",
464 skip_all,
465 fields(
466 db.query.text,
467 since = since.map(tracing::field::display),
468 until = %until,
469 limit = limit,
470 ),
471 err,
472 )]
473 async fn cleanup(
474 &mut self,
475 since: Option<Ulid>,
476 until: Ulid,
477 limit: usize,
478 ) -> Result<(usize, Option<Ulid>), Self::Error> {
479 let res = sqlx::query_scalar!(
484 r#"
485 WITH to_delete AS (
486 SELECT oauth2_authorization_grant_id
487 FROM oauth2_authorization_grants
488 WHERE ($1::uuid IS NULL OR oauth2_authorization_grant_id > $1)
489 AND oauth2_authorization_grant_id <= $2
490 ORDER BY oauth2_authorization_grant_id
491 LIMIT $3
492 )
493 DELETE FROM oauth2_authorization_grants
494 USING to_delete
495 WHERE oauth2_authorization_grants.oauth2_authorization_grant_id = to_delete.oauth2_authorization_grant_id
496 RETURNING oauth2_authorization_grants.oauth2_authorization_grant_id
497 "#,
498 since.map(Uuid::from),
499 Uuid::from(until),
500 i64::try_from(limit).unwrap_or(i64::MAX)
501 )
502 .traced()
503 .fetch_all(&mut *self.conn)
504 .await?;
505
506 let count = res.len();
507 let max_id = res.into_iter().max();
508
509 Ok((count, max_id.map(Ulid::from)))
510 }
511}