mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-06-13 02:39:27 +00:00
update the tokenstore sweeping function to rely on caches
This commit is contained in:
parent
511a556cb2
commit
d476a62fb5
1
internal/cache/db.go
vendored
1
internal/cache/db.go
vendored
|
@ -1210,6 +1210,7 @@ func (c *Caches) initToken() {
|
|||
|
||||
c.GTS.Token.Init(structr.CacheConfig[*gtsmodel.Token]{
|
||||
Indices: []structr.IndexConfig{
|
||||
{Fields: "ID"},
|
||||
{Fields: "Code"},
|
||||
{Fields: "Access"},
|
||||
{Fields: "Refresh"},
|
||||
|
|
|
@ -45,6 +45,9 @@ type Application interface {
|
|||
// DeleteClientByID ...
|
||||
DeleteClientByID(ctx context.Context, id string) error
|
||||
|
||||
// GetAllTokens ...
|
||||
GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error)
|
||||
|
||||
// GetTokenByCode ...
|
||||
GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error)
|
||||
|
||||
|
@ -57,6 +60,9 @@ type Application interface {
|
|||
// PutToken ...
|
||||
PutToken(ctx context.Context, token *gtsmodel.Token) error
|
||||
|
||||
// DeleteTokenByID ...
|
||||
DeleteTokenByID(ctx context.Context, id string) error
|
||||
|
||||
// DeleteTokenByCode ...
|
||||
DeleteTokenByCode(ctx context.Context, code string) error
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
|
@ -131,6 +132,48 @@ func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) {
|
||||
var tokenIDs []string
|
||||
|
||||
// Select ALL token IDs.
|
||||
if err := a.db.NewSelect().
|
||||
Table("tokens").
|
||||
Column("id").
|
||||
Scan(ctx, &tokenIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load all input token IDs via cache loader callback.
|
||||
tokens, err := a.state.Caches.GTS.Token.LoadIDs("ID",
|
||||
tokenIDs,
|
||||
func(uncached []string) ([]*gtsmodel.Token, error) {
|
||||
// Preallocate expected length of uncached tokens.
|
||||
tokens := make([]*gtsmodel.Token, 0, len(uncached))
|
||||
|
||||
// Perform database query scanning
|
||||
// the remaining (uncached) token IDs.
|
||||
if err := a.db.NewSelect().
|
||||
Model(tokens).
|
||||
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reoroder the tokens by their
|
||||
// IDs to ensure in correct order.
|
||||
getID := func(t *gtsmodel.Token) string { return t.ID }
|
||||
util.OrderBy(tokens, tokenIDs, getID)
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) {
|
||||
return a.getTokenBy(
|
||||
"Code",
|
||||
|
@ -180,6 +223,19 @@ func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) err
|
|||
})
|
||||
}
|
||||
|
||||
func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
|
||||
_, err := a.db.NewDelete().
|
||||
Table("tokens").
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.state.Caches.GTS.Token.Invalidate("ID", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error {
|
||||
_, err := a.db.NewDelete().
|
||||
Table("tokens").
|
||||
|
|
|
@ -68,19 +68,19 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
|
|||
func (ts *tokenStore) sweep(ctx context.Context) error {
|
||||
// select *all* tokens from the db
|
||||
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
|
||||
tokens := new([]*gtsmodel.Token)
|
||||
if err := ts.db.GetAll(ctx, tokens); err != nil {
|
||||
tokens, err := ts.db.GetAllTokens(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// iterate through and remove expired tokens
|
||||
now := time.Now()
|
||||
for _, dbt := range *tokens {
|
||||
for _, dbt := range tokens {
|
||||
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
|
||||
// we only want to check if a token expired before now if the expiry time is *not zero*;
|
||||
// ie., if it's been explicity set.
|
||||
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
|
||||
if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil {
|
||||
if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue