diff --git a/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go b/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go index c9f2b3d0f..dad943efa 100644 --- a/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go +++ b/internal/db/bundb/migrations/20231110142330_small_poll_table_tweaks.go @@ -44,7 +44,7 @@ func init() { Table("polls"). Column("expires_at_new"). Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")). - Where("1"). // bun gets angry performing update over all rows + Where("TRUE"). // bun gets angry performing update over all rows Exec(ctx); err != nil { return err } diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index ab6edb4b9..830fb88ec 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -341,9 +341,12 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error var poll gtsmodel.Poll - // Select poll counts from DB. + // Select current poll counts from DB, + // taking minimal columns needed to + // increment/decrement votes. if err := tx.NewSelect(). Model(&poll). + Column("options", "votes", "voters"). Where("? = ?", bun.Ident("id"), vote.PollID). Scan(ctx); err != nil { return err @@ -365,31 +368,35 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { err := p.db.RunInTx(ctx, func(tx Tx) error { - // Delete all vote in poll, - // returning all vote choices. - switch _, err := tx.NewDelete(). + // Delete all votes in poll. + res, err := tx.NewDelete(). Table("poll_votes"). Where("? = ?", bun.Ident("poll_id"), pollID). - Exec(ctx); { - - case err == nil: - // no issue. - - case errors.Is(err, db.ErrNoEntries): - // no votes found, - // return here. - return nil - - default: - // irrecoverable. + Exec(ctx) + if err != nil { + // irrecoverable return err } - var poll gtsmodel.Poll + ra, err := res.RowsAffected() + if err != nil { + // irrecoverable + return err + } - // Select poll counts from DB. + if ra == 0 { + // No poll votes deleted, + // nothing to update. + return nil + } + + // Select current poll counts from DB, + // taking minimal columns needed to + // increment/decrement votes. + var poll gtsmodel.Poll switch err := tx.NewSelect(). Model(&poll). + Column("options", "votes", "voters"). Where("? = ?", bun.Ident("id"), pollID). Scan(ctx); { @@ -410,7 +417,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { poll.ResetVotes() // Finally, update the poll entry. - _, err := tx.NewUpdate(). + _, err = tx.NewUpdate(). Model(&poll). Column("votes", "voters"). Where("? = ?", bun.Ident("id"), pollID). @@ -432,35 +439,37 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { err := p.db.RunInTx(ctx, func(tx Tx) error { - var choices []int + // Slice should only ever be of length + // 0 or 1; it's a slice of slices only + // because we can't LIMIT deletes to 1. + var choicesSl [][]int // Delete vote in poll by account, // returning the ID + choices of the vote. - switch err := tx.NewDelete(). + if err := tx.NewDelete(). Table("poll_votes"). Where("? = ?", bun.Ident("poll_id"), pollID). Where("? = ?", bun.Ident("account_id"), accountID). - Returning("choices"). - Scan(ctx, &choices); { - - case err == nil: - // no issue. - - case errors.Is(err, db.ErrNoEntries): - // no votes found, - // return here. - return nil - - default: + Returning("?", bun.Ident("choices")). + Scan(ctx, &choicesSl); err != nil { // irrecoverable. return err } - var poll gtsmodel.Poll + if len(choicesSl) != 1 { + // No poll votes by this + // acct on this poll. + return nil + } + choices := choicesSl[0] - // Select poll counts from DB. + // Select current poll counts from DB, + // taking minimal columns needed to + // increment/decrement votes. + var poll gtsmodel.Poll switch err := tx.NewSelect(). Model(&poll). + Column("options", "votes", "voters"). Where("? = ?", bun.Ident("id"), pollID). Scan(ctx); { @@ -468,7 +477,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID // no issue. case errors.Is(err, db.ErrNoEntries): - // no votes found, + // no poll found, // return here. return nil diff --git a/internal/db/bundb/poll_test.go b/internal/db/bundb/poll_test.go index 53da2514b..479557c55 100644 --- a/internal/db/bundb/poll_test.go +++ b/internal/db/bundb/poll_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/util" @@ -304,15 +305,64 @@ func (suite *PollTestSuite) TestDeletePollVotes() { suite.NoError(err) // Fetch latest version of poll from database. - poll, err = suite.db.GetPollByID(ctx, poll.ID) + poll, err = suite.db.GetPollByID( + gtscontext.SetBarebones(ctx), + poll.ID, + ) suite.NoError(err) // Check that poll counts are all zero. suite.Equal(*poll.Voters, 0) - suite.Equal(poll.Votes, make([]int, len(poll.Options))) + suite.Equal(make([]int, len(poll.Options)), poll.Votes) } } +func (suite *PollTestSuite) TestDeletePollVotesNoPoll() { + // Create a new context for this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Try to delete votes of nonexistent poll. + nonPollID := "01HF6V4XWTSZWJ80JNPPDTD4DB" + + err := suite.db.DeletePollVotes(ctx, nonPollID) + suite.NoError(err) +} + +func (suite *PollTestSuite) TestDeletePollVotesBy() { + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + for _, vote := range suite.testPollVotes { + // Fetch before version of pollBefore from database. + pollBefore, err := suite.db.GetPollByID(ctx, vote.PollID) + suite.NoError(err) + + // Delete this poll vote. + err = suite.db.DeletePollVoteBy(ctx, vote.PollID, vote.AccountID) + suite.NoError(err) + + // Fetch after version of poll from database. + pollAfter, err := suite.db.GetPollByID(ctx, vote.PollID) + suite.NoError(err) + + // Voters count should be reduced by 1. + suite.Equal(*pollBefore.Voters-1, *pollAfter.Voters) + } +} + +func (suite *PollTestSuite) TestDeletePollVotesByNoAccount() { + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Try to delete a poll by nonexisting account. + pollID := suite.testPolls["local_account_1_status_6_poll"].ID + nonAccountID := "01HF6T545G1G8ZNMY1S3ZXJ608" + + err := suite.db.DeletePollVoteBy(ctx, pollID, nonAccountID) + suite.NoError(err) +} + func TestPollTestSuite(t *testing.T) { suite.Run(t, new(PollTestSuite)) }