diff --git a/models/user/user.go b/models/user/user.go index ec9b35964d..98903bf18e 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -842,48 +842,46 @@ func countUsers(ctx context.Context, opts *CountUserFilter) int64 { // VerifyUserActiveCode verifies that the code is valid for the given purpose for this user. // If delete is specified, the token will be deleted. -func VerifyUserAuthorizationToken(ctx context.Context, code string, purpose auth.AuthorizationPurpose, delete bool) (*User, error) { +func VerifyUserAuthorizationToken(ctx context.Context, code string, purpose auth.AuthorizationPurpose) (user *User, deleteToken func() error, err error) { lookupKey, validator, found := strings.Cut(code, ":") if !found { - return nil, nil + return nil, nil, nil } authToken, err := auth.FindAuthToken(ctx, lookupKey, purpose) if err != nil { if errors.Is(err, util.ErrNotExist) { - return nil, nil + return nil, nil, nil } - return nil, err + return nil, nil, err } if authToken.IsExpired() { - return nil, auth.DeleteAuthToken(ctx, authToken) + return nil, nil, auth.DeleteAuthToken(ctx, authToken) } rawValidator, err := hex.DecodeString(validator) if err != nil { - return nil, err + return nil, nil, err } if subtle.ConstantTimeCompare([]byte(authToken.HashedValidator), []byte(auth.HashValidator(rawValidator))) == 0 { - return nil, errors.New("validator doesn't match") + return nil, nil, errors.New("validator doesn't match") } u, err := GetUserByID(ctx, authToken.UID) if err != nil { if IsErrUserNotExist(err) { - return nil, nil + return nil, nil, nil } - return nil, err + return nil, nil, err } - if delete { - if err := auth.DeleteAuthToken(ctx, authToken); err != nil { - return nil, err - } + deleteToken = func() error { + return auth.DeleteAuthToken(ctx, authToken) } - return u, nil + return u, deleteToken, nil } // ValidateUser check if user is valid to insert / update into database diff --git a/models/user/user_test.go b/models/user/user_test.go index bc23a5da48..68778f4002 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -741,13 +741,13 @@ func TestVerifyUserAuthorizationToken(t *testing.T) { assert.True(t, ok) t.Run("Wrong purpose", func(t *testing.T) { - u, err := user_model.VerifyUserAuthorizationToken(db.DefaultContext, code, auth.PasswordReset, false) + u, _, err := user_model.VerifyUserAuthorizationToken(db.DefaultContext, code, auth.PasswordReset) require.NoError(t, err) assert.Nil(t, u) }) t.Run("No delete", func(t *testing.T) { - u, err := user_model.VerifyUserAuthorizationToken(db.DefaultContext, code, auth.UserActivation, false) + u, _, err := user_model.VerifyUserAuthorizationToken(db.DefaultContext, code, auth.UserActivation) require.NoError(t, err) assert.EqualValues(t, user.ID, u.ID) @@ -757,9 +757,10 @@ func TestVerifyUserAuthorizationToken(t *testing.T) { }) t.Run("Delete", func(t *testing.T) { - u, err := user_model.VerifyUserAuthorizationToken(db.DefaultContext, code, auth.UserActivation, true) + u, deleteToken, err := user_model.VerifyUserAuthorizationToken(db.DefaultContext, code, auth.UserActivation) require.NoError(t, err) assert.EqualValues(t, user.ID, u.ID) + require.NoError(t, deleteToken()) authToken, err := auth.FindAuthToken(db.DefaultContext, lookupKey, auth.UserActivation) require.ErrorIs(t, err, util.ErrNotExist) diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index 4be3489891..2afaad45a2 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -61,7 +61,7 @@ func autoSignIn(ctx *context.Context) (bool, error) { return false, nil } - u, err := user_model.VerifyUserAuthorizationToken(ctx, authCookie, auth.LongTermAuthorization, false) + u, _, err := user_model.VerifyUserAuthorizationToken(ctx, authCookie, auth.LongTermAuthorization) if err != nil { return false, fmt.Errorf("VerifyUserAuthorizationToken: %w", err) } @@ -677,7 +677,7 @@ func Activate(ctx *context.Context) { return } - user, err := user_model.VerifyUserAuthorizationToken(ctx, code, auth.UserActivation, false) + user, deleteToken, err := user_model.VerifyUserAuthorizationToken(ctx, code, auth.UserActivation) if err != nil { ctx.ServerError("VerifyUserAuthorizationToken", err) return @@ -698,6 +698,11 @@ func Activate(ctx *context.Context) { return } + if err := deleteToken(); err != nil { + ctx.ServerError("deleteToken", err) + return + } + handleAccountActivation(ctx, user) } @@ -746,7 +751,7 @@ func ActivatePost(ctx *context.Context) { return } - user, err := user_model.VerifyUserAuthorizationToken(ctx, code, auth.UserActivation, true) + user, deleteToken, err := user_model.VerifyUserAuthorizationToken(ctx, code, auth.UserActivation) if err != nil { ctx.ServerError("VerifyUserAuthorizationToken", err) return @@ -775,6 +780,11 @@ func ActivatePost(ctx *context.Context) { } } + if err := deleteToken(); err != nil { + ctx.ServerError("deleteToken", err) + return + } + handleAccountActivation(ctx, user) } @@ -835,7 +845,7 @@ func ActivateEmail(ctx *context.Context) { code := ctx.FormString("code") emailStr := ctx.FormString("email") - u, err := user_model.VerifyUserAuthorizationToken(ctx, code, auth.EmailActivation(emailStr), true) + u, deleteToken, err := user_model.VerifyUserAuthorizationToken(ctx, code, auth.EmailActivation(emailStr)) if err != nil { ctx.ServerError("VerifyUserAuthorizationToken", err) return @@ -845,6 +855,11 @@ func ActivateEmail(ctx *context.Context) { return } + if err := deleteToken(); err != nil { + ctx.ServerError("deleteToken", err) + return + } + email, err := user_model.GetEmailAddressOfUser(ctx, emailStr, u.ID) if err != nil { ctx.ServerError("GetEmailAddressOfUser", err) diff --git a/routers/web/auth/password.go b/routers/web/auth/password.go index 363c01c6a8..84f343bfca 100644 --- a/routers/web/auth/password.go +++ b/routers/web/auth/password.go @@ -116,7 +116,7 @@ func commonResetPassword(ctx *context.Context, shouldDeleteToken bool) (*user_mo } // Fail early, don't frustrate the user - u, err := user_model.VerifyUserAuthorizationToken(ctx, code, auth.PasswordReset, shouldDeleteToken) + u, deleteToken, err := user_model.VerifyUserAuthorizationToken(ctx, code, auth.PasswordReset) if err != nil { ctx.ServerError("VerifyUserAuthorizationToken", err) return nil, nil @@ -127,6 +127,13 @@ func commonResetPassword(ctx *context.Context, shouldDeleteToken bool) (*user_mo return nil, nil } + if shouldDeleteToken { + if err := deleteToken(); err != nil { + ctx.ServerError("deleteToken", err) + return nil, nil + } + } + twofa, err := auth.GetTwoFactorByUID(ctx, u.ID) if err != nil { if !auth.IsErrTwoFactorNotEnrolled(err) { diff --git a/tests/integration/user_test.go b/tests/integration/user_test.go index d2b5f112a3..dba1feb399 100644 --- a/tests/integration/user_test.go +++ b/tests/integration/user_test.go @@ -886,11 +886,27 @@ func TestUserActivate(t *testing.T) { assert.False(t, authToken.IsExpired()) assert.EqualValues(t, authToken.HashedValidator, auth_model.HashValidator(rawValidator)) - req = NewRequest(t, "POST", "/user/activate?code="+code) - session.MakeRequest(t, req, http.StatusOK) + t.Run("No password", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() - unittest.AssertNotExistsBean(t, &auth_model.AuthorizationToken{ID: authToken.ID}) - unittest.AssertExistsAndLoadBean(t, &user_model.User{Name: "doesnotexist", IsActive: true}) + req = NewRequest(t, "POST", "/user/activate?code="+code) + session.MakeRequest(t, req, http.StatusOK) + + unittest.AssertExistsIf(t, true, &auth_model.AuthorizationToken{ID: authToken.ID}) + unittest.AssertExistsAndLoadBean(t, &user_model.User{Name: "doesnotexist"}, "is_active = false") + }) + + t.Run("With password", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() + + req = NewRequestWithValues(t, "POST", "/user/activate?code="+code, map[string]string{ + "password": "examplePassword!1", + }) + session.MakeRequest(t, req, http.StatusSeeOther) + + unittest.AssertExistsIf(t, false, &auth_model.AuthorizationToken{ID: authToken.ID}) + unittest.AssertExistsAndLoadBean(t, &user_model.User{Name: "doesnotexist"}, "is_active = true") + }) } func TestUserPasswordReset(t *testing.T) {