From bb82e5dd438928fb65eba746f59bc982104d8fb2 Mon Sep 17 00:00:00 2001 From: Dessalines Date: Tue, 11 Mar 2025 05:39:48 -0400 Subject: [PATCH] Register users in a transaction. (#5480) * Register users in a transaction. - Fixes #5477 * Removing transactions from language code. * Replace build_transactions with async transactions. * Forgot one transaction. --- Cargo.lock | 2 + crates/api/Cargo.toml | 1 + .../site/registration_applications/approve.rs | 11 +- crates/api_crud/Cargo.toml | 1 + crates/api_crud/src/user/create.rs | 272 ++++++++++-------- crates/db_schema/src/impls/actor_language.rs | 130 ++++----- crates/db_schema/src/impls/images.rs | 10 +- .../src/impls/local_site_url_blocklist.rs | 25 +- 8 files changed, 240 insertions(+), 212 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b43194b52..a53241b9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2582,6 +2582,7 @@ dependencies = [ "bcrypt", "captcha", "chrono", + "diesel-async", "elementtree", "hound", "lemmy_api_common", @@ -2646,6 +2647,7 @@ dependencies = [ "anyhow", "bcrypt", "chrono", + "diesel-async", "futures", "lemmy_api_common", "lemmy_db_schema", diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index 7d954d968..1afcaa5d3 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -35,6 +35,7 @@ regex = { workspace = true } hound = "3.5.1" sitemap-rs = "0.2.2" totp-rs = { version = "5.6.0", features = ["gen_secret", "otpauth"] } +diesel-async = { workspace = true, features = ["deadpool", "postgres"] } [dev-dependencies] serial_test = { workspace = true } diff --git a/crates/api/src/site/registration_applications/approve.rs b/crates/api/src/site/registration_applications/approve.rs index b8cd6c0ea..1129d8bf8 100644 --- a/crates/api/src/site/registration_applications/approve.rs +++ b/crates/api/src/site/registration_applications/approve.rs @@ -1,5 +1,6 @@ use activitypub_federation::config::Data; use actix_web::web::Json; +use diesel_async::{scoped_futures::ScopedFutureExt, AsyncConnection}; use lemmy_api_common::{ context::LemmyContext, site::{ApproveRegistrationApplication, RegistrationApplicationResponse}, @@ -30,9 +31,8 @@ pub async fn approve_registration_application( let conn = &mut get_conn(pool).await?; let tx_data = data.clone(); let approved_user_id = conn - .build_transaction() - .run(|conn| { - Box::pin(async move { + .transaction::<_, LemmyError, _>(|conn| { + async move { // Update the registration with reason, admin_id let deny_reason = diesel_string_update(tx_data.deny_reason.as_deref()); let app_form = RegistrationApplicationUpdateForm { @@ -52,8 +52,9 @@ pub async fn approve_registration_application( let approved_user_id = registration_application.local_user_id; LocalUser::update(&mut conn.into(), approved_user_id, &local_user_form).await?; - Ok::<_, LemmyError>(approved_user_id) - }) as _ + Ok(approved_user_id) + } + .scope_boxed() }) .await?; diff --git a/crates/api_crud/Cargo.toml b/crates/api_crud/Cargo.toml index 012f20d4c..dc64ef1b0 100644 --- a/crates/api_crud/Cargo.toml +++ b/crates/api_crud/Cargo.toml @@ -30,6 +30,7 @@ regex = { workspace = true } serde_json = { workspace = true } serde = { workspace = true } serde_with = { workspace = true } +diesel-async = { workspace = true, features = ["deadpool", "postgres"] } [package.metadata.cargo-shear] ignored = ["futures"] diff --git a/crates/api_crud/src/user/create.rs b/crates/api_crud/src/user/create.rs index 72779092d..a80ac4678 100644 --- a/crates/api_crud/src/user/create.rs +++ b/crates/api_crud/src/user/create.rs @@ -1,5 +1,6 @@ use activitypub_federation::{config::Data, http_signatures::generate_actor_keypair}; use actix_web::{web::Json, HttpRequest}; +use diesel_async::{scoped_futures::ScopedFutureExt, AsyncConnection, AsyncPgConnection}; use lemmy_api_common::{ claims::Claims, context::LemmyContext, @@ -31,11 +32,13 @@ use lemmy_db_schema::{ registration_application::{RegistrationApplication, RegistrationApplicationInsertForm}, }, traits::Crud, + utils::get_conn, RegistrationMode, }; use lemmy_db_views::structs::{LocalUserView, SiteView}; use lemmy_utils::{ error::{LemmyError, LemmyErrorExt, LemmyErrorType, LemmyResult}, + settings::structs::Settings, utils::{ slurs::{check_slurs, check_slurs_opt}, validation::is_valid_actor_name, @@ -62,7 +65,8 @@ pub async fn register( req: HttpRequest, context: Data, ) -> LemmyResult> { - let site_view = SiteView::read_local(&mut context.pool()).await?; + let pool = &mut context.pool(); + let site_view = SiteView::read_local(pool).await?; let local_site = site_view.local_site; let require_registration_application = local_site.registration_mode == RegistrationMode::RequireApplication; @@ -91,7 +95,7 @@ pub async fn register( if local_site.site_setup && local_site.captcha_enabled { let uuid = uuid::Uuid::parse_str(&data.captcha_uuid.clone().unwrap_or_default())?; CaptchaAnswer::check_captcha( - &mut context.pool(), + pool, CheckCaptchaAnswer { uuid, answer: data.captcha_answer.clone().unwrap_or_default(), @@ -104,21 +108,12 @@ pub async fn register( check_slurs(&data.username, &slur_regex)?; check_slurs_opt(&data.answer, &slur_regex)?; - Person::check_username_taken(&mut context.pool(), &data.username).await?; + Person::check_username_taken(pool, &data.username).await?; if let Some(email) = &data.email { - LocalUser::check_is_email_taken(&mut context.pool(), email).await?; + LocalUser::check_is_email_taken(pool, email).await?; } - // We have to create both a person, and local_user - let inserted_person = create_person( - data.username.clone(), - &local_site, - site_view.site.instance_id, - &context, - ) - .await?; - // Automatically set their application as accepted, if they created this with open registration. // Also fixes a bug which allows users to log in when registrations are changed to closed. let accepted_application = Some(!require_registration_application); @@ -130,33 +125,57 @@ pub async fn register( let language_tags = get_language_tags(&req); - // Create the local user - let local_user_form = LocalUserInsertForm { - email: data.email.as_deref().map(str::to_lowercase), - show_nsfw: Some(show_nsfw), - accepted_application, - ..LocalUserInsertForm::new(inserted_person.id, Some(data.password.to_string())) - }; + // Wrap the insert person, insert local user, and create registration, + // in a transaction, so that if any fail, the rows aren't created. + let conn = &mut get_conn(pool).await?; + let tx_data = data.clone(); + let tx_local_site = local_site.clone(); + let tx_settings = context.settings(); + let (person, local_user) = conn + .transaction::<_, LemmyError, _>(|conn| { + async move { + // We have to create both a person, and local_user + let person = create_person( + tx_data.username.clone(), + &tx_local_site, + site_view.site.instance_id, + tx_settings, + conn, + ) + .await?; - let inserted_local_user = - create_local_user(&context, language_tags, local_user_form, &local_site).await?; + // Create the local user + let local_user_form = LocalUserInsertForm { + email: tx_data.email.as_deref().map(str::to_lowercase), + show_nsfw: Some(show_nsfw), + accepted_application, + ..LocalUserInsertForm::new(person.id, Some(tx_data.password.to_string())) + }; - if local_site.site_setup && require_registration_application { - if let Some(answer) = data.answer.clone() { - // Create the registration application - let form = RegistrationApplicationInsertForm { - local_user_id: inserted_local_user.id, - answer, - }; + let local_user = + create_local_user(conn, language_tags, local_user_form, &tx_local_site).await?; - RegistrationApplication::create(&mut context.pool(), &form).await?; - } - } + if local_site.site_setup && require_registration_application { + if let Some(answer) = tx_data.answer.clone() { + // Create the registration application + let form = RegistrationApplicationInsertForm { + local_user_id: local_user.id, + answer, + }; + + RegistrationApplication::create(&mut conn.into(), &form).await?; + } + } + + Ok((person, local_user)) + } + .scope_boxed() + }) + .await?; // Email the admins, only if email verification is not required if local_site.application_email_admins && !local_site.require_email_verification { - send_new_applicant_email_to_admins(&data.username, &mut context.pool(), context.settings()) - .await?; + send_new_applicant_email_to_admins(&data.username, pool, context.settings()).await?; } let mut login_response = LoginResponse { @@ -170,16 +189,11 @@ pub async fn register( if !local_site.site_setup || (!require_registration_application && !local_site.require_email_verification) { - let jwt = Claims::generate(inserted_local_user.id, req, &context).await?; + let jwt = Claims::generate(local_user.id, req, &context).await?; login_response.jwt = Some(jwt); } else { - login_response.verify_email_sent = send_verification_email_if_required( - &context, - &local_site, - &inserted_local_user, - &inserted_person, - ) - .await?; + login_response.verify_email_sent = + send_verification_email_if_required(&context, &local_site, &local_user, &person).await?; if require_registration_application { login_response.registration_created = true; @@ -194,9 +208,17 @@ pub async fn authenticate_with_oauth( req: HttpRequest, context: Data, ) -> LemmyResult> { - let site_view = SiteView::read_local(&mut context.pool()).await?; + let pool = &mut context.pool(); + let site_view = SiteView::read_local(pool).await?; let local_site = site_view.local_site.clone(); + // Show nsfw content if param is true, or if content_warning exists + let show_nsfw = data + .show_nsfw + .unwrap_or(site_view.site.content_warning.is_some()); + + let language_tags = get_language_tags(&req); + // validate inputs if data.oauth_provider_id == OAuthProviderId(0) || data.code.is_empty() || data.code.len() > 300 { return Err(LemmyErrorType::OauthAuthorizationInvalid)?; @@ -218,7 +240,7 @@ pub async fn authenticate_with_oauth( // Fetch the OAUTH provider and make sure it's enabled let oauth_provider_id = data.oauth_provider_id; - let oauth_provider = OAuthProvider::read(&mut context.pool(), oauth_provider_id) + let oauth_provider = OAuthProvider::read(pool, oauth_provider_id) .await .ok() .ok_or(LemmyErrorType::OauthAuthorizationInvalid)?; @@ -253,16 +275,16 @@ pub async fn authenticate_with_oauth( // Lookup user by oauth_user_id let mut local_user_view = - LocalUserView::find_by_oauth_id(&mut context.pool(), oauth_provider.id, &oauth_user_id).await; + LocalUserView::find_by_oauth_id(pool, oauth_provider.id, &oauth_user_id).await; - let local_user: LocalUser; - if let Ok(user_view) = local_user_view { + let local_user = if let Ok(user_view) = local_user_view { // user found by oauth_user_id => Login user - local_user = user_view.clone().local_user; + let local_user = user_view.clone().local_user; check_user_valid(&user_view.person)?; check_email_verified(&user_view, &site_view)?; - check_registration_application(&user_view, &site_view.local_site, &mut context.pool()).await?; + check_registration_application(&user_view, &site_view.local_site, pool).await?; + local_user } else { // user has never previously registered using oauth @@ -283,9 +305,8 @@ pub async fn authenticate_with_oauth( local_site.registration_mode == RegistrationMode::RequireApplication; // Lookup user by OAUTH email and link accounts - local_user_view = LocalUserView::find_by_email(&mut context.pool(), &email).await; + local_user_view = LocalUserView::find_by_email(pool, &email).await; - let person; if let Ok(user_view) = local_user_view { // user found by email => link and login if linking is allowed @@ -299,18 +320,17 @@ pub async fn authenticate_with_oauth( check_user_valid(&user_view.person)?; check_email_verified(&user_view, &site_view)?; - check_registration_application(&user_view, &site_view.local_site, &mut context.pool()) - .await?; + check_registration_application(&user_view, &site_view.local_site, pool).await?; // Link with OAUTH => Login user let oauth_account_form = OAuthAccountInsertForm::new(user_view.local_user.id, oauth_provider.id, oauth_user_id); - OAuthAccount::create(&mut context.pool(), &oauth_account_form) + OAuthAccount::create(pool, &oauth_account_form) .await .with_lemmy_type(LemmyErrorType::OauthLoginFailed)?; - local_user = user_view.local_user.clone(); + user_view.local_user.clone() } else { return Err(LemmyErrorType::EmailAlreadyExists)?; } @@ -320,79 +340,90 @@ pub async fn authenticate_with_oauth( // make sure the registration answer is provided when the registration application is required validate_registration_answer(require_registration_application, &data.answer)?; - // make sure the username is provided - let username = data - .username - .as_ref() - .ok_or(LemmyErrorType::RegistrationUsernameRequired)?; - let slur_regex = slur_regex(&context).await?; - check_slurs(username, &slur_regex)?; - check_slurs_opt(&data.answer, &slur_regex)?; - Person::check_username_taken(&mut context.pool(), username).await?; + // Wrap the insert person, insert local user, and create registration, + // in a transaction, so that if any fail, the rows aren't created. + let conn = &mut get_conn(pool).await?; + let tx_data = data.clone(); + let tx_local_site = local_site.clone(); + let tx_settings = context.settings(); + let (person, local_user) = conn + .transaction::<_, LemmyError, _>(|conn| { + async move { + // make sure the username is provided + let username = tx_data + .username + .as_ref() + .ok_or(LemmyErrorType::RegistrationUsernameRequired)?; - // We have to create a person, a local_user, and an oauth_account - person = create_person( - username.clone(), - &local_site, - site_view.site.instance_id, - &context, - ) - .await?; + check_slurs(username, &slur_regex)?; + check_slurs_opt(&data.answer, &slur_regex)?; - // Show nsfw content if param is true, or if content_warning exists - let show_nsfw = data - .show_nsfw - .unwrap_or(site_view.site.content_warning.is_some()); + Person::check_username_taken(&mut conn.into(), username).await?; - let language_tags = get_language_tags(&req); + // We have to create a person, a local_user, and an oauth_account + let person = create_person( + username.clone(), + &tx_local_site, + site_view.site.instance_id, + tx_settings, + conn, + ) + .await?; - // Create the local user - let local_user_form = LocalUserInsertForm { - email: Some(str::to_lowercase(&email)), - show_nsfw: Some(show_nsfw), - accepted_application: Some(!require_registration_application), - email_verified: Some(oauth_provider.auto_verify_email), - ..LocalUserInsertForm::new(person.id, None) - }; + // Create the local user + let local_user_form = LocalUserInsertForm { + email: Some(str::to_lowercase(&email)), + show_nsfw: Some(show_nsfw), + accepted_application: Some(!require_registration_application), + email_verified: Some(oauth_provider.auto_verify_email), + ..LocalUserInsertForm::new(person.id, None) + }; - local_user = create_local_user(&context, language_tags, local_user_form, &local_site).await?; + let local_user = + create_local_user(conn, language_tags, local_user_form, &tx_local_site).await?; - // Create the oauth account - let oauth_account_form = - OAuthAccountInsertForm::new(local_user.id, oauth_provider.id, oauth_user_id); + // Create the oauth account + let oauth_account_form = + OAuthAccountInsertForm::new(local_user.id, oauth_provider.id, oauth_user_id); - OAuthAccount::create(&mut context.pool(), &oauth_account_form) - .await - .with_lemmy_type(LemmyErrorType::IncorrectLogin)?; + OAuthAccount::create(&mut conn.into(), &oauth_account_form) + .await + .with_lemmy_type(LemmyErrorType::IncorrectLogin)?; - // prevent sign in until application is accepted - if local_site.site_setup - && require_registration_application - && !local_user.accepted_application - && !local_user.admin - { - if let Some(answer) = data.answer.clone() { - // Create the registration application - RegistrationApplication::create( - &mut context.pool(), - &RegistrationApplicationInsertForm { - local_user_id: local_user.id, - answer, - }, - ) - .await?; + // prevent sign in until application is accepted + if local_site.site_setup + && require_registration_application + && !local_user.accepted_application + && !local_user.admin + { + if let Some(answer) = data.answer.clone() { + // Create the registration application + RegistrationApplication::create( + &mut conn.into(), + &RegistrationApplicationInsertForm { + local_user_id: local_user.id, + answer, + }, + ) + .await?; - login_response.registration_created = true; - } - } + login_response.registration_created = true; + } + } + Ok((person, local_user)) + } + .scope_boxed() + }) + .await?; // Check email is verified when required login_response.verify_email_sent = send_verification_email_if_required(&context, &local_site, &local_user, &person).await?; + local_user } - } + }; if !login_response.registration_created && !login_response.verify_email_sent { let jwt = Claims::generate(local_user.id, req, &context).await?; @@ -406,11 +437,12 @@ async fn create_person( username: String, local_site: &LocalSite, instance_id: InstanceId, - context: &Data, + settings: &Settings, + conn: &mut AsyncPgConnection, ) -> Result { let actor_keypair = generate_actor_keypair()?; is_valid_actor_name(&username, local_site.actor_name_max_length as usize)?; - let ap_id = Person::local_url(&username, context.settings())?; + let ap_id = Person::local_url(&username, settings)?; // Register the new person let person_form = PersonInsertForm { @@ -421,7 +453,7 @@ async fn create_person( }; // insert the person - let inserted_person = Person::create(&mut context.pool(), &person_form) + let inserted_person = Person::create(&mut conn.into(), &person_form) .await .with_lemmy_type(LemmyErrorType::UserAlreadyExists)?; @@ -441,17 +473,18 @@ fn get_language_tags(req: &HttpRequest) -> Vec { } async fn create_local_user( - context: &Data, + conn: &mut AsyncPgConnection, language_tags: Vec, mut local_user_form: LocalUserInsertForm, local_site: &LocalSite, ) -> Result { - let all_languages = Language::read_all(&mut context.pool()).await?; + let conn_ = &mut conn.into(); + let all_languages = Language::read_all(conn_).await?; // use hashset to avoid duplicates let mut language_ids = HashSet::new(); // Enable site languages. Ignored if all languages are enabled. - let discussion_languages = SiteLanguage::read(&mut context.pool(), local_site.site_id).await?; + let discussion_languages = SiteLanguage::read(conn_, local_site.site_id).await?; // Enable languages from `Accept-Language` header only if no site languages are set. Otherwise it // is possible that browser languages are only set to e.g. French, and the user won't see any @@ -472,8 +505,7 @@ async fn create_local_user( // If its the initial site setup, they are an admin local_user_form.admin = Some(!local_site.site_setup); local_user_form.interface_language = language_tags.first().cloned(); - let inserted_local_user = - LocalUser::create(&mut context.pool(), &local_user_form, language_ids).await?; + let inserted_local_user = LocalUser::create(conn_, &local_user_form, language_ids).await?; Ok(inserted_local_user) } diff --git a/crates/db_schema/src/impls/actor_language.rs b/crates/db_schema/src/impls/actor_language.rs index 4c0cfc0fc..cdc286f30 100644 --- a/crates/db_schema/src/impls/actor_language.rs +++ b/crates/db_schema/src/impls/actor_language.rs @@ -1,7 +1,7 @@ use crate::{ diesel::JoinOnDsl, newtypes::{CommunityId, InstanceId, LanguageId, LocalUserId, SiteId}, - schema::{local_site, site, site_language}, + schema::{community_language, local_site, local_user_language, site, site_language}, source::{ actor_language::{ CommunityLanguage, @@ -25,7 +25,12 @@ use diesel::{ ExpressionMethods, QueryDsl, }; -use diesel_async::{AsyncPgConnection, RunQueryDsl}; +use diesel_async::{ + scoped_futures::ScopedFutureExt, + AsyncConnection, + AsyncPgConnection, + RunQueryDsl, +}; use lemmy_utils::error::{LemmyErrorType, LemmyResult}; use tokio::sync::OnceCell; @@ -36,17 +41,12 @@ impl LocalUserLanguage { pool: &mut DbPool<'_>, for_local_user_id: LocalUserId, ) -> Result, Error> { - use crate::schema::local_user_language::dsl::{ - language_id, - local_user_id, - local_user_language, - }; let conn = &mut get_conn(pool).await?; - let langs = local_user_language - .filter(local_user_id.eq(for_local_user_id)) - .order(language_id) - .select(language_id) + let langs = local_user_language::table + .filter(local_user_language::local_user_id.eq(for_local_user_id)) + .order(local_user_language::language_id) + .select(local_user_language::language_id) .get_results(conn) .await?; convert_read_languages(conn, langs).await @@ -59,30 +59,25 @@ impl LocalUserLanguage { pool: &mut DbPool<'_>, language_ids: Vec, for_local_user_id: LocalUserId, - ) -> Result<(), Error> { + ) -> Result { let conn = &mut get_conn(pool).await?; let lang_ids = convert_update_languages(conn, language_ids).await?; // No need to update if languages are unchanged let current = LocalUserLanguage::read(&mut conn.into(), for_local_user_id).await?; if current == lang_ids { - return Ok(()); + return Ok(0); } conn - .build_transaction() - .run(|conn| { - Box::pin(async move { - use crate::schema::local_user_language::dsl::{ - language_id, - local_user_id, - local_user_language, - }; + .transaction::<_, Error, _>(|conn| { + async move { // Delete old languages, not including new languages - let delete_old = delete(local_user_language) - .filter(local_user_id.eq(for_local_user_id)) - .filter(language_id.ne_all(&lang_ids)) - .execute(conn); + delete(local_user_language::table) + .filter(local_user_language::local_user_id.eq(for_local_user_id)) + .filter(local_user_language::language_id.ne_all(&lang_ids)) + .execute(conn) + .await?; let forms = lang_ids .iter() @@ -93,15 +88,17 @@ impl LocalUserLanguage { .collect::>(); // Insert new languages - let insert_new = insert_into(local_user_language) + insert_into(local_user_language::table) .values(forms) - .on_conflict((language_id, local_user_id)) + .on_conflict(( + local_user_language::language_id, + local_user_language::local_user_id, + )) .do_nothing() - .execute(conn); - - tokio::try_join!(delete_old, insert_new)?; - Ok(()) - }) as _ + .execute(conn) + .await + } + .scope_boxed() }) .await } @@ -148,16 +145,14 @@ impl SiteLanguage { } conn - .build_transaction() - .run(|conn| { - Box::pin(async move { - use crate::schema::site_language::dsl::{language_id, site_id, site_language}; - + .transaction::<_, Error, _>(|conn| { + async move { // Delete old languages, not including new languages - let delete_old = delete(site_language) - .filter(site_id.eq(for_site_id)) - .filter(language_id.ne_all(&lang_ids)) - .execute(conn); + delete(site_language::table) + .filter(site_language::site_id.eq(for_site_id)) + .filter(site_language::language_id.ne_all(&lang_ids)) + .execute(conn) + .await?; let forms = lang_ids .iter() @@ -168,18 +163,17 @@ impl SiteLanguage { .collect::>(); // Insert new languages - let insert_new = insert_into(site_language) + insert_into(site_language::table) .values(forms) - .on_conflict((site_id, language_id)) + .on_conflict((site_language::site_id, site_language::language_id)) .do_nothing() - .execute(conn); - - tokio::try_join!(delete_old, insert_new)?; + .execute(conn) + .await?; CommunityLanguage::limit_languages(conn, instance_id).await?; - Ok(()) - }) as _ + } + .scope_boxed() }) .await } @@ -257,7 +251,7 @@ impl CommunityLanguage { pool: &mut DbPool<'_>, mut language_ids: Vec, for_community_id: CommunityId, - ) -> Result<(), Error> { + ) -> Result { if language_ids.is_empty() { language_ids = SiteLanguage::read_local_raw(pool).await?; } @@ -267,7 +261,7 @@ impl CommunityLanguage { // No need to update if languages are unchanged let current = CommunityLanguage::read(&mut conn.into(), for_community_id).await?; if current == lang_ids { - return Ok(()); + return Ok(0); } let form = lang_ids @@ -279,31 +273,27 @@ impl CommunityLanguage { .collect::>(); conn - .build_transaction() - .run(|conn| { - Box::pin(async move { - use crate::schema::community_language::dsl::{ - community_id, - community_language, - language_id, - }; + .transaction::<_, Error, _>(|conn| { + async move { // Delete old languages, not including new languages - let delete_old = delete(community_language) - .filter(community_id.eq(for_community_id)) - .filter(language_id.ne_all(&lang_ids)) - .execute(conn); + delete(community_language::table) + .filter(community_language::community_id.eq(for_community_id)) + .filter(community_language::language_id.ne_all(&lang_ids)) + .execute(conn) + .await?; // Insert new languages - let insert_new = insert_into(community_language) + insert_into(community_language::table) .values(form) - .on_conflict((community_id, language_id)) + .on_conflict(( + community_language::community_id, + community_language::language_id, + )) .do_nothing() - .execute(conn); - - tokio::try_join!(delete_old, insert_new)?; - - Ok(()) - }) as _ + .execute(conn) + .await + } + .scope_boxed() }) .await } diff --git a/crates/db_schema/src/impls/images.rs b/crates/db_schema/src/impls/images.rs index deb9a9e78..f4f53db4c 100644 --- a/crates/db_schema/src/impls/images.rs +++ b/crates/db_schema/src/impls/images.rs @@ -14,7 +14,7 @@ use diesel::{ NotFound, QueryDsl, }; -use diesel_async::RunQueryDsl; +use diesel_async::{scoped_futures::ScopedFutureExt, AsyncConnection, RunQueryDsl}; use url::Url; impl LocalImage { @@ -25,9 +25,8 @@ impl LocalImage { ) -> Result { let conn = &mut get_conn(pool).await?; conn - .build_transaction() - .run(|conn| { - Box::pin(async move { + .transaction::<_, Error, _>(|conn| { + async move { let local_insert = insert_into(local_image::table) .values(form) .get_result::(conn) @@ -36,7 +35,8 @@ impl LocalImage { ImageDetails::create(&mut conn.into(), image_details_form).await?; local_insert - }) as _ + } + .scope_boxed() }) .await } diff --git a/crates/db_schema/src/impls/local_site_url_blocklist.rs b/crates/db_schema/src/impls/local_site_url_blocklist.rs index 73dedabce..d2b0d5918 100644 --- a/crates/db_schema/src/impls/local_site_url_blocklist.rs +++ b/crates/db_schema/src/impls/local_site_url_blocklist.rs @@ -4,18 +4,20 @@ use crate::{ utils::{get_conn, DbPool}, }; use diesel::{dsl::insert_into, result::Error}; -use diesel_async::{AsyncPgConnection, RunQueryDsl}; +use diesel_async::{ + scoped_futures::ScopedFutureExt, + AsyncConnection, + AsyncPgConnection, + RunQueryDsl, +}; impl LocalSiteUrlBlocklist { - pub async fn replace(pool: &mut DbPool<'_>, url_blocklist: Vec) -> Result<(), Error> { + pub async fn replace(pool: &mut DbPool<'_>, url_blocklist: Vec) -> Result { let conn = &mut get_conn(pool).await?; conn - .build_transaction() - .run(|conn| { - Box::pin(async move { - use crate::schema::local_site_url_blocklist::dsl::local_site_url_blocklist; - + .transaction::<_, Error, _>(|conn| { + async move { Self::clear(conn).await?; let forms = url_blocklist @@ -23,13 +25,12 @@ impl LocalSiteUrlBlocklist { .map(|url| LocalSiteUrlBlocklistForm { url, updated: None }) .collect::>(); - insert_into(local_site_url_blocklist) + insert_into(local_site_url_blocklist::table) .values(forms) .execute(conn) - .await?; - - Ok(()) - }) as _ + .await + } + .scope_boxed() }) .await }