use crate::authentication::AuthError; use crate::authentication::{validate_credentials, Credentials}; use crate::routes::error_chain_fmt; use crate::session_state::TypedSession; use actix_web::error::InternalError; use actix_web::http::header::LOCATION; use actix_web::web; use actix_web::HttpResponse; use actix_web_flash_messages::FlashMessage; use secrecy::Secret; use sqlx::PgPool; #[derive(serde::Deserialize)] pub struct FormData { username: String, password: Secret, } #[tracing::instrument( skip(form, pool, session), fields(username=tracing::field::Empty, user_id=tracing::field::Empty) )] // We are now injecting `PgPool` to retrieve stored credentials from the database pub async fn login( form: web::Form, pool: web::Data, session: TypedSession, ) -> Result> { let credentials = Credentials { username: form.0.username, password: form.0.password, }; tracing::Span::current().record("username", &tracing::field::display(&credentials.username)); match validate_credentials(credentials, &pool).await { Ok(user_id) => { tracing::Span::current().record("user_id", &tracing::field::display(&user_id)); session.renew(); session .insert_user_id(user_id) .map_err(|e| login_redirect(LoginError::UnexpectedError(e.into())))?; Ok(HttpResponse::SeeOther() .insert_header((LOCATION, "/admin/dashboard")) .finish()) } Err(e) => { let e = match e { AuthError::InvalidCredentials(_) => LoginError::AuthError(e.into()), AuthError::UnexpectedError(_) => LoginError::UnexpectedError(e.into()), }; Err(login_redirect(e)) } } } fn login_redirect(e: LoginError) -> InternalError { FlashMessage::error(e.to_string()).send(); let response = HttpResponse::SeeOther() .insert_header((LOCATION, "/login")) .finish(); InternalError::from_response(e, response) } #[derive(thiserror::Error)] pub enum LoginError { #[error("Authentication failed")] AuthError(#[source] anyhow::Error), #[error("Something went wrong")] UnexpectedError(#[from] anyhow::Error), } impl std::fmt::Debug for LoginError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { error_chain_fmt(self, f) } }