mas_storage_pg/user/
mod.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8//! A module containing the PostgreSQL implementation of the user-related
9//! repositories
10
11use async_trait::async_trait;
12use mas_data_model::{Clock, User};
13use mas_storage::user::{UserFilter, UserRepository};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, extension::postgres::PgExpr as _};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use uuid::Uuid;
20
21use crate::{
22    DatabaseError,
23    filter::{Filter, StatementExt},
24    iden::Users,
25    pagination::QueryBuilderExt,
26    tracing::ExecuteExt,
27};
28
29mod email;
30mod password;
31mod recovery;
32mod registration;
33mod registration_token;
34mod session;
35mod terms;
36
37#[cfg(test)]
38mod tests;
39
40pub use self::{
41    email::PgUserEmailRepository, password::PgUserPasswordRepository,
42    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
43    registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
44    terms::PgUserTermsRepository,
45};
46
47/// An implementation of [`UserRepository`] for a PostgreSQL connection
48pub struct PgUserRepository<'c> {
49    conn: &'c mut PgConnection,
50}
51
52impl<'c> PgUserRepository<'c> {
53    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
54    pub fn new(conn: &'c mut PgConnection) -> Self {
55        Self { conn }
56    }
57}
58
59mod priv_ {
60    // The enum_def macro generates a public enum, which we don't want, because it
61    // triggers the missing docs warning
62    #![allow(missing_docs)]
63
64    use chrono::{DateTime, Utc};
65    use mas_storage::pagination::Node;
66    use sea_query::enum_def;
67    use ulid::Ulid;
68    use uuid::Uuid;
69
70    #[derive(Debug, Clone, sqlx::FromRow)]
71    #[enum_def]
72    pub(super) struct UserLookup {
73        pub(super) user_id: Uuid,
74        pub(super) username: String,
75        pub(super) created_at: DateTime<Utc>,
76        pub(super) locked_at: Option<DateTime<Utc>>,
77        pub(super) deactivated_at: Option<DateTime<Utc>>,
78        pub(super) can_request_admin: bool,
79        pub(super) is_guest: bool,
80    }
81
82    impl Node<Ulid> for UserLookup {
83        fn cursor(&self) -> Ulid {
84            self.user_id.into()
85        }
86    }
87}
88
89use priv_::{UserLookup, UserLookupIden};
90
91impl From<UserLookup> for User {
92    fn from(value: UserLookup) -> Self {
93        let id = value.user_id.into();
94        Self {
95            id,
96            username: value.username,
97            sub: id.to_string(),
98            created_at: value.created_at,
99            locked_at: value.locked_at,
100            deactivated_at: value.deactivated_at,
101            can_request_admin: value.can_request_admin,
102            is_guest: value.is_guest,
103        }
104    }
105}
106
107impl Filter for UserFilter<'_> {
108    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109        sea_query::Condition::all()
110            .add_option(self.state().map(|state| {
111                match state {
112                    mas_storage::user::UserState::Deactivated => {
113                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
114                    }
115                    mas_storage::user::UserState::Locked => {
116                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
117                    }
118                    mas_storage::user::UserState::Active => {
119                        Expr::col((Users::Table, Users::LockedAt))
120                            .is_null()
121                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
122                    }
123                }
124            }))
125            .add_option(self.can_request_admin().map(|can_request_admin| {
126                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
127            }))
128            .add_option(
129                self.is_guest()
130                    .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
131            )
132            .add_option(self.search().map(|search| {
133                Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
134            }))
135    }
136}
137
138#[async_trait]
139impl UserRepository for PgUserRepository<'_> {
140    type Error = DatabaseError;
141
142    #[tracing::instrument(
143        name = "db.user.lookup",
144        skip_all,
145        fields(
146            db.query.text,
147            user.id = %id,
148        ),
149        err,
150    )]
151    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
152        let res = sqlx::query_as!(
153            UserLookup,
154            r#"
155                SELECT user_id
156                     , username
157                     , created_at
158                     , locked_at
159                     , deactivated_at
160                     , can_request_admin
161                     , is_guest
162                FROM users
163                WHERE user_id = $1
164            "#,
165            Uuid::from(id),
166        )
167        .traced()
168        .fetch_optional(&mut *self.conn)
169        .await?;
170
171        let Some(res) = res else { return Ok(None) };
172
173        Ok(Some(res.into()))
174    }
175
176    #[tracing::instrument(
177        name = "db.user.find_by_username",
178        skip_all,
179        fields(
180            db.query.text,
181            user.username = username,
182        ),
183        err,
184    )]
185    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
186        // We may have multiple users with the same username, but with a different
187        // casing. In this case, we want to return the one which matches the exact
188        // casing
189        let res = sqlx::query_as!(
190            UserLookup,
191            r#"
192                SELECT user_id
193                     , username
194                     , created_at
195                     , locked_at
196                     , deactivated_at
197                     , can_request_admin
198                     , is_guest
199                FROM users
200                WHERE LOWER(username) = LOWER($1)
201            "#,
202            username,
203        )
204        .traced()
205        .fetch_all(&mut *self.conn)
206        .await?;
207
208        match &res[..] {
209            // Happy path: there is only one user matching the username…
210            [user] => Ok(Some(user.clone().into())),
211            // …or none.
212            [] => Ok(None),
213            list => {
214                // If there are multiple users with the same username, we want to
215                // return the one which matches the exact casing
216                if let Some(user) = list.iter().find(|user| user.username == username) {
217                    Ok(Some(user.clone().into()))
218                } else {
219                    // If none match exactly, we prefer to return nothing
220                    Ok(None)
221                }
222            }
223        }
224    }
225
226    #[tracing::instrument(
227        name = "db.user.add",
228        skip_all,
229        fields(
230            db.query.text,
231            user.username = username,
232            user.id,
233        ),
234        err,
235    )]
236    async fn add(
237        &mut self,
238        rng: &mut (dyn RngCore + Send),
239        clock: &dyn Clock,
240        username: String,
241    ) -> Result<User, Self::Error> {
242        let created_at = clock.now();
243        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
244        tracing::Span::current().record("user.id", tracing::field::display(id));
245
246        let res = sqlx::query!(
247            r#"
248                INSERT INTO users (user_id, username, created_at)
249                VALUES ($1, $2, $3)
250                ON CONFLICT (username) DO NOTHING
251            "#,
252            Uuid::from(id),
253            username,
254            created_at,
255        )
256        .traced()
257        .execute(&mut *self.conn)
258        .await?;
259
260        // If the user already exists, want to return an error but not poison the
261        // transaction
262        DatabaseError::ensure_affected_rows(&res, 1)?;
263
264        Ok(User {
265            id,
266            username,
267            sub: id.to_string(),
268            created_at,
269            locked_at: None,
270            deactivated_at: None,
271            can_request_admin: false,
272            is_guest: false,
273        })
274    }
275
276    #[tracing::instrument(
277        name = "db.user.exists",
278        skip_all,
279        fields(
280            db.query.text,
281            user.username = username,
282        ),
283        err,
284    )]
285    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
286        let exists = sqlx::query_scalar!(
287            r#"
288                SELECT EXISTS(
289                    SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
290                ) AS "exists!"
291            "#,
292            username
293        )
294        .traced()
295        .fetch_one(&mut *self.conn)
296        .await?;
297
298        Ok(exists)
299    }
300
301    #[tracing::instrument(
302        name = "db.user.lock",
303        skip_all,
304        fields(
305            db.query.text,
306            %user.id,
307        ),
308        err,
309    )]
310    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
311        if user.locked_at.is_some() {
312            return Ok(user);
313        }
314
315        let locked_at = clock.now();
316        let res = sqlx::query!(
317            r#"
318                UPDATE users
319                SET locked_at = $1
320                WHERE user_id = $2
321            "#,
322            locked_at,
323            Uuid::from(user.id),
324        )
325        .traced()
326        .execute(&mut *self.conn)
327        .await?;
328
329        DatabaseError::ensure_affected_rows(&res, 1)?;
330
331        user.locked_at = Some(locked_at);
332
333        Ok(user)
334    }
335
336    #[tracing::instrument(
337        name = "db.user.unlock",
338        skip_all,
339        fields(
340            db.query.text,
341            %user.id,
342        ),
343        err,
344    )]
345    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
346        if user.locked_at.is_none() {
347            return Ok(user);
348        }
349
350        let res = sqlx::query!(
351            r#"
352                UPDATE users
353                SET locked_at = NULL
354                WHERE user_id = $1
355            "#,
356            Uuid::from(user.id),
357        )
358        .traced()
359        .execute(&mut *self.conn)
360        .await?;
361
362        DatabaseError::ensure_affected_rows(&res, 1)?;
363
364        user.locked_at = None;
365
366        Ok(user)
367    }
368
369    #[tracing::instrument(
370        name = "db.user.deactivate",
371        skip_all,
372        fields(
373            db.query.text,
374            %user.id,
375        ),
376        err,
377    )]
378    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
379        if user.deactivated_at.is_some() {
380            return Ok(user);
381        }
382
383        let deactivated_at = clock.now();
384        let res = sqlx::query!(
385            r#"
386                UPDATE users
387                SET deactivated_at = $2
388                WHERE user_id = $1
389                  AND deactivated_at IS NULL
390            "#,
391            Uuid::from(user.id),
392            deactivated_at,
393        )
394        .traced()
395        .execute(&mut *self.conn)
396        .await?;
397
398        DatabaseError::ensure_affected_rows(&res, 1)?;
399
400        user.deactivated_at = Some(deactivated_at);
401
402        Ok(user)
403    }
404
405    #[tracing::instrument(
406        name = "db.user.reactivate",
407        skip_all,
408        fields(
409            db.query.text,
410            %user.id,
411        ),
412        err,
413    )]
414    async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
415        if user.deactivated_at.is_none() {
416            return Ok(user);
417        }
418
419        let res = sqlx::query!(
420            r#"
421                UPDATE users
422                SET deactivated_at = NULL
423                WHERE user_id = $1
424            "#,
425            Uuid::from(user.id),
426        )
427        .traced()
428        .execute(&mut *self.conn)
429        .await?;
430
431        DatabaseError::ensure_affected_rows(&res, 1)?;
432
433        user.deactivated_at = None;
434
435        Ok(user)
436    }
437
438    #[tracing::instrument(
439        name = "db.user.delete_unsupported_threepids",
440        skip_all,
441        fields(
442            db.query.text,
443            %user.id,
444        ),
445        err,
446    )]
447    async fn delete_unsupported_threepids(&mut self, user: &User) -> Result<usize, Self::Error> {
448        let res = sqlx::query!(
449            r#"
450                DELETE FROM user_unsupported_third_party_ids
451                WHERE user_id = $1
452            "#,
453            Uuid::from(user.id),
454        )
455        .traced()
456        .execute(&mut *self.conn)
457        .await?;
458
459        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
460    }
461
462    #[tracing::instrument(
463        name = "db.user.set_can_request_admin",
464        skip_all,
465        fields(
466            db.query.text,
467            %user.id,
468            user.can_request_admin = can_request_admin,
469        ),
470        err,
471    )]
472    async fn set_can_request_admin(
473        &mut self,
474        mut user: User,
475        can_request_admin: bool,
476    ) -> Result<User, Self::Error> {
477        let res = sqlx::query!(
478            r#"
479                UPDATE users
480                SET can_request_admin = $2
481                WHERE user_id = $1
482            "#,
483            Uuid::from(user.id),
484            can_request_admin,
485        )
486        .traced()
487        .execute(&mut *self.conn)
488        .await?;
489
490        DatabaseError::ensure_affected_rows(&res, 1)?;
491
492        user.can_request_admin = can_request_admin;
493
494        Ok(user)
495    }
496
497    #[tracing::instrument(
498        name = "db.user.list",
499        skip_all,
500        fields(
501            db.query.text,
502        ),
503        err,
504    )]
505    async fn list(
506        &mut self,
507        filter: UserFilter<'_>,
508        pagination: mas_storage::Pagination,
509    ) -> Result<mas_storage::Page<User>, Self::Error> {
510        let (sql, arguments) = Query::select()
511            .expr_as(
512                Expr::col((Users::Table, Users::UserId)),
513                UserLookupIden::UserId,
514            )
515            .expr_as(
516                Expr::col((Users::Table, Users::Username)),
517                UserLookupIden::Username,
518            )
519            .expr_as(
520                Expr::col((Users::Table, Users::CreatedAt)),
521                UserLookupIden::CreatedAt,
522            )
523            .expr_as(
524                Expr::col((Users::Table, Users::LockedAt)),
525                UserLookupIden::LockedAt,
526            )
527            .expr_as(
528                Expr::col((Users::Table, Users::DeactivatedAt)),
529                UserLookupIden::DeactivatedAt,
530            )
531            .expr_as(
532                Expr::col((Users::Table, Users::CanRequestAdmin)),
533                UserLookupIden::CanRequestAdmin,
534            )
535            .expr_as(
536                Expr::col((Users::Table, Users::IsGuest)),
537                UserLookupIden::IsGuest,
538            )
539            .from(Users::Table)
540            .apply_filter(filter)
541            .generate_pagination((Users::Table, Users::UserId), pagination)
542            .build_sqlx(PostgresQueryBuilder);
543
544        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
545            .traced()
546            .fetch_all(&mut *self.conn)
547            .await?;
548
549        let page = pagination.process(edges).map(User::from);
550
551        Ok(page)
552    }
553
554    #[tracing::instrument(
555        name = "db.user.count",
556        skip_all,
557        fields(
558            db.query.text,
559        ),
560        err,
561    )]
562    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
563        let (sql, arguments) = Query::select()
564            .expr(Expr::col((Users::Table, Users::UserId)).count())
565            .from(Users::Table)
566            .apply_filter(filter)
567            .build_sqlx(PostgresQueryBuilder);
568
569        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
570            .traced()
571            .fetch_one(&mut *self.conn)
572            .await?;
573
574        count
575            .try_into()
576            .map_err(DatabaseError::to_invalid_operation)
577    }
578
579    #[tracing::instrument(
580        name = "db.user.acquire_lock_for_sync",
581        skip_all,
582        fields(
583            db.query.text,
584            user.id = %user.id,
585        ),
586        err,
587    )]
588    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
589        // XXX: this lock isn't stictly scoped to users, but as we don't use many
590        // postgres advisory locks, it's fine for now. Later on, we could use row-level
591        // locks to make sure we don't get into trouble
592
593        // Convert the user ID to a u128 and grab the lower 64 bits
594        // As this includes 64bit of the random part of the ULID, it should be random
595        // enough to not collide
596        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
597
598        // Use a PG advisory lock, which will be released when the transaction is
599        // committed or rolled back
600        sqlx::query!(
601            r#"
602                SELECT pg_advisory_xact_lock($1)
603            "#,
604            lock_id,
605        )
606        .traced()
607        .execute(&mut *self.conn)
608        .await?;
609
610        Ok(())
611    }
612}