update the tokenstore sweeping function to rely on caches

This commit is contained in:
kim 2024-04-15 12:37:48 +01:00
parent 511a556cb2
commit d476a62fb5
4 changed files with 67 additions and 4 deletions

View file

@ -1210,6 +1210,7 @@ func (c *Caches) initToken() {
c.GTS.Token.Init(structr.CacheConfig[*gtsmodel.Token]{
Indices: []structr.IndexConfig{
{Fields: "ID"},
{Fields: "Code"},
{Fields: "Access"},
{Fields: "Refresh"},

View file

@ -45,6 +45,9 @@ type Application interface {
// DeleteClientByID ...
DeleteClientByID(ctx context.Context, id string) error
// GetAllTokens ...
GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error)
// GetTokenByCode ...
GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error)
@ -57,6 +60,9 @@ type Application interface {
// PutToken ...
PutToken(ctx context.Context, token *gtsmodel.Token) error
// DeleteTokenByID ...
DeleteTokenByID(ctx context.Context, id string) error
// DeleteTokenByCode ...
DeleteTokenByCode(ctx context.Context, code string) error

View file

@ -22,6 +22,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -131,6 +132,48 @@ func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error {
return nil
}
func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) {
var tokenIDs []string
// Select ALL token IDs.
if err := a.db.NewSelect().
Table("tokens").
Column("id").
Scan(ctx, &tokenIDs); err != nil {
return nil, err
}
// Load all input token IDs via cache loader callback.
tokens, err := a.state.Caches.GTS.Token.LoadIDs("ID",
tokenIDs,
func(uncached []string) ([]*gtsmodel.Token, error) {
// Preallocate expected length of uncached tokens.
tokens := make([]*gtsmodel.Token, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) token IDs.
if err := a.db.NewSelect().
Model(tokens).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return tokens, nil
},
)
if err != nil {
return nil, err
}
// Reoroder the tokens by their
// IDs to ensure in correct order.
getID := func(t *gtsmodel.Token) string { return t.ID }
util.OrderBy(tokens, tokenIDs, getID)
return tokens, nil
}
func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) {
return a.getTokenBy(
"Code",
@ -180,6 +223,19 @@ func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) err
})
}
func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
_, err := a.db.NewDelete().
Table("tokens").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
if err != nil {
return err
}
a.state.Caches.GTS.Token.Invalidate("ID", id)
return nil
}
func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error {
_, err := a.db.NewDelete().
Table("tokens").

View file

@ -68,19 +68,19 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens := new([]*gtsmodel.Token)
if err := ts.db.GetAll(ctx, tokens); err != nil {
tokens, err := ts.db.GetAllTokens(ctx)
if err != nil {
return err
}
// iterate through and remove expired tokens
now := time.Now()
for _, dbt := range *tokens {
for _, dbt := range tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set.
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil {
if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil {
return err
}
}