mirror of
https://github.com/LemmyNet/lemmy.git
synced 2024-11-26 03:11:08 +00:00
feat: add PKCE
This commit is contained in:
parent
7f4e26e29e
commit
81b06cc911
10 changed files with 61 additions and 10 deletions
|
@ -25,6 +25,8 @@ pub struct CreateOAuthProvider {
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub account_linking_enabled: Option<bool>,
|
pub account_linking_enabled: Option<bool>,
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
|
pub use_pkce: Option<bool>,
|
||||||
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub enabled: Option<bool>,
|
pub enabled: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +56,8 @@ pub struct EditOAuthProvider {
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub account_linking_enabled: Option<bool>,
|
pub account_linking_enabled: Option<bool>,
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
|
pub use_pkce: Option<bool>,
|
||||||
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
pub enabled: Option<bool>,
|
pub enabled: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,6 +76,7 @@ pub struct DeleteOAuthProvider {
|
||||||
/// Logging in with an OAuth 2.0 authorization
|
/// Logging in with an OAuth 2.0 authorization
|
||||||
pub struct AuthenticateWithOauth {
|
pub struct AuthenticateWithOauth {
|
||||||
pub code: String,
|
pub code: String,
|
||||||
|
pub pkce_code_verifier: Option<String>,
|
||||||
pub oauth_provider_id: OAuthProviderId,
|
pub oauth_provider_id: OAuthProviderId,
|
||||||
pub redirect_uri: Url,
|
pub redirect_uri: Url,
|
||||||
#[cfg_attr(feature = "full", ts(optional))]
|
#[cfg_attr(feature = "full", ts(optional))]
|
||||||
|
|
|
@ -1162,6 +1162,19 @@ fn build_proxied_image_url(
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn check_code_verifier(code_verifier: &str) -> LemmyResult<&str> {
|
||||||
|
static VALID_CODE_VERIFIER_REGEX: LazyLock<Regex> =
|
||||||
|
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9\-._~]{43,128}$").expect("compile regex"));
|
||||||
|
|
||||||
|
let check = VALID_CODE_VERIFIER_REGEX.is_match(code_verifier);
|
||||||
|
|
||||||
|
if !check {
|
||||||
|
Err(LemmyErrorType::InvalidCodeVerifier.into())
|
||||||
|
} else {
|
||||||
|
Ok(code_verifier)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ pub async fn create_oauth_provider(
|
||||||
scopes: data.scopes.to_string(),
|
scopes: data.scopes.to_string(),
|
||||||
auto_verify_email: data.auto_verify_email,
|
auto_verify_email: data.auto_verify_email,
|
||||||
account_linking_enabled: data.account_linking_enabled,
|
account_linking_enabled: data.account_linking_enabled,
|
||||||
|
use_pkce: data.use_pkce,
|
||||||
enabled: data.enabled,
|
enabled: data.enabled,
|
||||||
};
|
};
|
||||||
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;
|
let oauth_provider = OAuthProvider::create(&mut context.pool(), &oauth_provider_form).await?;
|
||||||
|
|
|
@ -32,6 +32,7 @@ pub async fn update_oauth_provider(
|
||||||
auto_verify_email: data.auto_verify_email,
|
auto_verify_email: data.auto_verify_email,
|
||||||
account_linking_enabled: data.account_linking_enabled,
|
account_linking_enabled: data.account_linking_enabled,
|
||||||
enabled: data.enabled,
|
enabled: data.enabled,
|
||||||
|
use_pkce: data.use_pkce,
|
||||||
updated: Some(Some(naive_now())),
|
updated: Some(Some(naive_now())),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ use lemmy_api_common::{
|
||||||
oauth_provider::AuthenticateWithOauth,
|
oauth_provider::AuthenticateWithOauth,
|
||||||
person::{LoginResponse, Register},
|
person::{LoginResponse, Register},
|
||||||
utils::{
|
utils::{
|
||||||
|
check_code_verifier,
|
||||||
check_email_verified,
|
check_email_verified,
|
||||||
check_registration_application,
|
check_registration_application,
|
||||||
check_user_valid,
|
check_user_valid,
|
||||||
|
@ -228,9 +229,20 @@ pub async fn authenticate_with_oauth(
|
||||||
return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
|
return Err(LemmyErrorType::OauthAuthorizationInvalid)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let token_response =
|
let pkce_code_verifier = data
|
||||||
oauth_request_access_token(&context, &oauth_provider, &data.code, redirect_uri.as_str())
|
.pkce_code_verifier
|
||||||
.await?;
|
.as_deref()
|
||||||
|
.map(check_code_verifier)
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let token_response = oauth_request_access_token(
|
||||||
|
&context,
|
||||||
|
&oauth_provider,
|
||||||
|
&data.code,
|
||||||
|
pkce_code_verifier,
|
||||||
|
redirect_uri.as_str(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let user_info = oidc_get_user_info(
|
let user_info = oidc_get_user_info(
|
||||||
&context,
|
&context,
|
||||||
|
@ -509,20 +521,26 @@ async fn oauth_request_access_token(
|
||||||
context: &Data<LemmyContext>,
|
context: &Data<LemmyContext>,
|
||||||
oauth_provider: &OAuthProvider,
|
oauth_provider: &OAuthProvider,
|
||||||
code: &str,
|
code: &str,
|
||||||
|
pkce_code_verifier: Option<&str>,
|
||||||
redirect_uri: &str,
|
redirect_uri: &str,
|
||||||
) -> LemmyResult<TokenResponse> {
|
) -> LemmyResult<TokenResponse> {
|
||||||
|
let mut form = vec![
|
||||||
|
("grant_type", "authorization_code"),
|
||||||
|
("code", code),
|
||||||
|
("redirect_uri", redirect_uri),
|
||||||
|
("client_id", &oauth_provider.client_id),
|
||||||
|
("client_secret", &oauth_provider.client_secret),
|
||||||
|
];
|
||||||
|
if let Some(code_verifier) = pkce_code_verifier {
|
||||||
|
form.push(("code_verifier", code_verifier));
|
||||||
|
}
|
||||||
|
|
||||||
// Request an Access Token from the OAUTH provider
|
// Request an Access Token from the OAUTH provider
|
||||||
let response = context
|
let response = context
|
||||||
.client()
|
.client()
|
||||||
.post(oauth_provider.token_endpoint.as_str())
|
.post(oauth_provider.token_endpoint.as_str())
|
||||||
.header("Accept", "application/json")
|
.header("Accept", "application/json")
|
||||||
.form(&[
|
.form(&*form)
|
||||||
("grant_type", "authorization_code"),
|
|
||||||
("code", code),
|
|
||||||
("redirect_uri", redirect_uri),
|
|
||||||
("client_id", &oauth_provider.client_id),
|
|
||||||
("client_secret", &oauth_provider.client_secret),
|
|
||||||
])
|
|
||||||
.send()
|
.send()
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
|
@ -631,6 +631,7 @@ diesel::table! {
|
||||||
scopes -> Text,
|
scopes -> Text,
|
||||||
auto_verify_email -> Bool,
|
auto_verify_email -> Bool,
|
||||||
account_linking_enabled -> Bool,
|
account_linking_enabled -> Bool,
|
||||||
|
use_pkce -> Bool,
|
||||||
enabled -> Bool,
|
enabled -> Bool,
|
||||||
published -> Timestamptz,
|
published -> Timestamptz,
|
||||||
updated -> Nullable<Timestamptz>,
|
updated -> Nullable<Timestamptz>,
|
||||||
|
|
|
@ -57,6 +57,8 @@ pub struct OAuthProvider {
|
||||||
pub auto_verify_email: bool,
|
pub auto_verify_email: bool,
|
||||||
/// Allows linking an OAUTH account to an existing user account by matching emails
|
/// Allows linking an OAUTH account to an existing user account by matching emails
|
||||||
pub account_linking_enabled: bool,
|
pub account_linking_enabled: bool,
|
||||||
|
/// switch to enable or disable PKCE
|
||||||
|
pub use_pkce: bool,
|
||||||
/// switch to enable or disable an oauth provider
|
/// switch to enable or disable an oauth provider
|
||||||
pub enabled: bool,
|
pub enabled: bool,
|
||||||
pub published: DateTime<Utc>,
|
pub published: DateTime<Utc>,
|
||||||
|
@ -83,6 +85,7 @@ impl Serialize for PublicOAuthProvider {
|
||||||
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
|
state.serialize_field("authorization_endpoint", &self.0.authorization_endpoint)?;
|
||||||
state.serialize_field("client_id", &self.0.client_id)?;
|
state.serialize_field("client_id", &self.0.client_id)?;
|
||||||
state.serialize_field("scopes", &self.0.scopes)?;
|
state.serialize_field("scopes", &self.0.scopes)?;
|
||||||
|
state.serialize_field("use_pkce", &self.0.use_pkce)?;
|
||||||
state.end()
|
state.end()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,6 +105,7 @@ pub struct OAuthProviderInsertForm {
|
||||||
pub scopes: String,
|
pub scopes: String,
|
||||||
pub auto_verify_email: Option<bool>,
|
pub auto_verify_email: Option<bool>,
|
||||||
pub account_linking_enabled: Option<bool>,
|
pub account_linking_enabled: Option<bool>,
|
||||||
|
pub use_pkce: Option<bool>,
|
||||||
pub enabled: Option<bool>,
|
pub enabled: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +122,7 @@ pub struct OAuthProviderUpdateForm {
|
||||||
pub scopes: Option<String>,
|
pub scopes: Option<String>,
|
||||||
pub auto_verify_email: Option<bool>,
|
pub auto_verify_email: Option<bool>,
|
||||||
pub account_linking_enabled: Option<bool>,
|
pub account_linking_enabled: Option<bool>,
|
||||||
|
pub use_pkce: Option<bool>,
|
||||||
pub enabled: Option<bool>,
|
pub enabled: Option<bool>,
|
||||||
pub updated: Option<Option<DateTime<Utc>>>,
|
pub updated: Option<Option<DateTime<Utc>>>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,6 +76,7 @@ pub enum LemmyErrorType {
|
||||||
MissingAnEmail,
|
MissingAnEmail,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
InvalidName,
|
InvalidName,
|
||||||
|
InvalidCodeVerifier,
|
||||||
InvalidDisplayName,
|
InvalidDisplayName,
|
||||||
InvalidMatrixId,
|
InvalidMatrixId,
|
||||||
InvalidPostTitle,
|
InvalidPostTitle,
|
||||||
|
|
3
migrations/2024-11-12-080606_oauth_pkce/down.sql
Normal file
3
migrations/2024-11-12-080606_oauth_pkce/down.sql
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
ALTER TABLE oauth_provider
|
||||||
|
DROP COLUMN use_pkce;
|
||||||
|
|
3
migrations/2024-11-12-080606_oauth_pkce/up.sql
Normal file
3
migrations/2024-11-12-080606_oauth_pkce/up.sql
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
ALTER TABLE oauth_provider
|
||||||
|
ADD COLUMN use_pkce boolean DEFAULT FALSE NOT NULL;
|
||||||
|
|
Loading…
Reference in a new issue