feat: add PKCE

This commit is contained in:
avdb13 2024-11-15 00:17:17 +00:00
parent 7f4e26e29e
commit 81b06cc911
10 changed files with 61 additions and 10 deletions

View file

@ -25,6 +25,8 @@ pub struct CreateOAuthProvider {
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub account_linking_enabled: Option<bool>, pub account_linking_enabled: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub use_pkce: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))]
pub enabled: Option<bool>, pub enabled: Option<bool>,
} }
@ -54,6 +56,8 @@ pub struct EditOAuthProvider {
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub account_linking_enabled: Option<bool>, pub account_linking_enabled: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]
pub use_pkce: Option<bool>,
#[cfg_attr(feature = "full", ts(optional))]
pub enabled: Option<bool>, pub enabled: Option<bool>,
} }
@ -72,6 +76,7 @@ pub struct DeleteOAuthProvider {
/// Logging in with an OAuth 2.0 authorization /// Logging in with an OAuth 2.0 authorization
pub struct AuthenticateWithOauth { pub struct AuthenticateWithOauth {
pub code: String, pub code: String,
pub pkce_code_verifier: Option<String>,
pub oauth_provider_id: OAuthProviderId, pub oauth_provider_id: OAuthProviderId,
pub redirect_uri: Url, pub redirect_uri: Url,
#[cfg_attr(feature = "full", ts(optional))] #[cfg_attr(feature = "full", ts(optional))]

View file

@ -1162,6 +1162,19 @@ fn build_proxied_image_url(
)) ))
} }
pub fn check_code_verifier(code_verifier: &str) -> LemmyResult<&str> {
static VALID_CODE_VERIFIER_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9\-._~]{43,128}$").expect("compile regex"));
let check = VALID_CODE_VERIFIER_REGEX.is_match(code_verifier);
if !check {
Err(LemmyErrorType::InvalidCodeVerifier.into())
} else {
Ok(code_verifier)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View file

@ -35,6 +35,7 @@ pub async fn create_oauth_provider(
scopes: data.scopes.to_string(), scopes: data.scopes.to_string(),
auto_verify_email: data.auto_verify_email, auto_verify_email: data.auto_verify_email,
account_linking_enabled: data.account_linking_enabled, account_linking_enabled: data.account_linking_enabled,
use_pkce: data.use_pkce,
enabled: data.enabled, enabled: data.enabled,
}; };
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?; let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;

View file

@ -32,6 +32,7 @@ pub async fn update_oauth_provider(
auto_verify_email: data.auto_verify_email, auto_verify_email: data.auto_verify_email,
account_linking_enabled: data.account_linking_enabled, account_linking_enabled: data.account_linking_enabled,
enabled: data.enabled, enabled: data.enabled,
use_pkce: data.use_pkce,
updated: Some(Some(naive_now())), updated: Some(Some(naive_now())),
}; };

View file

@ -6,6 +6,7 @@ use lemmy_api_common::{
oauth_provider::AuthenticateWithOauth, oauth_provider::AuthenticateWithOauth,
person::{LoginResponse, Register}, person::{LoginResponse, Register},
utils::{ utils::{
check_code_verifier,
check_email_verified, check_email_verified,
check_registration_application, check_registration_application,
check_user_valid, check_user_valid,
@ -228,9 +229,20 @@ pub async fn authenticate_with_oauth(
return Err(LemmyErrorType::OauthAuthorizationInvalid)?; return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
} }
let token_response = let pkce_code_verifier = data
oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str()) .pkce_code_verifier
.await?; .as_deref()
.map(check_code_verifier)
.transpose()?;
let token_response = oauth_request_access_token(
&context,
&oauth_provider,
&data.code,
pkce_code_verifier,
redirect_uri.as_str(),
)
.await?;
let user_info = oidc_get_user_info( let user_info = oidc_get_user_info(
&context, &context,
@ -509,20 +521,26 @@ async fn oauth_request_access_token(
context: &Data<LemmyContext>, context: &Data<LemmyContext>,
oauth_provider: &OAuthProvider, oauth_provider: &OAuthProvider,
code: &str, code: &str,
pkce_code_verifier: Option<&str>,
redirect_uri: &str, redirect_uri: &str,
) -> LemmyResult<TokenResponse> { ) -> LemmyResult<TokenResponse> {
let mut form = vec![
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("client_id", &oauth_provider.client_id),
("client_secret", &oauth_provider.client_secret),
];
if let Some(code_verifier) = pkce_code_verifier {
form.push(("code_verifier", code_verifier));
}
// Request an Access Token from the OAUTH provider // Request an Access Token from the OAUTH provider
let response = context let response = context
.client() .client()
.post(oauth_provider.token_endpoint.as_str()) .post(oauth_provider.token_endpoint.as_str())
.header("Accept", "application/json") .header("Accept", "application/json")
.form(&[ .form(&*form)
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("client_id", &oauth_provider.client_id),
("client_secret", &oauth_provider.client_secret),
])
.send() .send()
.await; .await;

View file

@ -631,6 +631,7 @@ diesel::table! {
scopes -> Text, scopes -> Text,
auto_verify_email -> Bool, auto_verify_email -> Bool,
account_linking_enabled -> Bool, account_linking_enabled -> Bool,
use_pkce -> Bool,
enabled -> Bool, enabled -> Bool,
published -> Timestamptz, published -> Timestamptz,
updated -> Nullable<Timestamptz>, updated -> Nullable<Timestamptz>,

View file

@ -57,6 +57,8 @@ pub struct OAuthProvider {
pub auto_verify_email: bool, pub auto_verify_email: bool,
/// Allows linking an OAUTH account to an existing user account by matching emails /// Allows linking an OAUTH account to an existing user account by matching emails
pub account_linking_enabled: bool, pub account_linking_enabled: bool,
/// switch to enable or disable PKCE
pub use_pkce: bool,
/// switch to enable or disable an oauth provider /// switch to enable or disable an oauth provider
pub enabled: bool, pub enabled: bool,
pub published: DateTime<Utc>, pub published: DateTime<Utc>,
@ -83,6 +85,7 @@ impl Serialize for PublicOAuthProvider {
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?; state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
state.serialize_field("client_id", &self.0.client_id)?; state.serialize_field("client_id", &self.0.client_id)?;
state.serialize_field("scopes", &self.0.scopes)?; state.serialize_field("scopes", &self.0.scopes)?;
state.serialize_field("use_pkce", &self.0.use_pkce)?;
state.end() state.end()
} }
} }
@ -102,6 +105,7 @@ pub struct OAuthProviderInsertForm {
pub scopes: String, pub scopes: String,
pub auto_verify_email: Option<bool>, pub auto_verify_email: Option<bool>,
pub account_linking_enabled: Option<bool>, pub account_linking_enabled: Option<bool>,
pub use_pkce: Option<bool>,
pub enabled: Option<bool>, pub enabled: Option<bool>,
} }
@ -118,6 +122,7 @@ pub struct OAuthProviderUpdateForm {
pub scopes: Option<String>, pub scopes: Option<String>,
pub auto_verify_email: Option<bool>, pub auto_verify_email: Option<bool>,
pub account_linking_enabled: Option<bool>, pub account_linking_enabled: Option<bool>,
pub use_pkce: Option<bool>,
pub enabled: Option<bool>, pub enabled: Option<bool>,
pub updated: Option<Option<DateTime<Utc>>>, pub updated: Option<Option<DateTime<Utc>>>,
} }

View file

@ -76,6 +76,7 @@ pub enum LemmyErrorType {
MissingAnEmail, MissingAnEmail,
RateLimitError, RateLimitError,
InvalidName, InvalidName,
InvalidCodeVerifier,
InvalidDisplayName, InvalidDisplayName,
InvalidMatrixId, InvalidMatrixId,
InvalidPostTitle, InvalidPostTitle,

View file

@ -0,0 +1,3 @@
ALTER TABLE oauth_provider
DROP COLUMN use_pkce;

View file

@ -0,0 +1,3 @@
ALTER TABLE oauth_provider
ADD COLUMN use_pkce boolean DEFAULT FALSE NOT NULL;