change stats function signatures

This commit is contained in:
tobi 2024-04-16 12:36:18 +02:00
parent 30aa7666c6
commit 9ff5e0cbab
10 changed files with 69 additions and 98 deletions

View file

@ -194,9 +194,7 @@ func (suite *StatusPinTestSuite) TestPinStatusTooManyPins() {
}
// Regenerate account stats to set pinned count.
var err error
testAccount.Stats, err = suite.db.RegenerateAccountStats(ctx, testAccount.ID)
if err != nil {
if err := suite.db.RegenerateAccountStats(ctx, testAccount); err != nil {
suite.FailNow(err.Error())
}

View file

@ -137,13 +137,19 @@ type Account interface {
// Update local account settings.
UpdateAccountSettings(ctx context.Context, settings *gtsmodel.AccountSettings, columns ...string) error
// Get (or create and get) account stats for the given accountID.
GetAccountStats(ctx context.Context, accountID string) (*gtsmodel.AccountStats, error)
// PopulateAccountStats gets (or creates and gets) account stats for
// the given account, and attaches them to the account model.
PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error
// RegenerateAccountStats creates, upserts, and returns stats for the given accountID.
// Unlike GetAccountStats, it will always get the database stats fresh. This can be
// used to "refresh" stats. Callers should prefer GetAccountStats in 99% of cases.
RegenerateAccountStats(ctx context.Context, accountID string) (*gtsmodel.AccountStats, error)
// RegenerateAccountStats creates, upserts, and returns stats
// for the given account, and attaches them to the account model.
//
// Unlike GetAccountStats, it will always get the database stats fresh.
// This can be used to "refresh" stats.
//
// Because this involves database calls that can be expensive (on Postgres
// specifically), callers should prefer GetAccountStats in 99% of cases.
RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error
// Update account stats.
UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error

View file

@ -632,11 +632,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
if account.Stats == nil {
// Get / Create stats for this account.
account.Stats, err = a.state.DB.GetAccountStats(
ctx,
account.ID,
)
if err != nil {
if err := a.state.DB.PopulateAccountStats(ctx, account); err != nil {
errs.Appendf("error populating account stats: %w", err)
}
}
@ -1081,7 +1077,7 @@ func (a *accountDB) UpdateAccountSettings(
})
}
func (a *accountDB) GetAccountStats(ctx context.Context, accountID string) (*gtsmodel.AccountStats, error) {
func (a *accountDB) PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error {
// Fetch stats from db cache with loader callback.
stats, err := a.state.Caches.GTS.AccountStats.LoadOne(
"AccountID",
@ -1091,45 +1087,39 @@ func (a *accountDB) GetAccountStats(ctx context.Context, accountID string) (*gts
if err := a.db.
NewSelect().
Model(&stats).
Where("? = ?", bun.Ident("account_stats.account_id"), accountID).
Where("? = ?", bun.Ident("account_stats.account_id"), account.ID).
Scan(ctx); err != nil {
return nil, err
}
return &stats, nil
},
accountID,
account.ID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Real error.
return nil, err
return err
}
if stats == nil {
// Don't have stats yet, generate them.
return a.RegenerateAccountStats(ctx, accountID)
return a.RegenerateAccountStats(ctx, account)
}
// We have a stats, attach
// it to the account.
account.Stats = stats
// Check if this is a local
// stats by looking at the
// account they pertain to.
statsAcct, err := a.GetAccountByID(
// Must be barebones to avoid
// getting stuck in a loop.
gtscontext.SetBarebones(ctx),
accountID,
)
if err != nil {
return nil, gtserror.Newf("db error getting stats account: %w", err)
}
if statsAcct.IsRemote() {
if account.IsRemote() {
// Account is remote. Updating
// stats for remote accounts is
// handled in the dereferencer.
//
// Nothing more to do!
return stats, nil
return nil
}
// Stats account is local, check
@ -1138,41 +1128,41 @@ func (a *accountDB) GetAccountStats(ctx context.Context, accountID string) (*gts
expiry := stats.RegeneratedAt.Add(statsFreshness)
if time.Now().After(expiry) {
// Stats have expired, regenerate them.
return a.RegenerateAccountStats(ctx, accountID)
return a.RegenerateAccountStats(ctx, account)
}
// Stats are still fresh.
return stats, nil
return nil
}
func (a *accountDB) RegenerateAccountStats(ctx context.Context, accountID string) (*gtsmodel.AccountStats, error) {
func (a *accountDB) RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error {
// Initialize a new stats struct.
stats := &gtsmodel.AccountStats{
AccountID: accountID,
AccountID: account.ID,
RegeneratedAt: time.Now(),
}
// Count followers outside of transaction since
// it uses a cache + requires its own db calls.
followerIDs, err := a.state.DB.GetAccountFollowerIDs(ctx, accountID, nil)
followerIDs, err := a.state.DB.GetAccountFollowerIDs(ctx, account.ID, nil)
if err != nil {
return nil, err
return err
}
stats.FollowersCount = util.Ptr(len(followerIDs))
// Count following outside of transaction since
// it uses a cache + requires its own db calls.
followIDs, err := a.state.DB.GetAccountFollowIDs(ctx, accountID, nil)
followIDs, err := a.state.DB.GetAccountFollowIDs(ctx, account.ID, nil)
if err != nil {
return nil, err
return err
}
stats.FollowingCount = util.Ptr(len(followIDs))
// Count follow requests outside of transaction since
// it uses a cache + requires its own db calls.
followRequestIDs, err := a.state.DB.GetAccountFollowRequestIDs(ctx, accountID, nil)
followRequestIDs, err := a.state.DB.GetAccountFollowRequestIDs(ctx, account.ID, nil)
if err != nil {
return nil, err
return err
}
stats.FollowRequestsCount = util.Ptr(len(followRequestIDs))
@ -1184,7 +1174,7 @@ func (a *accountDB) RegenerateAccountStats(ctx context.Context, accountID string
// Scan database for account statuses.
statusesCount, err := tx.NewSelect().
Table("statuses").
Where("? = ?", bun.Ident("account_id"), accountID).
Where("? = ?", bun.Ident("account_id"), account.ID).
Count(ctx)
if err != nil {
return err
@ -1194,7 +1184,7 @@ func (a *accountDB) RegenerateAccountStats(ctx context.Context, accountID string
// Scan database for pinned statuses.
statusesPinnedCount, err := tx.NewSelect().
Table("statuses").
Where("? = ?", bun.Ident("account_id"), accountID).
Where("? = ?", bun.Ident("account_id"), account.ID).
Where("? IS NOT NULL", bun.Ident("pinned_at")).
Count(ctx)
if err != nil {
@ -1208,7 +1198,7 @@ func (a *accountDB) RegenerateAccountStats(ctx context.Context, accountID string
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.created_at").
Where("? = ?", bun.Ident("status.account_id"), accountID).
Where("? = ?", bun.Ident("status.account_id"), account.ID).
Order("status.id DESC").
Limit(1).
Scan(ctx, &lastStatusAt)
@ -1219,12 +1209,12 @@ func (a *accountDB) RegenerateAccountStats(ctx context.Context, accountID string
return nil
}); err != nil {
return nil, err
return err
}
// Upsert this stats in case a race
// meant someone else inserted it first.
err = a.state.Caches.GTS.AccountStats.Store(stats, func() error {
if err := a.state.Caches.GTS.AccountStats.Store(stats, func() error {
if _, err := NewUpsert(a.db).
Model(stats).
Constraint("account_id").
@ -1232,9 +1222,12 @@ func (a *accountDB) RegenerateAccountStats(ctx context.Context, accountID string
return err
}
return nil
})
}); err != nil {
return err
}
return stats, err
account.Stats = stats
return nil
}
func (a *accountDB) UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error {

View file

@ -656,21 +656,21 @@ func (suite *AccountTestSuite) TestAccountStatsAll() {
// Get stats for the first time. They
// should all be generated now since
// they're not stored in the test rig.
stats, err := suite.db.GetAccountStats(ctx, account.ID)
if err != nil {
if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
suite.FailNow(err.Error())
}
stats := account.Stats
suite.NotNil(stats)
suite.WithinDuration(time.Now(), stats.RegeneratedAt, 5*time.Second)
// Get stats a second time. They shouldn't
// be regenerated since we just did it.
stats2, err := suite.db.GetAccountStats(ctx, account.ID)
if err != nil {
if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(stats)
suite.Equal(stats.RegeneratedAt, stats2.RegeneratedAt)
stats2 := account.Stats
suite.NotNil(stats2)
suite.Equal(stats2.RegeneratedAt, stats.RegeneratedAt)
// Update the stats to indicate they're out of date.
stats2.RegeneratedAt = time.Now().Add(-72 * time.Hour)
@ -681,11 +681,11 @@ func (suite *AccountTestSuite) TestAccountStatsAll() {
// Get stats for a third time, they
// should get regenerated now, but
// only for local accounts.
stats3, err := suite.db.GetAccountStats(ctx, account.ID)
if err != nil {
if err := suite.db.PopulateAccountStats(ctx, account); err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(stats)
stats3 := account.Stats
suite.NotNil(stats3)
if account.IsLocal() {
suite.True(stats3.RegeneratedAt.After(stats.RegeneratedAt))
} else {

View file

@ -1044,9 +1044,7 @@ func (d *Dereferencer) fetchRemoteAccountEmojis(ctx context.Context, targetAccou
func (d *Dereferencer) fetchRemoteAccountStats(ctx context.Context, account *gtsmodel.Account, requestUser string) error {
// Ensure we have a stats model for this account.
if account.Stats == nil {
var err error
account.Stats, err = d.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := d.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}

View file

@ -71,8 +71,7 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)
// Ensure account stats populated.
if account.Stats == nil {
account.Stats, err = p.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := p.state.DB.PopulateAccountStats(ctx, account); err != nil {
err = gtserror.Newf("db error getting account stats %s: %w", username, err)
return nil, never, gtserror.NewErrorInternalError(err)
}

View file

@ -128,8 +128,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUser string, page
// Ensure we have stats for this account.
if receiver.Stats == nil {
receiver.Stats, err = p.state.DB.GetAccountStats(ctx, receiver.ID)
if err != nil {
if err := p.state.DB.PopulateAccountStats(ctx, receiver); err != nil {
err := gtserror.Newf("error getting stats for account %s: %w", receiver.ID, err)
return nil, gtserror.NewErrorInternalError(err)
}
@ -239,8 +238,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUser string, page
// Ensure we have stats for this account.
if receiver.Stats == nil {
receiver.Stats, err = p.state.DB.GetAccountStats(ctx, receiver.ID)
if err != nil {
if err := p.state.DB.PopulateAccountStats(ctx, receiver); err != nil {
err := gtserror.Newf("error getting stats for account %s: %w", receiver.ID, err)
return nil, gtserror.NewErrorInternalError(err)
}

View file

@ -93,9 +93,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A
// Ensure account stats populated.
if requestingAccount.Stats == nil {
var err error
requestingAccount.Stats, err = p.state.DB.GetAccountStats(ctx, requestingAccount.ID)
if err != nil {
if err := p.state.DB.PopulateAccountStats(ctx, requestingAccount); err != nil {
err = gtserror.Newf("db error getting account stats: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
@ -160,9 +158,7 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A
// Ensure account stats populated.
if requestingAccount.Stats == nil {
var err error
requestingAccount.Stats, err = p.state.DB.GetAccountStats(ctx, requestingAccount.ID)
if err != nil {
if err := p.state.DB.PopulateAccountStats(ctx, requestingAccount); err != nil {
err = gtserror.Newf("db error getting account stats: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}

View file

@ -246,9 +246,7 @@ func (u *utilF) incrementStatusesCount(
) error {
// Populate stats.
if account.Stats == nil {
var err error
account.Stats, err = u.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
@ -275,9 +273,7 @@ func (u *utilF) decrementStatusesCount(
) error {
// Populate stats.
if account.Stats == nil {
var err error
account.Stats, err = u.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
@ -307,9 +303,7 @@ func (u *utilF) incrementFollowersCount(
) error {
// Populate stats.
if account.Stats == nil {
var err error
account.Stats, err = u.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
@ -334,9 +328,7 @@ func (u *utilF) decrementFollowersCount(
) error {
// Populate stats.
if account.Stats == nil {
var err error
account.Stats, err = u.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
@ -366,9 +358,7 @@ func (u *utilF) incrementFollowingCount(
) error {
// Populate stats.
if account.Stats == nil {
var err error
account.Stats, err = u.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
@ -393,9 +383,7 @@ func (u *utilF) decrementFollowingCount(
) error {
// Populate stats.
if account.Stats == nil {
var err error
account.Stats, err = u.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
@ -425,9 +413,7 @@ func (u *utilF) incrementFollowRequestsCount(
) error {
// Populate stats.
if account.Stats == nil {
var err error
account.Stats, err = u.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}
@ -452,9 +438,7 @@ func (u *utilF) decrementFollowRequestsCount(
) error {
// Populate stats.
if account.Stats == nil {
var err error
account.Stats, err = u.state.DB.GetAccountStats(ctx, account.ID)
if err != nil {
if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil {
return gtserror.Newf("db error getting account stats: %w", err)
}
}

View file

@ -73,8 +73,7 @@ func (c *Converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode
// Ensure account stats populated.
if a.Stats == nil {
a.Stats, err = c.state.DB.GetAccountStats(ctx, a.ID)
if err != nil {
if err := c.state.DB.PopulateAccountStats(ctx, a); err != nil {
return nil, gtserror.Newf(
"error getting stats for account %s: %w",
a.ID, err,