[performance] retry db queries on busy errors (#2025)

* catch SQLITE_BUSY errors, wrap bun.DB to use our own busy retrier, remove unnecessary db.Error type

Signed-off-by: kim <grufwub@gmail.com>

* remove dead code

Signed-off-by: kim <grufwub@gmail.com>

* remove more dead code, add missing error arguments

Signed-off-by: kim <grufwub@gmail.com>

* update sqlite to use maxOpenConns()

Signed-off-by: kim <grufwub@gmail.com>

* add uncommitted changes

Signed-off-by: kim <grufwub@gmail.com>

* use direct calls-through for the ConnIface to make sure we don't double query hook

Signed-off-by: kim <grufwub@gmail.com>

* expose underlying bun.DB better

Signed-off-by: kim <grufwub@gmail.com>

* retry on the correct busy error

Signed-off-by: kim <grufwub@gmail.com>

* use longer possible maxRetries for db retry-backoff

Signed-off-by: kim <grufwub@gmail.com>

* remove the note regarding max-open-conns only applying to postgres

Signed-off-by: kim <grufwub@gmail.com>

* improved code commenting

Signed-off-by: kim <grufwub@gmail.com>

* remove unnecessary infof call (just use info)

Signed-off-by: kim <grufwub@gmail.com>

* rename DBConn to WrappedDB to better follow sql package name conventions

Signed-off-by: kim <grufwub@gmail.com>

* update test error string checks

Signed-off-by: kim <grufwub@gmail.com>

* shush linter

Signed-off-by: kim <grufwub@gmail.com>

* update backoff logic to be more transparent

Signed-off-by: kim <grufwub@gmail.com>

---------

Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
kim 2023-07-25 09:34:05 +01:00 committed by GitHub
parent 9eff0d46e4
commit 5f3e095717
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
53 changed files with 1050 additions and 898 deletions

View file

@ -194,10 +194,6 @@ db-tls-ca-cert: ""
#
# If you set the multiplier to less than 1, only one open connection will be used regardless of cpu count.
#
# PLEASE NOTE!!: This setting currently only applies for Postgres. SQLite will always use 1 connection regardless
# of what is set here. This behavior will change in future when we implement better SQLITE_BUSY handling.
# See https://github.com/superseriousbusiness/gotosocial/issues/1407 for more details.
#
# Examples: [16, 8, 10, 2]
# Default: 8
db-max-open-conns-multiplier: 8

View file

@ -27,67 +27,67 @@ import (
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error)
GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong.
GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error)
GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error)
// GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong.
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, error)
// GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong.
GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong.
GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong.
GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong.
GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
PopulateAccount(ctx context.Context, account *gtsmodel.Account) error
// PutAccount puts one account in the database.
PutAccount(ctx context.Context, account *gtsmodel.Account) Error
PutAccount(ctx context.Context, account *gtsmodel.Account) error
// UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) Error
UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) error
// DeleteAccount deletes one account from the database by its ID.
// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the
// account as suspended instead, rather than deleting from the db entirely.
DeleteAccount(ctx context.Context, id string) Error
DeleteAccount(ctx context.Context, id string) error
// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error)
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, error)
// GetAccountFaves fetches faves/likes created by the target accountID.
GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error)
GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, error)
// GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column.
GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
CountAccountStatuses(ctx context.Context, accountID string) (int, Error)
CountAccountStatuses(ctx context.Context, accountID string) (int, error)
// CountAccountPinned returns the total number of pinned statuses owned by account with the given id.
CountAccountPinned(ctx context.Context, accountID string) (int, Error)
CountAccountPinned(ctx context.Context, accountID string) (int, error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
//
// In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, Error)
GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error)
// GetAccountPinnedStatuses returns ONLY statuses owned by the give accountID for which a corresponding StatusPin
// exists in the database. Statuses which are not pinned will not be returned by this function.
@ -95,28 +95,28 @@ type Account interface {
// Statuses will be returned in the order in which they were pinned, from latest pinned to oldest pinned (descending).
//
// In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, Error)
GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error)
// GetAccountWebStatuses is similar to GetAccountStatuses, but it's specifically for returning statuses that
// should be visible via the web view of an account. So, only public, federated statuses that aren't boosts
// or replies.
//
// In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, Error)
GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error)
GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned.
//
// The returned time will be zero if account has never posted anything.
GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, Error)
GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error)
GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error)
}

View file

@ -27,26 +27,26 @@ import (
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(ctx context.Context, username string) (bool, Error)
IsUsernameAvailable(ctx context.Context, username string) (bool, error)
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(ctx context.Context, email string) (bool, Error)
IsEmailAvailable(ctx context.Context, email string) (bool, error)
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, Error)
NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount(ctx context.Context) Error
CreateInstanceAccount(ctx context.Context) error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance(ctx context.Context) Error
CreateInstanceInstance(ctx context.Context) error
}

View file

@ -23,58 +23,58 @@ import "context"
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
CreateTable(ctx context.Context, i interface{}) Error
CreateTable(ctx context.Context, i interface{}) error
// CreateAllTables creates *all* tables necessary for the running of GoToSocial.
// Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized.
CreateAllTables(ctx context.Context) Error
CreateAllTables(ctx context.Context) error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(ctx context.Context, i interface{}) Error
DropTable(ctx context.Context, i interface{}) error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) Error
Stop(ctx context.Context) error
// IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) Error
IsHealthy(ctx context.Context) error
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(ctx context.Context, id string, i interface{}) Error
GetByID(ctx context.Context, id string, i interface{}) error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(ctx context.Context, where []Where, i interface{}) Error
GetWhere(ctx context.Context, where []Where, i interface{}) error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(ctx context.Context, i interface{}) Error
GetAll(ctx context.Context, i interface{}) error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(ctx context.Context, i interface{}) Error
Put(ctx context.Context, i interface{}) error
// UpdateByID updates values of i based on its id.
// If any columns are specified, these will be updated exclusively.
// Otherwise, the whole model will be updated.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) Error
UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error
UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(ctx context.Context, id string, i interface{}) Error
DeleteByID(ctx context.Context, id string, i interface{}) error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(ctx context.Context, where []Where, i interface{}) Error
DeleteWhere(ctx context.Context, where []Where, i interface{}) error
}

View file

@ -38,16 +38,16 @@ import (
)
type accountDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"ID",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.id"), id).
Scan(ctx)
@ -77,12 +77,12 @@ func (a *accountDB) GetAccountsByIDs(ctx context.Context, ids []string) ([]*gtsm
return accounts, nil
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"URI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.uri"), uri).
Scan(ctx)
@ -91,12 +91,12 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
)
}
func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"URL",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.url"), url).
Scan(ctx)
@ -105,7 +105,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
)
}
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) {
if domain != "" {
// Normalize the domain as punycode
var err error
@ -119,7 +119,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
ctx,
"Username.Domain",
func(account *gtsmodel.Account) error {
q := a.conn.NewSelect().
q := a.db.NewSelect().
Model(account)
if domain != "" {
@ -139,12 +139,12 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
)
}
func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"PublicKeyURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.public_key_uri"), id).
Scan(ctx)
@ -153,12 +153,12 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
)
}
func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"InboxURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.inbox_uri"), uri).
Scan(ctx)
@ -167,12 +167,12 @@ func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsm
)
}
func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"OutboxURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.outbox_uri"), uri).
Scan(ctx)
@ -181,12 +181,12 @@ func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gts
)
}
func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"FollowersURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.followers_uri"), uri).
Scan(ctx)
@ -195,12 +195,12 @@ func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*
)
}
func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount(
ctx,
"FollowingURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
return a.db.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.following_uri"), uri).
Scan(ctx)
@ -209,7 +209,7 @@ func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*
)
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error) {
var username string
if domain == "" {
@ -223,14 +223,14 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
return a.GetAccountByUsernameDomain(ctx, username, domain)
}
func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) {
func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {
// Fetch account from database cache with loader callback
account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) {
var account gtsmodel.Account
// Not cached! Perform database query
if err := dbQuery(&account); err != nil {
return nil, a.conn.ProcessError(err)
return nil, a.db.ProcessError(err)
}
return &account, nil
@ -294,12 +294,12 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
return errs.Combine()
}
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error {
return a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses
for _, i := range account.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
@ -317,7 +317,7 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) d
})
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) db.Error {
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) error {
account.UpdatedAt = time.Now()
if len(columns) > 0 {
// If we're updating by column, ensure "updated_at" is included.
@ -328,7 +328,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses
// first clear out any old emoji links
if _, err := tx.
@ -362,7 +362,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
})
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
defer a.state.Caches.GTS.Account().Invalidate("ID", id)
// Load account into cache before attempting a delete,
@ -376,7 +376,7 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
return err
}
return a.conn.RunInTx(ctx, func(tx bun.Tx) error {
return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// clear out any emoji links
if _, err := tx.
NewDelete().
@ -396,10 +396,10 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
})
}
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) {
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) {
createdAt := time.Time{}
q := a.conn.
q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.created_at").
@ -416,12 +416,12 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string,
}
if err := q.Scan(ctx, &createdAt); err != nil {
return time.Time{}, a.conn.ProcessError(err)
return time.Time{}, a.db.ProcessError(err)
}
return createdAt, nil
}
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
if *mediaAttachment.Avatar && *mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@ -437,26 +437,26 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
}
// TODO: there are probably more side effects here that need to be handled
if _, err := a.conn.
if _, err := a.db.
NewInsert().
Model(mediaAttachment).
Exec(ctx); err != nil {
return a.conn.ProcessError(err)
return a.db.ProcessError(err)
}
if _, err := a.conn.
if _, err := a.db.
NewUpdate().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Set("? = ?", column, mediaAttachment.ID).
Where("? = ?", bun.Ident("account.id"), accountID).
Exec(ctx); err != nil {
return a.conn.ProcessError(err)
return a.db.ProcessError(err)
}
return nil
}
func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, db.Error) {
func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, error) {
account, err := a.GetAccountByUsernameDomain(ctx, username, "")
if err != nil {
return "", err
@ -469,7 +469,7 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) (
var accountIDs []string
// Create SELECT account query.
q := a.conn.NewSelect().
q := a.db.NewSelect().
Table("accounts").
Column("id")
@ -486,37 +486,37 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) (
// Execute the query, scanning destination into accountIDs.
if _, err := q.Exec(ctx, &accountIDs); err != nil {
return nil, a.conn.ProcessError(err)
return nil, a.db.ProcessError(err)
}
// Convert account IDs into account objects.
return a.GetAccountsByIDs(ctx, accountIDs)
}
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, error) {
faves := new([]*gtsmodel.StatusFave)
if err := a.conn.
if err := a.db.
NewSelect().
Model(faves).
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
Scan(ctx); err != nil {
return nil, a.conn.ProcessError(err)
return nil, a.db.ProcessError(err)
}
return *faves, nil
}
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
return a.conn.
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) {
return a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.account_id"), accountID).
Count(ctx)
}
func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, db.Error) {
return a.conn.
func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) {
return a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.account_id"), accountID).
@ -524,7 +524,7 @@ func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (i
Count(ctx)
}
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) {
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -536,7 +536,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
frontToBack = true
)
q := a.conn.
q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
@ -562,7 +562,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
// implementation differs between SQLite and Postgres,
// so we have to be thorough to cover all eventualities
q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
switch a.conn.Dialect().Name() {
switch a.db.Dialect().Name() {
case dialect.PG:
return q.
Where("? IS NOT NULL", bun.Ident("status.attachments")).
@ -613,7 +613,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
}
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err)
return nil, a.db.ProcessError(err)
}
// If we're paging up, we still want statuses
@ -628,10 +628,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
return a.statusesFromIDs(ctx, statusIDs)
}
func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, db.Error) {
func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) {
statusIDs := []string{}
q := a.conn.
q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id").
@ -640,13 +640,13 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri
Order("status.pinned_at DESC")
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err)
return nil, a.db.ProcessError(err)
}
return a.statusesFromIDs(ctx, statusIDs)
}
func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) {
func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -655,7 +655,7 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
// Make educated guess for slice size
statusIDs := make([]string, 0, limit)
q := a.conn.
q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
@ -688,16 +688,16 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
q = q.Order("status.id DESC")
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err)
return nil, a.db.ProcessError(err)
}
return a.statusesFromIDs(ctx, statusIDs)
}
func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
blocks := []*gtsmodel.Block{}
fq := a.conn.
fq := a.db.
NewSelect().
Model(&blocks).
Where("? = ?", bun.Ident("block.account_id"), accountID).
@ -717,7 +717,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
}
if err := fq.Scan(ctx); err != nil {
return nil, "", "", a.conn.ProcessError(err)
return nil, "", "", a.db.ProcessError(err)
}
if len(blocks) == 0 {
@ -734,7 +734,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
return accounts, nextMaxID, prevMinID, nil
}
func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, db.Error) {
func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {
// Catch case of no statuses early
if len(statusIDs) == 0 {
return nil, db.ErrNoEntries

View file

@ -260,7 +260,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
}
noCache := &gtsmodel.Account{}
err = dbService.GetConn().
err = dbService.DB().
NewSelect().
Model(noCache).
Where("? = ?", bun.Ident("account.id"), testAccount.ID).
@ -288,7 +288,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
suite.Empty(updated.EmojiIDs)
suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
err = dbService.GetConn().
err = dbService.DB().
NewSelect().
Model(noCache).
Where("? = ?", bun.Ident("account.id"), testAccount.ID).

View file

@ -44,21 +44,21 @@ import (
const rsaKeyBits = 2048
type adminDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
q := a.conn.
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, error) {
q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Column("account.id").
Where("? = ?", bun.Ident("account.username"), username).
Where("? IS NULL", bun.Ident("account.domain"))
return a.conn.NotExists(ctx, q)
return a.db.NotExists(ctx, q)
}
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, error) {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
@ -67,12 +67,12 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
emailDomainBlockedQ := a.conn.
emailDomainBlockedQ := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
Column("email_domain_block.id").
Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ)
emailDomainBlocked, err := a.db.Exists(ctx, emailDomainBlockedQ)
if err != nil {
return false, err
}
@ -81,16 +81,16 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
}
// check if this email is associated with a user already
q := a.conn.
q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
Column("user.id").
Where("? = ?", bun.Ident("user.email"), email).
WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
return a.conn.NotExists(ctx, q)
return a.db.NotExists(ctx, q)
}
func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, db.Error) {
func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error) {
// If something went wrong previously while doing a new
// sign up with this username, we might already have an
// account, so check first.
@ -220,17 +220,17 @@ func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (
return user, nil
}
func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
func (a *adminDB) CreateInstanceAccount(ctx context.Context) error {
username := config.GetHost()
q := a.conn.
q := a.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Column("account.id").
Where("? = ?", bun.Ident("account.username"), username).
Where("? IS NULL", bun.Ident("account.domain"))
exists, err := a.conn.Exists(ctx, q)
exists, err := a.db.Exists(ctx, q)
if err != nil {
return err
}
@ -277,18 +277,18 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return nil
}
func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
func (a *adminDB) CreateInstanceInstance(ctx context.Context) error {
protocol := config.GetProtocol()
host := config.GetHost()
// check if instance entry already exists
q := a.conn.
q := a.db.
NewSelect().
Column("instance.id").
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
Where("? = ?", bun.Ident("instance.domain"), host)
exists, err := a.conn.Exists(ctx, q)
exists, err := a.db.Exists(ctx, q)
if err != nil {
return err
}
@ -309,13 +309,13 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
URI: fmt.Sprintf("%s://%s", protocol, host),
}
insertQ := a.conn.
insertQ := a.db.
NewInsert().
Model(i)
_, err = insertQ.Exec(ctx)
if err != nil {
return a.conn.ProcessError(err)
return a.db.ProcessError(err)
}
log.Infof(ctx, "created instance instance %s with id %s", host, i.ID)

View file

@ -28,99 +28,99 @@ import (
)
type basicDB struct {
conn *DBConn
db *WrappedDB
}
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewInsert().Model(i).Exec(ctx)
return b.conn.ProcessError(err)
func (b *basicDB) Put(ctx context.Context, i interface{}) error {
_, err := b.db.NewInsert().Model(i).Exec(ctx)
return b.db.ProcessError(err)
}
func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
q := b.conn.
func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) error {
q := b.db.
NewSelect().
Model(i).
Where("id = ?", id)
err := q.Scan(ctx)
return b.conn.ProcessError(err)
return b.db.ProcessError(err)
}
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.NewSelect().Model(i)
q := b.db.NewSelect().Model(i)
selectWhere(q, where)
err := q.Scan(ctx)
return b.conn.ProcessError(err)
return b.db.ProcessError(err)
}
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
q := b.conn.
func (b *basicDB) GetAll(ctx context.Context, i interface{}) error {
q := b.db.
NewSelect().
Model(i)
err := q.Scan(ctx)
return b.conn.ProcessError(err)
return b.db.ProcessError(err)
}
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
q := b.conn.
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) error {
q := b.db.
NewDelete().
Model(i).
Where("id = ?", id)
_, err := q.Exec(ctx)
return b.conn.ProcessError(err)
return b.db.ProcessError(err)
}
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.
q := b.db.
NewDelete().
Model(i)
deleteWhere(q, where)
_, err := q.Exec(ctx)
return b.conn.ProcessError(err)
return b.db.ProcessError(err)
}
func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) db.Error {
q := b.conn.
func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) error {
q := b.db.
NewUpdate().
Model(i).
Column(columns...).
Where("? = ?", bun.Ident("id"), id)
_, err := q.Exec(ctx)
return b.conn.ProcessError(err)
return b.db.ProcessError(err)
}
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
q := b.conn.NewUpdate().Model(i)
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) error {
q := b.db.NewUpdate().Model(i)
updateWhere(q, where)
q = q.Set("? = ?", bun.Ident(key), value)
_, err := q.Exec(ctx)
return b.conn.ProcessError(err)
return b.db.ProcessError(err)
}
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error {
_, err := b.db.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
return err
}
func (b *basicDB) CreateAllTables(ctx context.Context) db.Error {
func (b *basicDB) CreateAllTables(ctx context.Context) error {
models := []interface{}{
&gtsmodel.Account{},
&gtsmodel.Application{},
@ -154,16 +154,16 @@ func (b *basicDB) CreateAllTables(ctx context.Context) db.Error {
return nil
}
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
return b.conn.ProcessError(err)
func (b *basicDB) DropTable(ctx context.Context, i interface{}) error {
_, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx)
return b.db.ProcessError(err)
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
return b.conn.PingContext(ctx)
func (b *basicDB) IsHealthy(ctx context.Context) error {
return b.db.DB.PingContext(ctx)
}
func (b *basicDB) Stop(ctx context.Context) db.Error {
func (b *basicDB) Stop(ctx context.Context) error {
log.Info(ctx, "closing db connection")
return b.conn.Close()
return b.db.DB.Close()
}

View file

@ -79,13 +79,13 @@ type DBService struct {
db.Timeline
db.User
db.Tombstone
conn *DBConn
db *WrappedDB
}
// GetConn returns the underlying bun connection.
// GetDB returns the underlying database connection pool.
// Should only be used in testing + exceptional circumstance.
func (dbService *DBService) GetConn() *DBConn {
return dbService.conn
func (dbService *DBService) DB() *WrappedDB {
return dbService.db
}
func doMigration(ctx context.Context, db *bun.DB) error {
@ -112,18 +112,18 @@ func doMigration(ctx context.Context, db *bun.DB) error {
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
var conn *DBConn
var db *WrappedDB
var err error
t := strings.ToLower(config.GetDbType())
switch t {
case "postgres":
conn, err = pgConn(ctx)
db, err = pgConn(ctx)
if err != nil {
return nil, err
}
case "sqlite":
conn, err = sqliteConn(ctx)
db, err = sqliteConn(ctx)
if err != nil {
return nil, err
}
@ -132,15 +132,15 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
}
// Add database query hooks.
conn.DB.AddQueryHook(queryHook{})
db.AddQueryHook(queryHook{})
if config.GetTracingEnabled() {
conn.DB.AddQueryHook(tracing.InstrumentBun())
db.AddQueryHook(tracing.InstrumentBun())
}
// execute sqlite pragmas *after* adding database hook;
// this allows the pragma queries to be logged
if t == "sqlite" {
if err := sqlitePragmas(ctx, conn); err != nil {
if err := sqlitePragmas(ctx, db); err != nil {
return nil, err
}
}
@ -148,103 +148,103 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
// table registration is needed for many-to-many, see:
// https://bun.uptrace.dev/orm/many-to-many-relation/
for _, t := range registerTables {
conn.RegisterModel(t)
db.RegisterModel(t)
}
// perform any pending database migrations: this includes
// the very first 'migration' on startup which just creates
// necessary tables
if err := doMigration(ctx, conn.DB); err != nil {
if err := doMigration(ctx, db.DB); err != nil {
return nil, fmt.Errorf("db migration error: %s", err)
}
ps := &DBService{
Account: &accountDB{
conn: conn,
db: db,
state: state,
},
Admin: &adminDB{
conn: conn,
db: db,
state: state,
},
Basic: &basicDB{
conn: conn,
db: db,
},
Domain: &domainDB{
conn: conn,
db: db,
state: state,
},
Emoji: &emojiDB{
conn: conn,
db: db,
state: state,
},
Instance: &instanceDB{
conn: conn,
db: db,
state: state,
},
List: &listDB{
conn: conn,
db: db,
state: state,
},
Media: &mediaDB{
conn: conn,
db: db,
state: state,
},
Mention: &mentionDB{
conn: conn,
db: db,
state: state,
},
Notification: &notificationDB{
conn: conn,
db: db,
state: state,
},
Relationship: &relationshipDB{
conn: conn,
db: db,
state: state,
},
Report: &reportDB{
conn: conn,
db: db,
state: state,
},
Search: &searchDB{
conn: conn,
db: db,
state: state,
},
Session: &sessionDB{
conn: conn,
db: db,
},
Status: &statusDB{
conn: conn,
db: db,
state: state,
},
StatusBookmark: &statusBookmarkDB{
conn: conn,
db: db,
state: state,
},
StatusFave: &statusFaveDB{
conn: conn,
db: db,
state: state,
},
Timeline: &timelineDB{
conn: conn,
db: db,
state: state,
},
User: &userDB{
conn: conn,
db: db,
state: state,
},
Tombstone: &tombstoneDB{
conn: conn,
db: db,
state: state,
},
conn: conn,
db: db,
}
// we can confidently return this useable service now
return ps, nil
}
func pgConn(ctx context.Context) (*DBConn, error) {
func pgConn(ctx context.Context) (*WrappedDB, error) {
opts, err := deriveBunDBPGOptions() //nolint:contextcheck
if err != nil {
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
@ -259,10 +259,10 @@ func pgConn(ctx context.Context) (*DBConn, error) {
sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted
sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections
conn := WrapDBConn(bun.NewDB(sqldb, pgdialect.New()))
conn := WrapDB(bun.NewDB(sqldb, pgdialect.New()))
// ping to check the db is there and listening
if err := conn.PingContext(ctx); err != nil {
if err := conn.DB.PingContext(ctx); err != nil {
return nil, fmt.Errorf("postgres ping: %s", err)
}
@ -270,7 +270,7 @@ func pgConn(ctx context.Context) (*DBConn, error) {
return conn, nil
}
func sqliteConn(ctx context.Context) (*DBConn, error) {
func sqliteConn(ctx context.Context) (*WrappedDB, error) {
// validate db address has actually been set
address := config.GetDbAddress()
if address == "" {
@ -326,15 +326,15 @@ func sqliteConn(ctx context.Context) (*DBConn, error) {
// Tune db connections for sqlite, see:
// - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql
// - https://www.alexedwards.net/blog/configuring-sqldb
sqldb.SetMaxOpenConns(1) // only 1 connection regardless of multiplier, see https://github.com/superseriousbusiness/gotosocial/issues/1407
sqldb.SetMaxOpenConns(maxOpenConns()) // x number of conns per CPU
sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around
sqldb.SetConnMaxLifetime(0) // don't kill connections due to age
// Wrap Bun database conn in our own wrapper
conn := WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()))
conn := WrapDB(bun.NewDB(sqldb, sqlitedialect.New()))
// ping to check the db is there and listening
if err := conn.PingContext(ctx); err != nil {
if err := conn.DB.PingContext(ctx); err != nil {
if errWithCode, ok := err.(*sqlite.Error); ok {
err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
}
@ -445,7 +445,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
// sqlitePragmas sets desired sqlite pragmas based on configured values, and
// logs the results of the pragma queries. Errors if something goes wrong.
func sqlitePragmas(ctx context.Context, conn *DBConn) error {
func sqlitePragmas(ctx context.Context, db *WrappedDB) error {
var pragmas [][]string
if mode := config.GetDbSqliteJournalMode(); mode != "" {
// Set the user provided SQLite journal mode
@ -475,12 +475,12 @@ func sqlitePragmas(ctx context.Context, conn *DBConn) error {
pk := p[0]
pv := p[1]
if _, err := conn.DB.ExecContext(ctx, "PRAGMA ?=?", bun.Ident(pk), bun.Safe(pv)); err != nil {
if _, err := db.ExecContext(ctx, "PRAGMA ?=?", bun.Ident(pk), bun.Safe(pv)); err != nil {
return fmt.Errorf("error executing sqlite pragma %s: %w", pk, err)
}
var res string
if err := conn.DB.NewRaw("PRAGMA ?", bun.Ident(pk)).Scan(ctx, &res); err != nil {
if err := db.NewRaw("PRAGMA ?", bun.Ident(pk)).Scan(ctx, &res); err != nil {
return fmt.Errorf("error scanning sqlite pragma %s: %w", pv, err)
}
@ -502,7 +502,7 @@ func (dbService *DBService) TagStringToTag(ctx context.Context, t string, origin
tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created
if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil && err != sql.ErrNoRows {
if err := dbService.db.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("error getting tag with name %s: %s", t, err)
}

View file

@ -1,113 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
// DBConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
type DBConn struct {
errProc func(error) db.Error // errProc is the SQL-type specific error processor
*bun.DB // DB is the underlying bun.DB connection
}
// WrapDBConn wraps a bun DB connection to provide our own error processing dependent on DB dialect.
func WrapDBConn(dbConn *bun.DB) *DBConn {
var errProc func(error) db.Error
switch dbConn.Dialect().Name() {
case dialect.PG:
errProc = processPostgresError
case dialect.SQLite:
errProc = processSQLiteError
default:
panic("unknown dialect name: " + dbConn.Dialect().Name().String())
}
return &DBConn{
errProc: errProc,
DB: dbConn,
}
}
// RunInTx wraps execution of the supplied transaction function.
func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error {
return conn.ProcessError(func() error {
// Acquire a new transaction
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return err
}
var done bool
defer func() {
if !done {
_ = tx.Rollback()
}
}()
// Perform supplied transaction
if err := fn(tx); err != nil {
return err
}
// Finally, commit
err = tx.Commit() //nolint:contextcheck
done = true
return err
}())
}
// ProcessError processes an error to replace any known values with our own db.Error types,
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
func (conn *DBConn) ProcessError(err error) db.Error {
switch {
case err == nil:
return nil
case err == sql.ErrNoRows:
return db.ErrNoEntries
default:
return conn.errProc(err)
}
}
// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
exists, err := query.Exists(ctx)
// Process error as our own and check if it exists
switch err := conn.ProcessError(err); err {
case nil:
return exists, nil
case db.ErrNoEntries:
return false, nil
default:
return false, err
}
}
// NotExists is the functional opposite of conn.Exists()
func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
exists, err := conn.Exists(ctx, query)
return !exists, err
}

View file

@ -30,11 +30,11 @@ import (
)
type domainDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error {
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error {
// Normalize the domain as punycode
var err error
block.Domain, err = util.Punify(block.Domain)
@ -43,10 +43,10 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
}
// Attempt to store domain block in DB
if _, err := d.conn.NewInsert().
if _, err := d.db.NewInsert().
Model(block).
Exec(ctx); err != nil {
return d.conn.ProcessError(err)
return d.db.ProcessError(err)
}
// Clear the domain block cache (for later reload)
@ -55,7 +55,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
return nil
}
func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) {
func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) {
// Normalize the domain as punycode
domain, err := util.Punify(domain)
if err != nil {
@ -71,12 +71,12 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
var block gtsmodel.DomainBlock
// Look for block matching domain in DB
q := d.conn.
q := d.db.
NewSelect().
Model(&block).
Where("? = ?", bun.Ident("domain_block.domain"), domain)
if err := q.Scan(ctx); err != nil {
return nil, d.conn.ProcessError(err)
return nil, d.db.ProcessError(err)
}
return &block, nil
@ -85,31 +85,31 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) {
blocks := []*gtsmodel.DomainBlock{}
if err := d.conn.
if err := d.db.
NewSelect().
Model(&blocks).
Scan(ctx); err != nil {
return nil, d.conn.ProcessError(err)
return nil, d.db.ProcessError(err)
}
return blocks, nil
}
func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, db.Error) {
func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, error) {
var block gtsmodel.DomainBlock
q := d.conn.
q := d.db.
NewSelect().
Model(&block).
Where("? = ?", bun.Ident("domain_block.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, d.conn.ProcessError(err)
return nil, d.db.ProcessError(err)
}
return &block, nil
}
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
// Normalize the domain as punycode
domain, err := util.Punify(domain)
if err != nil {
@ -117,11 +117,11 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
}
// Attempt to delete domain block
if _, err := d.conn.NewDelete().
if _, err := d.db.NewDelete().
Model((*gtsmodel.DomainBlock)(nil)).
Where("? = ?", bun.Ident("domain_block.domain"), domain).
Exec(ctx); err != nil {
return d.conn.ProcessError(err)
return d.db.ProcessError(err)
}
// Clear the domain block cache (for later reload)
@ -130,7 +130,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
return nil
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, error) {
// Normalize the domain as punycode
domain, err := util.Punify(domain)
if err != nil {
@ -148,18 +148,18 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
var domains []string
// Scan list of all blocked domains from DB
q := d.conn.NewSelect().
q := d.db.NewSelect().
Table("domain_blocks").
Column("domain")
if err := q.Scan(ctx, &domains); err != nil {
return nil, d.conn.ProcessError(err)
return nil, d.db.ProcessError(err)
}
return domains, nil
})
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) {
for _, domain := range domains {
if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
return false, err
@ -170,11 +170,11 @@ func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (boo
return false, nil
}
func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) {
func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, error) {
return d.IsDomainBlocked(ctx, uri.Hostname())
}
func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) {
func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, error) {
for _, uri := range uris {
if blocked, err := d.IsDomainBlocked(ctx, uri.Hostname()); err != nil {
return false, err

View file

@ -34,14 +34,14 @@ import (
)
type emojiDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error {
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error {
return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
_, err := e.conn.NewInsert().Model(emoji).Exec(ctx)
return e.conn.ProcessError(err)
_, err := e.db.NewInsert().Model(emoji).Exec(ctx)
return e.db.ProcessError(err)
})
}
@ -54,17 +54,17 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column
// Update the emoji model in the database.
return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
_, err := e.conn.
_, err := e.db.
NewUpdate().
Model(emoji).
Where("? = ?", bun.Ident("emoji.id"), emoji.ID).
Column(columns...).
Exec(ctx)
return e.conn.ProcessError(err)
return e.db.ProcessError(err)
})
}
func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
var (
accountIDs []string
statusIDs []string
@ -105,7 +105,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
return err
}
return e.conn.RunInTx(ctx, func(tx bun.Tx) error {
return e.db.RunInTx(ctx, func(tx bun.Tx) error {
// delete links between this emoji and any statuses that use it
// TODO: remove when we delete this table
if _, err := tx.
@ -229,7 +229,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, error) {
emojiIDs := []string{}
subQuery := e.conn.
subQuery := e.db.
NewSelect().
ColumnExpr("? AS ?", bun.Ident("emoji.id"), bun.Ident("emoji_ids"))
@ -255,7 +255,7 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable
// "emojis" AS "emoji"
// ORDER BY
// "shortcode_domain" ASC
switch e.conn.Dialect().Name() {
switch e.db.Dialect().Name() {
case dialect.SQLite:
subQuery = subQuery.ColumnExpr("LOWER(? || ? || COALESCE(?, ?)) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain"))
case dialect.PG:
@ -321,12 +321,12 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable
// ORDER BY
// "shortcode_domain" ASC
// ) AS "subquery"
if err := e.conn.
if err := e.db.
NewSelect().
Column("subquery.emoji_ids").
TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")).
Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err)
return nil, e.db.ProcessError(err)
}
if order == "DESC" {
@ -346,7 +346,7 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable
func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) {
var emojiIDs []string
q := e.conn.NewSelect().
q := e.db.NewSelect().
Table("emojis").
Column("id").
Order("id DESC")
@ -360,7 +360,7 @@ func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gt
}
if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err)
return nil, e.db.ProcessError(err)
}
return e.GetEmojisByIDs(ctx, emojiIDs)
@ -369,7 +369,7 @@ func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gt
func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) {
var emojiIDs []string
q := e.conn.NewSelect().
q := e.db.NewSelect().
Table("emojis").
Column("id").
Where("domain IS NOT NULL").
@ -384,7 +384,7 @@ func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int)
}
if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err)
return nil, e.db.ProcessError(err)
}
return e.GetEmojisByIDs(ctx, emojiIDs)
@ -393,7 +393,7 @@ func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int)
func (e *emojiDB) GetCachedEmojisOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.Emoji, error) {
var emojiIDs []string
q := e.conn.NewSelect().
q := e.db.NewSelect().
Table("emojis").
Column("id").
Where("cached = true").
@ -406,16 +406,16 @@ func (e *emojiDB) GetCachedEmojisOlderThan(ctx context.Context, olderThan time.T
}
if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err)
return nil, e.db.ProcessError(err)
}
return e.GetEmojisByIDs(ctx, emojiIDs)
}
func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) {
func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, error) {
emojiIDs := []string{}
q := e.conn.
q := e.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
Column("emoji.id").
@ -425,18 +425,18 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.E
Order("emoji.shortcode ASC")
if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err)
return nil, e.db.ProcessError(err)
}
return e.GetEmojisByIDs(ctx, emojiIDs)
}
func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) {
func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, error) {
return e.getEmoji(
ctx,
"ID",
func(emoji *gtsmodel.Emoji) error {
return e.conn.
return e.db.
NewSelect().
Model(emoji).
Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
@ -445,12 +445,12 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,
)
}
func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) {
func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, error) {
return e.getEmoji(
ctx,
"URI",
func(emoji *gtsmodel.Emoji) error {
return e.conn.
return e.db.
NewSelect().
Model(emoji).
Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
@ -459,12 +459,12 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
)
}
func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) {
func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) {
return e.getEmoji(
ctx,
"Shortcode.Domain",
func(emoji *gtsmodel.Emoji) error {
q := e.conn.
q := e.db.
NewSelect().
Model(emoji)
@ -483,12 +483,12 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
)
}
func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) {
func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, error) {
return e.getEmoji(
ctx,
"ImageStaticURL",
func(emoji *gtsmodel.Emoji) error {
return e.conn.
return e.db.
NewSelect().
Model(emoji).
Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL).
@ -498,35 +498,35 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string
)
}
func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error {
func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error {
return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error {
_, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx)
return e.conn.ProcessError(err)
_, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx)
return e.db.ProcessError(err)
})
}
func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) {
func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, error) {
emojiCategoryIDs := []string{}
q := e.conn.
q := e.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("emoji_categories"), bun.Ident("emoji_category")).
Column("emoji_category.id").
Order("emoji_category.name ASC")
if err := q.Scan(ctx, &emojiCategoryIDs); err != nil {
return nil, e.conn.ProcessError(err)
return nil, e.db.ProcessError(err)
}
return e.GetEmojiCategoriesByIDs(ctx, emojiCategoryIDs)
}
func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) {
func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, error) {
return e.getEmojiCategory(
ctx,
"ID",
func(emojiCategory *gtsmodel.EmojiCategory) error {
return e.conn.
return e.db.
NewSelect().
Model(emojiCategory).
Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx)
@ -535,12 +535,12 @@ func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.Em
)
}
func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) {
func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) {
return e.getEmojiCategory(
ctx,
"Name",
func(emojiCategory *gtsmodel.EmojiCategory) error {
return e.conn.
return e.db.
NewSelect().
Model(emojiCategory).
Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx)
@ -549,14 +549,14 @@ func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gts
)
}
func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) {
func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) {
// Fetch emoji from database cache with loader callback
emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) {
var emoji gtsmodel.Emoji
// Not cached! Perform database query
if err := dbQuery(&emoji); err != nil {
return nil, e.conn.ProcessError(err)
return nil, e.db.ProcessError(err)
}
return &emoji, nil
@ -580,7 +580,7 @@ func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gts
return emoji, nil
}
func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) {
func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) {
if len(emojiIDs) == 0 {
return nil, db.ErrNoEntries
}
@ -600,20 +600,20 @@ func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gts
return emojis, nil
}
func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) {
func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) {
return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) {
var category gtsmodel.EmojiCategory
// Not cached! Perform database query
if err := dbQuery(&category); err != nil {
return nil, e.conn.ProcessError(err)
return nil, e.db.ProcessError(err)
}
return &category, nil
}, keyParts...)
}
func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) {
func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, error) {
if len(emojiCategoryIDs) == 0 {
return nil, db.ErrNoEntries
}

View file

@ -18,14 +18,20 @@
package bundb
import (
"errors"
"github.com/jackc/pgconn"
"github.com/superseriousbusiness/gotosocial/internal/db"
"modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib"
)
// errBusy is a sentinel error indicating
// busy database (e.g. retry needed).
var errBusy = errors.New("busy")
// processPostgresError processes an error, replacing any postgres specific errors with our own error type
func processPostgresError(err error) db.Error {
func processPostgresError(err error) error {
// Attempt to cast as postgres
pgErr, ok := err.(*pgconn.PgError)
if !ok {
@ -34,16 +40,16 @@ func processPostgresError(err error) db.Error {
// Handle supplied error code:
// (https://www.postgresql.org/docs/10/errcodes-appendix.html)
switch pgErr.Code {
switch pgErr.Code { //nolint
case "23505" /* unique_violation */ :
return db.ErrAlreadyExists
default:
return err
}
return err
}
// processSQLiteError processes an error, replacing any sqlite specific errors with our own error type
func processSQLiteError(err error) db.Error {
func processSQLiteError(err error) error {
// Attempt to cast as sqlite
sqliteErr, ok := err.(*sqlite.Error)
if !ok {
@ -55,7 +61,11 @@ func processSQLiteError(err error) db.Error {
case sqlite3.SQLITE_CONSTRAINT_UNIQUE,
sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY:
return db.ErrAlreadyExists
default:
return err
case sqlite3.SQLITE_BUSY:
return errBusy
case sqlite3.SQLITE_BUSY_TIMEOUT:
return db.ErrBusyTimeout
}
return err
}

View file

@ -34,12 +34,12 @@ import (
)
type instanceDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, error) {
q := i.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Column("account.id").
@ -56,13 +56,13 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
count, err := q.Count(ctx)
if err != nil {
return 0, i.conn.ProcessError(err)
return 0, i.db.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, error) {
q := i.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status"))
@ -78,13 +78,13 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
count, err := q.Count(ctx)
if err != nil {
return 0, i.conn.ProcessError(err)
return 0, i.db.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, error) {
q := i.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance"))
@ -101,12 +101,12 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
count, err := q.Count(ctx)
if err != nil {
return 0, i.conn.ProcessError(err)
return 0, i.db.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) {
func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, error) {
// Normalize the domain as punycode
var err error
domain, err = util.Punify(domain)
@ -118,7 +118,7 @@ func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.
ctx,
"Domain",
func(instance *gtsmodel.Instance) error {
return i.conn.NewSelect().
return i.db.NewSelect().
Model(instance).
Where("? = ?", bun.Ident("instance.domain"), domain).
Scan(ctx)
@ -132,7 +132,7 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.
ctx,
"ID",
func(instance *gtsmodel.Instance) error {
return i.conn.NewSelect().
return i.db.NewSelect().
Model(instance).
Where("? = ?", bun.Ident("instance.id"), id).
Scan(ctx)
@ -141,14 +141,14 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.
)
}
func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, db.Error) {
func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) {
// Fetch instance from database cache with loader callback
instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) {
var instance gtsmodel.Instance
// Not cached! Perform database query.
if err := dbQuery(&instance); err != nil {
return nil, i.conn.ProcessError(err)
return nil, i.db.ProcessError(err)
}
return &instance, nil
@ -210,8 +210,8 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc
}
return i.state.Caches.GTS.Instance().Store(instance, func() error {
_, err := i.conn.NewInsert().Model(instance).Exec(ctx)
return i.conn.ProcessError(err)
_, err := i.db.NewInsert().Model(instance).Exec(ctx)
return i.db.ProcessError(err)
})
}
@ -230,20 +230,20 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst
}
return i.state.Caches.GTS.Instance().Store(instance, func() error {
_, err := i.conn.
_, err := i.db.
NewUpdate().
Model(instance).
Where("? = ?", bun.Ident("instance.id"), instance.ID).
Column(columns...).
Exec(ctx)
return i.conn.ProcessError(err)
return i.db.ProcessError(err)
})
}
func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) {
func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, error) {
instanceIDs := []string{}
q := i.conn.
q := i.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
// Select just the IDs of each instance.
@ -256,7 +256,7 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
}
if err := q.Scan(ctx, &instanceIDs); err != nil {
return nil, i.conn.ProcessError(err)
return nil, i.db.ProcessError(err)
}
if len(instanceIDs) == 0 {
@ -280,7 +280,7 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
return instances, nil
}
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -296,7 +296,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
// Make educated guess for slice size
accountIDs := make([]string, 0, limit)
q := i.conn.
q := i.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
// Select just the account ID.
@ -315,7 +315,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
}
if err := q.Scan(ctx, &accountIDs); err != nil {
return nil, i.conn.ProcessError(err)
return nil, i.db.ProcessError(err)
}
// Catch case of no accounts early.
@ -340,13 +340,13 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
return accounts, nil
}
func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]string, db.Error) {
func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]string, error) {
addresses := []string{}
// Select email addresses of approved, confirmed,
// and enabled moderators or admins.
q := i.conn.
q := i.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
Column("user.email").
@ -361,7 +361,7 @@ func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]strin
OrderExpr("? ASC", bun.Ident("user.email"))
if err := q.Scan(ctx, &addresses); err != nil {
return nil, i.conn.ProcessError(err)
return nil, i.db.ProcessError(err)
}
if len(addresses) == 0 {

View file

@ -33,7 +33,7 @@ import (
)
type listDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
@ -46,7 +46,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er
ctx,
"ID",
func(list *gtsmodel.List) error {
return l.conn.NewSelect().
return l.db.NewSelect().
Model(list).
Where("? = ?", bun.Ident("list.id"), id).
Scan(ctx)
@ -61,7 +61,7 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo
// Not cached! Perform database query.
if err := dbQuery(&list); err != nil {
return nil, l.conn.ProcessError(err)
return nil, l.db.ProcessError(err)
}
return &list, nil
@ -86,14 +86,14 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo
func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) {
// Fetch IDs of all lists owned by this account.
var listIDs []string
if err := l.conn.
if err := l.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")).
Column("list.id").
Where("? = ?", bun.Ident("list.account_id"), accountID).
Order("list.id DESC").
Scan(ctx, &listIDs); err != nil {
return nil, l.conn.ProcessError(err)
return nil, l.db.ProcessError(err)
}
if len(listIDs) == 0 {
@ -148,8 +148,8 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
return l.state.Caches.GTS.List().Store(list, func() error {
_, err := l.conn.NewInsert().Model(list).Exec(ctx)
return l.conn.ProcessError(err)
_, err := l.db.NewInsert().Model(list).Exec(ctx)
return l.db.ProcessError(err)
})
}
@ -171,12 +171,12 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
}()
return l.state.Caches.GTS.List().Store(list, func() error {
_, err := l.conn.NewUpdate().
_, err := l.db.NewUpdate().
Model(list).
Where("? = ?", bun.Ident("list.id"), list.ID).
Column(columns...).
Exec(ctx)
return l.conn.ProcessError(err)
return l.db.ProcessError(err)
})
}
@ -207,7 +207,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
}
}()
return l.conn.RunInTx(ctx, func(tx bun.Tx) error {
return l.db.RunInTx(ctx, func(tx bun.Tx) error {
// Delete all entries attached to list.
if _, err := tx.NewDelete().
Table("list_entries").
@ -234,7 +234,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis
ctx,
"ID",
func(listEntry *gtsmodel.ListEntry) error {
return l.conn.NewSelect().
return l.db.NewSelect().
Model(listEntry).
Where("? = ?", bun.Ident("list_entry.id"), id).
Scan(ctx)
@ -249,7 +249,7 @@ func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*
// Not cached! Perform database query.
if err := dbQuery(&listEntry); err != nil {
return nil, l.conn.ProcessError(err)
return nil, l.db.ProcessError(err)
}
return &listEntry, nil
@ -289,7 +289,7 @@ func (l *listDB) GetListEntries(ctx context.Context,
frontToBack = true
)
q := l.conn.
q := l.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
// Select only IDs from table
@ -329,7 +329,7 @@ func (l *listDB) GetListEntries(ctx context.Context,
}
if err := q.Scan(ctx, &entryIDs); err != nil {
return nil, l.conn.ProcessError(err)
return nil, l.db.ProcessError(err)
}
if len(entryIDs) == 0 {
@ -362,7 +362,7 @@ func (l *listDB) GetListEntries(ctx context.Context,
func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
var entryIDs []string
if err := l.conn.
if err := l.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
// Select only IDs from table
@ -370,7 +370,7 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string)
// Select only entries belonging with given followID.
Where("? = ?", bun.Ident("entry.follow_id"), followID).
Scan(ctx, &entryIDs); err != nil {
return nil, l.conn.ProcessError(err)
return nil, l.db.ProcessError(err)
}
if len(entryIDs) == 0 {
@ -424,7 +424,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt
}()
// Finally, insert each list entry into the database.
return l.conn.RunInTx(ctx, func(tx bun.Tx) error {
return l.db.RunInTx(ctx, func(tx bun.Tx) error {
for _, entry := range entries {
if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error {
_, err := tx.
@ -468,7 +468,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
}()
// Finally delete the list entry.
_, err = l.conn.NewDelete().
_, err = l.db.NewDelete().
Table("list_entries").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
@ -479,14 +479,14 @@ func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID stri
var entryIDs []string
// Fetch entry IDs for follow ID.
if err := l.conn.
if err := l.db.
NewSelect().
Table("list_entries").
Column("id").
Where("? = ?", bun.Ident("follow_id"), followID).
Order("id DESC").
Scan(ctx, &entryIDs); err != nil {
return l.conn.ProcessError(err)
return l.db.ProcessError(err)
}
for _, id := range entryIDs {

View file

@ -32,16 +32,16 @@ import (
)
type mediaDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) {
func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, error) {
return m.getAttachment(
ctx,
"ID",
func(attachment *gtsmodel.MediaAttachment) error {
return m.conn.NewSelect().
return m.db.NewSelect().
Model(attachment).
Where("? = ?", bun.Ident("media_attachment.id"), id).
Scan(ctx)
@ -68,13 +68,13 @@ func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gts
return attachments, nil
}
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) {
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) {
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment
// Not cached! Perform database query
if err := dbQuery(&attachment); err != nil {
return nil, m.conn.ProcessError(err)
return nil, m.db.ProcessError(err)
}
return &attachment, nil
@ -83,8 +83,8 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func
func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error {
return m.state.Caches.GTS.Media().Store(media, func() error {
_, err := m.conn.NewInsert().Model(media).Exec(ctx)
return m.conn.ProcessError(err)
_, err := m.db.NewInsert().Model(media).Exec(ctx)
return m.db.ProcessError(err)
})
}
@ -96,12 +96,12 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt
}
return m.state.Caches.GTS.Media().Store(media, func() error {
_, err := m.conn.NewUpdate().
_, err := m.db.NewUpdate().
Model(media).
Where("? = ?", bun.Ident("media_attachment.id"), media.ID).
Column(columns...).
Exec(ctx)
return m.conn.ProcessError(err)
return m.db.ProcessError(err)
})
}
@ -126,7 +126,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
)
// Delete media attachment in new transaction.
err = m.conn.RunInTx(ctx, func(tx bun.Tx) error {
err = m.db.RunInTx(ctx, func(tx bun.Tx) error {
if media.AccountID != "" {
var account gtsmodel.Account
@ -229,11 +229,11 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID)
}
return m.conn.ProcessError(err)
return m.db.ProcessError(err)
}
func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
q := m.conn.
func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, error) {
q := m.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
Column("media_attachment.id").
@ -243,7 +243,7 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time)
count, err := q.Count(ctx)
if err != nil {
return 0, m.conn.ProcessError(err)
return 0, m.db.ProcessError(err)
}
return count, nil
@ -252,7 +252,7 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time)
func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit)
q := m.conn.NewSelect().
q := m.db.NewSelect().
Table("media_attachments").
Column("id").
Order("id DESC")
@ -266,7 +266,7 @@ func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) (
}
if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err)
return nil, m.db.ProcessError(err)
}
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
@ -275,7 +275,7 @@ func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) (
func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit)
q := m.conn.NewSelect().
q := m.db.NewSelect().
Table("media_attachments").
Column("id").
Where("remote_url IS NOT NULL").
@ -290,16 +290,16 @@ func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit
}
if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err)
return nil, m.db.ProcessError(err)
}
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit)
q := m.conn.
q := m.db.
NewSelect().
Table("media_attachments").
Column("id").
@ -313,16 +313,16 @@ func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan t
}
if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err)
return nil, m.db.ProcessError(err)
}
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit)
q := m.conn.NewSelect().
q := m.db.NewSelect().
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
Column("media_attachment.id").
WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {
@ -341,16 +341,16 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
}
if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err)
return nil, m.db.ProcessError(err)
}
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit)
q := m.conn.
q := m.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
Column("media_attachment.id").
@ -367,14 +367,14 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
}
if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err)
return nil, m.db.ProcessError(err)
}
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
q := m.conn.
func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, error) {
q := m.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
Column("media_attachment.id").
@ -387,7 +387,7 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t
count, err := q.Count(ctx)
if err != nil {
return 0, m.conn.ProcessError(err)
return 0, m.db.ProcessError(err)
}
return count, nil

View file

@ -31,21 +31,21 @@ import (
)
type mentionDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) {
mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention
q := m.conn.
q := m.db.
NewSelect().
Model(&mention).
Where("? = ?", bun.Ident("mention.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, m.conn.ProcessError(err)
return nil, m.db.ProcessError(err)
}
return &mention, nil
@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
return mention, nil
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) {
mentions := make([]*gtsmodel.Mention, 0, len(ids))
for _, id := range ids {
@ -104,8 +104,8 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
return m.state.Caches.GTS.Mention().Store(mention, func() error {
_, err := m.conn.NewInsert().Model(mention).Exec(ctx)
return m.conn.ProcessError(err)
_, err := m.db.NewInsert().Model(mention).Exec(ctx)
return m.db.ProcessError(err)
})
}
@ -125,9 +125,9 @@ func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
}
// Finally delete mention from DB.
_, err = m.conn.NewDelete().
_, err = m.db.NewDelete().
Table("mentions").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return m.conn.ProcessError(err)
return m.db.ProcessError(err)
}

View file

@ -31,19 +31,19 @@ import (
)
type notificationDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) {
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
q := n.conn.NewSelect().
q := n.db.NewSelect().
Model(&notif).
Where("? = ?", bun.Ident("notification.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, n.conn.ProcessError(err)
return nil, n.db.ProcessError(err)
}
return &notif, nil
@ -56,11 +56,11 @@ func (n *notificationDB) GetNotification(
targetAccountID string,
originAccountID string,
statusID string,
) (*gtsmodel.Notification, db.Error) {
) (*gtsmodel.Notification, error) {
return n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
q := n.conn.NewSelect().
q := n.db.NewSelect().
Model(&notif).
Where("? = ?", bun.Ident("notification_type"), notificationType).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
@ -68,7 +68,7 @@ func (n *notificationDB) GetNotification(
Where("? = ?", bun.Ident("status_id"), statusID)
if err := q.Scan(ctx); err != nil {
return nil, n.conn.ProcessError(err)
return nil, n.db.ProcessError(err)
}
return &notif, nil
@ -83,7 +83,7 @@ func (n *notificationDB) GetAccountNotifications(
minID string,
limit int,
excludeTypes []string,
) ([]*gtsmodel.Notification, db.Error) {
) ([]*gtsmodel.Notification, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -95,7 +95,7 @@ func (n *notificationDB) GetAccountNotifications(
frontToBack = true
)
q := n.conn.
q := n.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Column("notification.id")
@ -140,7 +140,7 @@ func (n *notificationDB) GetAccountNotifications(
}
if err := q.Scan(ctx, &notifIDs); err != nil {
return nil, n.conn.ProcessError(err)
return nil, n.db.ProcessError(err)
}
if len(notifIDs) == 0 {
@ -174,12 +174,12 @@ func (n *notificationDB) GetAccountNotifications(
func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
return n.state.Caches.GTS.Notification().Store(notif, func() error {
_, err := n.conn.NewInsert().Model(notif).Exec(ctx)
return n.conn.ProcessError(err)
_, err := n.db.NewInsert().Model(notif).Exec(ctx)
return n.db.ProcessError(err)
})
}
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error {
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error {
defer n.state.Caches.GTS.Notification().Invalidate("ID", id)
// Load notif into cache before attempting a delete,
@ -195,21 +195,21 @@ func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string)
}
// Finally delete notif from DB.
_, err = n.conn.NewDelete().
_, err = n.db.NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Where("? = ?", bun.Ident("notification.id"), id).
Exec(ctx)
return n.conn.ProcessError(err)
return n.db.ProcessError(err)
}
func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error {
func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error {
if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set")
}
var notifIDs []string
q := n.conn.
q := n.db.
NewSelect().
Column("id").
Table("notifications")
@ -227,7 +227,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
}
if _, err := q.Exec(ctx, &notifIDs); err != nil {
return n.conn.ProcessError(err)
return n.db.ProcessError(err)
}
defer func() {
@ -248,24 +248,24 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
}
// Finally delete all from DB.
_, err := n.conn.NewDelete().
_, err := n.db.NewDelete().
Table("notifications").
Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)).
Exec(ctx)
return n.conn.ProcessError(err)
return n.db.ProcessError(err)
}
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error {
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) error {
var notifIDs []string
q := n.conn.
q := n.db.
NewSelect().
Column("id").
Table("notifications").
Where("? = ?", bun.Ident("status_id"), statusID)
if _, err := q.Exec(ctx, &notifIDs); err != nil {
return n.conn.ProcessError(err)
return n.db.ProcessError(err)
}
defer func() {
@ -286,9 +286,9 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu
}
// Finally delete all from DB.
_, err := n.conn.NewDelete().
_, err := n.db.NewDelete().
Table("notifications").
Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)).
Exec(ctx)
return n.conn.ProcessError(err)
return n.db.ProcessError(err)
}

View file

@ -30,11 +30,11 @@ import (
)
type relationshipDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
var rel gtsmodel.Relationship
rel.ID = targetAccount
@ -90,91 +90,91 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectFollows(r.conn, accountID).
if err := newSelectFollows(r.db, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectLocalFollows(r.conn, accountID).
if err := newSelectLocalFollows(r.db, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectFollowers(r.conn, accountID).
if err := newSelectFollowers(r.db, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectLocalFollowers(r.conn, accountID).
if err := newSelectLocalFollowers(r.db, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollows(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
n, err := newSelectFollows(r.db, accountID).Count(ctx)
return n, r.db.ProcessError(err)
}
func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
n, err := newSelectLocalFollows(r.db, accountID).Count(ctx)
return n, r.db.ProcessError(err)
}
func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowers(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
n, err := newSelectFollowers(r.db, accountID).Count(ctx)
return n, r.db.ProcessError(err)
}
func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
n, err := newSelectLocalFollowers(r.db, accountID).Count(ctx)
return n, r.db.ProcessError(err)
}
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
var followReqIDs []string
if err := newSelectFollowRequests(r.conn, accountID).
if err := newSelectFollowRequests(r.db, accountID).
Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
var followReqIDs []string
if err := newSelectFollowRequesting(r.conn, accountID).
if err := newSelectFollowRequesting(r.db, accountID).
Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
n, err := newSelectFollowRequests(r.db, accountID).Count(ctx)
return n, r.db.ProcessError(err)
}
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
n, err := newSelectFollowRequesting(r.db, accountID).Count(ctx)
return n, r.db.ProcessError(err)
}
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
func newSelectFollowRequests(db *WrappedDB, accountID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID).
@ -182,8 +182,8 @@ func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
}
// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
func newSelectFollowRequesting(db *WrappedDB, accountID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID).
@ -191,8 +191,8 @@ func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery
}
// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
func newSelectFollows(db *WrappedDB, accountID string) *bun.SelectQuery {
return db.NewSelect().
Table("follows").
Column("id").
Where("? = ?", bun.Ident("account_id"), accountID).
@ -201,15 +201,15 @@ func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
// newSelectLocalFollows returns a new select query for all rows in the follows table with
// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
func newSelectLocalFollows(db *WrappedDB, accountID string) *bun.SelectQuery {
return db.NewSelect().
Table("follows").
Column("id").
Where("? = ? AND ? IN (?)",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
conn.NewSelect().
db.NewSelect().
Table("accounts").
Column("id").
Where("? IS NULL", bun.Ident("domain")),
@ -218,8 +218,8 @@ func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
}
// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
func newSelectFollowers(db *WrappedDB, accountID string) *bun.SelectQuery {
return db.NewSelect().
Table("follows").
Column("id").
Where("? = ?", bun.Ident("target_account_id"), accountID).
@ -228,15 +228,15 @@ func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
// newSelectLocalFollowers returns a new select query for all rows in the follows table with
// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery {
return db.NewSelect().
Table("follows").
Column("id").
Where("? = ? AND ? IN (?)",
bun.Ident("target_account_id"),
accountID,
bun.Ident("account_id"),
conn.NewSelect().
db.NewSelect().
Table("accounts").
Column("id").
Where("? IS NULL", bun.Ident("domain")),

View file

@ -28,7 +28,7 @@ import (
"github.com/uptrace/bun"
)
func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
block, err := r.GetBlock(
gtscontext.SetBarebones(ctx),
sourceAccountID,
@ -61,7 +61,7 @@ func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel
ctx,
"ID",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.id"), id).
Scan(ctx)
},
@ -74,7 +74,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod
ctx,
"URI",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.uri"), uri).
Scan(ctx)
},
@ -87,7 +87,7 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t
ctx,
"AccountID.TargetAccountID",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID).
Scan(ctx)
@ -104,7 +104,7 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu
// Not cached! Perform database query
if err := dbQuery(&block); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return &block, nil
@ -142,8 +142,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
return r.state.Caches.GTS.Block().Store(block, func() error {
_, err := r.conn.NewInsert().Model(block).Exec(ctx)
return r.conn.ProcessError(err)
_, err := r.db.NewInsert().Model(block).Exec(ctx)
return r.db.ProcessError(err)
})
}
@ -163,11 +163,11 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
}
// Finally delete block from DB.
_, err = r.conn.NewDelete().
_, err = r.db.NewDelete().
Table("blocks").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
@ -186,18 +186,18 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error
}
// Finally delete block from DB.
_, err = r.conn.NewDelete().
_, err = r.db.NewDelete().
Table("blocks").
Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error {
var blockIDs []string
// Get full list of IDs.
if err := r.conn.NewSelect().
if err := r.db.NewSelect().
Column("id").
Table("blocks").
WhereOr("? = ? OR ? = ?",
@ -207,7 +207,7 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
accountID,
).
Scan(ctx, &blockIDs); err != nil {
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
defer func() {
@ -228,9 +228,9 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
}
// Finally delete all from DB.
_, err := r.conn.NewDelete().
_, err := r.db.NewDelete().
Table("blocks").
Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}

View file

@ -36,7 +36,7 @@ func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmode
ctx,
"ID",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
return r.db.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
@ -50,7 +50,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo
ctx,
"URI",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
return r.db.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("uri"), uri).
Scan(ctx)
@ -64,7 +64,7 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string,
ctx,
"AccountID.TargetAccountID",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
return r.db.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
@ -94,7 +94,7 @@ func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*
return follows, nil
}
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
follow, err := r.GetFollow(
gtscontext.SetBarebones(ctx),
sourceAccountID,
@ -106,7 +106,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string
return (follow != nil), nil
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) {
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, error) {
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx,
accountID1,
@ -135,7 +135,7 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f
// Not cached! Perform database query
if err := dbQuery(&follow); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return &follow, nil
@ -190,8 +190,8 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
return r.state.Caches.GTS.Follow().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
return r.conn.ProcessError(err)
_, err := r.db.NewInsert().Model(follow).Exec(ctx)
return r.db.ProcessError(err)
})
}
@ -203,12 +203,12 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
}
return r.state.Caches.GTS.Follow().Store(follow, func() error {
if _, err := r.conn.NewUpdate().
if _, err := r.db.NewUpdate().
Model(follow).
Where("? = ?", bun.Ident("follow.id"), follow.ID).
Column(columns...).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
return nil
@ -217,11 +217,11 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {
// Delete the follow itself using the given ID.
if _, err := r.conn.NewDelete().
if _, err := r.db.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
// Delete every list entry that used this followID.
@ -297,7 +297,7 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
var followIDs []string
// Get full list of IDs.
if _, err := r.conn.
if _, err := r.db.
NewSelect().
Column("id").
Table("follows").
@ -308,7 +308,7 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
accountID,
).
Exec(ctx, &followIDs); err != nil {
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
defer func() {

View file

@ -35,7 +35,7 @@ func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*
ctx,
"ID",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
return r.db.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
@ -49,7 +49,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string)
ctx,
"URI",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
return r.db.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("uri"), uri).
Scan(ctx)
@ -63,7 +63,7 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s
ctx,
"AccountID.TargetAccountID",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
return r.db.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
@ -93,7 +93,7 @@ func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []strin
return followReqs, nil
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
followReq, err := r.GetFollowRequest(
gtscontext.SetBarebones(ctx),
sourceAccountID,
@ -112,7 +112,7 @@ func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, db
// Not cached! Perform database query
if err := dbQuery(&followReq); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return &followReq, nil
@ -150,8 +150,8 @@ func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, db
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
return r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
return r.conn.ProcessError(err)
_, err := r.db.NewInsert().Model(follow).Exec(ctx)
return r.db.ProcessError(err)
})
}
@ -163,19 +163,19 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest
}
return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error {
if _, err := r.conn.NewUpdate().
if _, err := r.db.NewUpdate().
Model(followRequest).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
Column(columns...).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
return nil
})
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
// Get original follow request.
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
if err != nil {
@ -198,12 +198,12 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
if err := r.state.Caches.GTS.Follow().Store(follow, func() error {
// If the follow already exists, just
// replace the URI with the new one.
_, err := r.conn.
_, err := r.db.
NewInsert().
Model(follow).
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}); err != nil {
return nil, err
}
@ -212,12 +212,12 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
// Delete original follow request.
if _, err := r.conn.
if _, err := r.db.
NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), followReq.ID).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
// Delete original follow request notification
@ -230,7 +230,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
return follow, nil
}
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error {
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) error {
// Delete follow request first.
if err := r.DeleteFollowRequest(ctx, sourceAccountID, targetAccountID); err != nil {
return err
@ -262,11 +262,11 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI
}
// Finally delete followreq from DB.
_, err = r.conn.NewDelete().
_, err = r.db.NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), follow.ID).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
@ -285,11 +285,11 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)
}
// Finally delete followreq from DB.
_, err = r.conn.NewDelete().
_, err = r.db.NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
@ -308,18 +308,18 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin
}
// Finally delete followreq from DB.
_, err = r.conn.NewDelete().
_, err = r.db.NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
var followReqIDs []string
// Get full list of IDs.
if _, err := r.conn.
if _, err := r.db.
NewSelect().
Column("id").
Table("follow_requestss").
@ -330,7 +330,7 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
accountID,
).
Exec(ctx, &followReqIDs); err != nil {
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}
defer func() {
@ -351,9 +351,9 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
}
// Finally delete all from DB.
_, err := r.conn.NewDelete().
_, err := r.db.NewDelete().
Table("follow_requests").
Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}

View file

@ -32,15 +32,15 @@ import (
)
type reportDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (r *reportDB) newReportQ(report interface{}) *bun.SelectQuery {
return r.conn.NewSelect().Model(report)
return r.db.NewSelect().Model(report)
}
func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, db.Error) {
func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, error) {
return r.getReport(
ctx,
"ID",
@ -51,10 +51,10 @@ func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Repo
)
}
func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, db.Error) {
func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, error) {
reportIDs := []string{}
q := r.conn.
q := r.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")).
Column("report.id").
@ -94,7 +94,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str
}
if err := q.Scan(ctx, &reportIDs); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
// Catch case of no reports early
@ -118,14 +118,14 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str
return reports, nil
}
func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, db.Error) {
func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) {
// Fetch report from database cache with loader callback
report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) {
var report gtsmodel.Report
// Not cached! Perform database query
if err := dbQuery(&report); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
return &report, nil
@ -166,34 +166,34 @@ func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*g
return report, nil
}
func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) db.Error {
func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error {
return r.state.Caches.GTS.Report().Store(report, func() error {
_, err := r.conn.NewInsert().Model(report).Exec(ctx)
return r.conn.ProcessError(err)
_, err := r.db.NewInsert().Model(report).Exec(ctx)
return r.db.ProcessError(err)
})
}
func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, db.Error) {
func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, error) {
// Update the report's last-updated
report.UpdatedAt = time.Now()
if len(columns) != 0 {
columns = append(columns, "updated_at")
}
if _, err := r.conn.
if _, err := r.db.
NewUpdate().
Model(report).
Where("? = ?", bun.Ident("report.id"), report.ID).
Column(columns...).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
return nil, r.db.ProcessError(err)
}
r.state.Caches.GTS.Report().Invalidate("ID", report.ID)
return report, nil
}
func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error {
func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error {
defer r.state.Caches.GTS.Report().Invalidate("ID", id)
// Load status into cache before attempting a delete,
@ -209,9 +209,9 @@ func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error {
}
// Finally delete report from DB.
_, err = r.conn.NewDelete().
_, err = r.db.NewDelete().
TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")).
Where("? = ?", bun.Ident("report.id"), id).
Exec(ctx)
return r.conn.ProcessError(err)
return r.db.ProcessError(err)
}

View file

@ -56,7 +56,7 @@ import (
// This isn't ideal, of course, but at least we could cover the most common use case of
// a caller paging down through results.
type searchDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
@ -89,7 +89,7 @@ func (s *searchDB) SearchForAccounts(
frontToBack = true
)
q := s.conn.
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
// Select only IDs from table.
@ -148,7 +148,7 @@ func (s *searchDB) SearchForAccounts(
}
if err := q.Scan(ctx, &accountIDs); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
if len(accountIDs) == 0 {
@ -183,7 +183,7 @@ func (s *searchDB) SearchForAccounts(
// followedAccounts returns a subquery that selects only IDs
// of accounts that are followed by the given accountID.
func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery {
return s.conn.
return s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.target_account_id").
@ -196,7 +196,7 @@ func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery {
// in the concatenation.
func (s *searchDB) accountText(following bool) *bun.SelectQuery {
var (
accountText = s.conn.NewSelect()
accountText = s.db.NewSelect()
query string
args []interface{}
)
@ -225,7 +225,7 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery {
// different number of placeholders depending on
// following/not following. COALESCE calls ensure
// that we're not trying to concatenate null values.
d := s.conn.Dialect().Name()
d := s.db.Dialect().Name()
switch {
case d == dialect.SQLite && following:
@ -276,7 +276,7 @@ func (s *searchDB) SearchForStatuses(
frontToBack = true
)
q := s.conn.
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
@ -326,7 +326,7 @@ func (s *searchDB) SearchForStatuses(
}
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
if len(statusIDs) == 0 {
@ -361,11 +361,11 @@ func (s *searchDB) SearchForStatuses(
// statusText returns a subquery that selects a concatenation
// of status content and content warning as "status_text".
func (s *searchDB) statusText() *bun.SelectQuery {
statusText := s.conn.NewSelect()
statusText := s.db.NewSelect()
// SQLite and Postgres use different
// syntaxes for concatenation.
switch s.conn.Dialect().Name() {
switch s.db.Dialect().Name() {
case dialect.SQLite:
statusText = statusText.ColumnExpr(

View file

@ -22,26 +22,25 @@ import (
"crypto/rand"
"io"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
type sessionDB struct {
conn *DBConn
db *WrappedDB
}
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) {
rss := make([]*gtsmodel.RouterSession, 0, 1)
// get the first router session in the db or...
if err := s.conn.
if err := s.db.
NewSelect().
Model(&rss).
Limit(1).
Order("router_session.id DESC").
Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
// ... create a new one
@ -52,7 +51,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db
return rss[0], nil
}
func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, error) {
buf := make([]byte, 64)
auth := buf[:32]
crypt := buf[32:64]
@ -67,11 +66,11 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession,
Crypt: crypt,
}
if _, err := s.conn.
if _, err := s.db.
NewInsert().
Model(rs).
Exec(ctx); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
return rs, nil

View file

@ -35,19 +35,19 @@ import (
)
type statusDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
return s.db.
NewSelect().
Model(status).
Relation("Tags").
Relation("CreatedWithApplication")
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) {
return s.getStatus(
ctx,
"ID",
@ -76,7 +76,7 @@ func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmo
return statuses, nil
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, error) {
return s.getStatus(
ctx,
"URI",
@ -87,7 +87,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
)
}
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, error) {
return s.getStatus(
ctx,
"URL",
@ -98,14 +98,14 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St
)
}
func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) {
func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) {
// Fetch status from database cache with loader callback
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status
// Not cached! Perform database query.
if err := dbQuery(&status); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
return &status, nil
@ -243,12 +243,12 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
return errs.Combine()
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error {
return s.state.Caches.GTS.Status().Store(status, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
return s.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.
@ -259,7 +259,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
}).
On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")).
Exec(ctx); err != nil {
err = s.conn.ProcessError(err)
err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
}
@ -276,7 +276,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
}).
On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")).
Exec(ctx); err != nil {
err = s.conn.ProcessError(err)
err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
}
@ -292,7 +292,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
Model(a).
Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil {
err = s.conn.ProcessError(err)
err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
}
@ -306,7 +306,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
})
}
func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) db.Error {
func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) error {
status.UpdatedAt = time.Now()
if len(columns) > 0 {
// If we're updating by column, ensure "updated_at" is included.
@ -317,7 +317,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
return s.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.
@ -328,7 +328,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
}).
On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")).
Exec(ctx); err != nil {
err = s.conn.ProcessError(err)
err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
}
@ -345,7 +345,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
}).
On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")).
Exec(ctx); err != nil {
err = s.conn.ProcessError(err)
err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
}
@ -361,7 +361,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
Model(a).
Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil {
err = s.conn.ProcessError(err)
err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) {
return err
}
@ -380,7 +380,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
})
}
func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {
defer s.state.Caches.GTS.Status().Invalidate("ID", id)
// Load status into cache before attempting a delete,
@ -397,7 +397,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
return err
}
return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
return s.db.RunInTx(ctx, func(tx bun.Tx) error {
// delete links between this status and any emojis it uses
if _, err := tx.
NewDelete().
@ -433,7 +433,7 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([
var statusIDs []string
// Create SELECT status query.
q := s.conn.NewSelect().
q := s.db.NewSelect().
Table("statuses").
Column("id")
@ -450,14 +450,14 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([
// Execute the query, scanning destination into statusIDs.
if _, err := q.Exec(ctx, &statusIDs); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
// Convert status IDs into status objects.
return s.GetStatusesByIDs(ctx, statusIDs)
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
if onlyDirect {
// Only want the direct parent, no further than first level
parent, err := s.GetStatusByID(ctx, status.InReplyToID)
@ -485,7 +485,7 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status
return parents, nil
}
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
@ -509,7 +509,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
var childIDs []string
q := s.conn.
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id").
@ -554,71 +554,71 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
}
}
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error) {
return s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID).
Count(ctx)
}
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error) {
return s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID).
Count(ctx)
}
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error) {
return s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Where("? = ?", bun.Ident("status_fave.status_id"), status.ID).
Count(ctx)
}
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) {
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Where("? = ?", bun.Ident("status_fave.status_id"), status.ID).
Where("? = ?", bun.Ident("status_fave.account_id"), accountID)
return s.conn.Exists(ctx, q)
return s.db.Exists(ctx, q)
}
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) {
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID).
Where("? = ?", bun.Ident("status.account_id"), accountID)
return s.conn.Exists(ctx, q)
return s.db.Exists(ctx, q)
}
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) {
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")).
Where("? = ?", bun.Ident("status_mute.status_id"), status.ID).
Where("? = ?", bun.Ident("status_mute.account_id"), accountID)
return s.conn.Exists(ctx, q)
return s.db.Exists(ctx, q)
}
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) {
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID).
Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID)
return s.conn.Exists(ctx, q)
return s.db.Exists(ctx, q)
}
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) {
reblogs := []*gtsmodel.Status{}
q := s.
@ -626,7 +626,7 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID)
if err := q.Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
return reblogs, nil
}

View file

@ -22,7 +22,6 @@ import (
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -30,20 +29,20 @@ import (
)
type statusBookmarkDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, db.Error) {
func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, error) {
bookmark := new(gtsmodel.StatusBookmark)
err := s.conn.
err := s.db.
NewSelect().
Model(bookmark).
Where("? = ?", bun.Ident("status_bookmark.id"), id).
Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
bookmark.Account, err = s.state.DB.GetAccountByID(ctx, bookmark.AccountID)
@ -64,10 +63,10 @@ func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*g
return bookmark, nil
}
func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, db.Error) {
func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, error) {
var id string
q := s.conn.
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Column("status_bookmark.id").
@ -76,13 +75,13 @@ func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID st
Limit(1)
if err := q.Scan(ctx, &id); err != nil {
return "", s.conn.ProcessError(err)
return "", s.db.ProcessError(err)
}
return id, nil
}
func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, db.Error) {
func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -91,7 +90,7 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str
// Guess size of IDs based on limit.
ids := make([]string, 0, limit)
q := s.conn.
q := s.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Column("status_bookmark.id").
@ -115,7 +114,7 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str
}
if err := q.Scan(ctx, &ids); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
bookmarks := make([]*gtsmodel.StatusBookmark, 0, len(ids))
@ -133,26 +132,26 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str
return bookmarks, nil
}
func (s *statusBookmarkDB) PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) db.Error {
_, err := s.conn.
func (s *statusBookmarkDB) PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) error {
_, err := s.db.
NewInsert().
Model(statusBookmark).
Exec(ctx)
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}
func (s *statusBookmarkDB) DeleteStatusBookmark(ctx context.Context, id string) db.Error {
_, err := s.conn.
func (s *statusBookmarkDB) DeleteStatusBookmark(ctx context.Context, id string) error {
_, err := s.db.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Where("? = ?", bun.Ident("status_bookmark.id"), id).
Exec(ctx)
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}
func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) error {
if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteBookmarks: one of targetAccountID or originAccountID must be set")
}
@ -161,7 +160,7 @@ func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAcco
// statement (when bookmarks have a cache),
// + use the IDs to invalidate cache entries.
q := s.conn.
q := s.db.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark"))
@ -174,24 +173,24 @@ func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAcco
}
if _, err := q.Exec(ctx); err != nil {
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}
return nil
}
func (s *statusBookmarkDB) DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) db.Error {
func (s *statusBookmarkDB) DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) error {
// TODO: Capture bookmark IDs in a RETURNING
// statement (when bookmarks have a cache),
// + use the IDs to invalidate cache entries.
q := s.conn.
q := s.db.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Where("? = ?", bun.Ident("status_bookmark.status_id"), statusID)
if _, err := q.Exec(ctx); err != nil {
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}
return nil

View file

@ -32,16 +32,16 @@ import (
)
type statusFaveDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) {
return s.getStatusFave(
ctx,
"AccountID.StatusID",
func(fave *gtsmodel.StatusFave) error {
return s.conn.
return s.db.
NewSelect().
Model(fave).
Where("? = ?", bun.Ident("account_id"), accountID).
@ -53,12 +53,12 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, stat
)
}
func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, error) {
return s.getStatusFave(
ctx,
"ID",
func(fave *gtsmodel.StatusFave) error {
return s.conn.
return s.db.
NewSelect().
Model(fave).
Where("? = ?", bun.Ident("id"), id).
@ -75,7 +75,7 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery
// Not cached! Perform database query.
if err := dbQuery(&fave); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
return &fave, nil
@ -119,16 +119,16 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery
return fave, nil
}
func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) {
ids := []string{}
if err := s.conn.
if err := s.db.
NewSelect().
Table("status_faves").
Column("id").
Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx, &ids); err != nil {
return nil, s.conn.ProcessError(err)
return nil, s.db.ProcessError(err)
}
faves := make([]*gtsmodel.StatusFave, 0, len(ids))
@ -188,17 +188,17 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo
return errs.Combine()
}
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error {
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error {
return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
_, err := s.conn.
_, err := s.db.
NewInsert().
Model(fave).
Exec(ctx)
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
})
}
func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error {
func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) error {
defer s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
// Load fave into cache before attempting a delete,
@ -214,21 +214,21 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.E
}
// Finally delete fave from DB.
_, err = s.conn.NewDelete().
_, err = s.db.NewDelete().
Table("status_faves").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error {
if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set")
}
var faveIDs []string
q := s.conn.
q := s.db.
NewSelect().
Column("id").
Table("status_faves")
@ -242,7 +242,7 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
}
if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}
defer func() {
@ -263,24 +263,24 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
}
// Finally delete all from DB.
_, err := s.conn.NewDelete().
_, err := s.db.NewDelete().
Table("status_faves").
Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)).
Exec(ctx)
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error {
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) error {
// Capture fave IDs in a RETURNING statement.
var faveIDs []string
q := s.conn.
q := s.db.
NewSelect().
Column("id").
Table("status_faves").
Where("? = ?", bun.Ident("status_id"), statusID)
if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}
defer func() {
@ -301,9 +301,9 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID
}
// Finally delete all from DB.
_, err := s.conn.NewDelete().
_, err := s.db.NewDelete().
Table("status_faves").
Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)).
Exec(ctx)
return s.conn.ProcessError(err)
return s.db.ProcessError(err)
}

View file

@ -33,11 +33,11 @@ import (
)
type timelineDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -49,7 +49,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
frontToBack = true
)
q := t.conn.
q := t.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
@ -103,7 +103,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
// Subquery to select target (followed) account
// IDs from follows owned by given accountID.
subQ := t.conn.
subQ := t.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.target_account_id").
@ -119,7 +119,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
})
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err)
return nil, t.db.ProcessError(err)
}
if len(statusIDs) == 0 {
@ -151,7 +151,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
return statuses, nil
}
func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -160,7 +160,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
// Make educated guess for slice size
statusIDs := make([]string, 0, limit)
q := t.conn.
q := t.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id").
@ -202,7 +202,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
}
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err)
return nil, t.db.ProcessError(err)
}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
@ -224,7 +224,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -233,7 +233,7 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
// Make educated guess for slice size
faves := make([]*gtsmodel.StatusFave, 0, limit)
fq := t.conn.
fq := t.db.
NewSelect().
Model(&faves).
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
@ -253,7 +253,7 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
err := fq.Scan(ctx)
if err != nil {
return nil, "", "", t.conn.ProcessError(err)
return nil, "", "", t.db.ProcessError(err)
}
if len(faves) == 0 {
@ -322,7 +322,7 @@ func (t *timelineDB) GetListTimeline(
}
// Select target account IDs from follows.
subQ := t.conn.
subQ := t.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.target_account_id").
@ -330,7 +330,7 @@ func (t *timelineDB) GetListTimeline(
// Select only status IDs created
// by one of the followed accounts.
q := t.conn.
q := t.db.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
@ -379,7 +379,7 @@ func (t *timelineDB) GetListTimeline(
}
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err)
return nil, t.db.ProcessError(err)
}
if len(statusIDs) == 0 {

View file

@ -27,28 +27,28 @@ import (
)
type tombstoneDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, db.Error) {
func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) {
return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) {
var tomb gtsmodel.Tombstone
q := t.conn.
q := t.db.
NewSelect().
Model(&tomb).
Where("? = ?", bun.Ident("tombstone.uri"), uri)
if err := q.Scan(ctx); err != nil {
return nil, t.conn.ProcessError(err)
return nil, t.db.ProcessError(err)
}
return &tomb, nil
}, uri)
}
func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (bool, db.Error) {
func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (bool, error) {
tomb, err := t.GetTombstoneByURI(ctx, uri)
if err == db.ErrNoEntries {
err = nil
@ -56,23 +56,23 @@ func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (b
return (tomb != nil), err
}
func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) db.Error {
func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error {
return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error {
_, err := t.conn.
_, err := t.db.
NewInsert().
Model(tombstone).
Exec(ctx)
return t.conn.ProcessError(err)
return t.db.ProcessError(err)
})
}
func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) db.Error {
func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error {
defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id)
// Delete tombstone from DB.
_, err := t.conn.NewDelete().
_, err := t.db.NewDelete().
TableExpr("? AS ?", bun.Ident("tombstones"), bun.Ident("tombstone")).
Where("? = ?", bun.Ident("tombstone.id"), id).
Exec(ctx)
return t.conn.ProcessError(err)
return t.db.ProcessError(err)
}

View file

@ -30,125 +30,125 @@ import (
)
type userDB struct {
conn *DBConn
db *WrappedDB
state *state.State
}
func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) {
func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("ID", func() (*gtsmodel.User, error) {
var user gtsmodel.User
q := u.conn.
q := u.db.
NewSelect().
Model(&user).
Relation("Account").
Where("? = ?", bun.Ident("user.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
return nil, u.db.ProcessError(err)
}
return &user, nil
}, id)
}
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) {
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("AccountID", func() (*gtsmodel.User, error) {
var user gtsmodel.User
q := u.conn.
q := u.db.
NewSelect().
Model(&user).
Relation("Account").
Where("? = ?", bun.Ident("user.account_id"), accountID)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
return nil, u.db.ProcessError(err)
}
return &user, nil
}, accountID)
}
func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) {
func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("Email", func() (*gtsmodel.User, error) {
var user gtsmodel.User
q := u.conn.
q := u.db.
NewSelect().
Model(&user).
Relation("Account").
Where("? = ?", bun.Ident("user.email"), emailAddress)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
return nil, u.db.ProcessError(err)
}
return &user, nil
}, emailAddress)
}
func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, db.Error) {
func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("ExternalID", func() (*gtsmodel.User, error) {
var user gtsmodel.User
q := u.conn.
q := u.db.
NewSelect().
Model(&user).
Relation("Account").
Where("? = ?", bun.Ident("user.external_id"), id)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
return nil, u.db.ProcessError(err)
}
return &user, nil
}, id)
}
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) {
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("ConfirmationToken", func() (*gtsmodel.User, error) {
var user gtsmodel.User
q := u.conn.
q := u.db.
NewSelect().
Model(&user).
Relation("Account").
Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken)
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
return nil, u.db.ProcessError(err)
}
return &user, nil
}, confirmationToken)
}
func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, db.Error) {
func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
var users []*gtsmodel.User
q := u.conn.
q := u.db.
NewSelect().
Model(&users).
Relation("Account")
if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err)
return nil, u.db.ProcessError(err)
}
return users, nil
}
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) db.Error {
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error {
return u.state.Caches.GTS.User().Store(user, func() error {
_, err := u.conn.
_, err := u.db.
NewInsert().
Model(user).
Exec(ctx)
return u.conn.ProcessError(err)
return u.db.ProcessError(err)
})
}
func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) db.Error {
func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) error {
// Update the user's last-updated
user.UpdatedAt = time.Now()
@ -158,17 +158,17 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
}
return u.state.Caches.GTS.User().Store(user, func() error {
_, err := u.conn.
_, err := u.db.
NewUpdate().
Model(user).
Where("? = ?", bun.Ident("user.id"), user.ID).
Column(columns...).
Exec(ctx)
return u.conn.ProcessError(err)
return u.db.ProcessError(err)
})
}
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error {
defer u.state.Caches.GTS.User().Invalidate("ID", userID)
// Load user into cache before attempting a delete,
@ -184,9 +184,9 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
}
// Finally delete user from DB.
_, err = u.conn.NewDelete().
_, err = u.db.NewDelete().
TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
Where("? = ?", bun.Ident("user.id"), userID).
Exec(ctx)
return u.conn.ProcessError(err)
return u.db.ProcessError(err)
}

258
internal/db/bundb/wrap.go Normal file
View file

@ -0,0 +1,258 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"database/sql"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
// WrappedDB wraps a bun database instance
// to provide common per-dialect SQL error
// conversions to common types, and retries
// on returned busy errors (SQLite only for now).
type WrappedDB struct {
errHook func(error) error
*bun.DB // underlying conn
}
// WrapDB wraps a bun database instance in our own WrappedDB type.
func WrapDB(db *bun.DB) *WrappedDB {
var errProc func(error) error
switch name := db.Dialect().Name(); name {
case dialect.PG:
errProc = processPostgresError
case dialect.SQLite:
errProc = processSQLiteError
default:
panic("unknown dialect name: " + name.String())
}
return &WrappedDB{
errHook: errProc,
DB: db,
}
}
// BeginTx wraps bun.DB.BeginTx() with retry-busy timeout.
func (db *WrappedDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (tx bun.Tx, err error) {
err = retryOnBusy(ctx, func() error {
tx, err = db.DB.BeginTx(ctx, opts)
err = db.ProcessError(err)
return err
})
return
}
// ExecContext wraps bun.DB.ExecContext() with retry-busy timeout.
func (db *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
err = retryOnBusy(ctx, func() error {
result, err = db.DB.ExecContext(ctx, query, args...)
err = db.ProcessError(err)
return err
})
return
}
// QueryContext wraps bun.DB.QueryContext() with retry-busy timeout.
func (db *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
err = retryOnBusy(ctx, func() error {
rows, err = db.DB.QueryContext(ctx, query, args...)
err = db.ProcessError(err)
return err
})
return
}
// QueryRowContext wraps bun.DB.QueryRowContext() with retry-busy timeout.
func (db *WrappedDB) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) {
_ = retryOnBusy(ctx, func() error {
row = db.DB.QueryRowContext(ctx, query, args...)
err := db.ProcessError(row.Err())
return err
})
return
}
// RunInTx is functionally the same as bun.DB.RunInTx() but with retry-busy timeouts.
func (db *WrappedDB) RunInTx(ctx context.Context, fn func(bun.Tx) error) error {
// Attempt to start new transaction.
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
var done bool
defer func() {
if !done {
// Rollback (with retry-backoff).
_ = retryOnBusy(ctx, func() error {
err := tx.Rollback()
return db.errHook(err)
})
}
}()
// Perform supplied transaction
if err := fn(tx); err != nil {
return db.errHook(err)
}
// Commit (with retry-backoff).
err = retryOnBusy(ctx, func() error {
err := tx.Commit()
return db.errHook(err)
})
done = true
return err
}
func (db *WrappedDB) NewValues(model interface{}) *bun.ValuesQuery {
return bun.NewValuesQuery(db.DB, model).Conn(db)
}
func (db *WrappedDB) NewMerge() *bun.MergeQuery {
return bun.NewMergeQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewSelect() *bun.SelectQuery {
return bun.NewSelectQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewInsert() *bun.InsertQuery {
return bun.NewInsertQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewUpdate() *bun.UpdateQuery {
return bun.NewUpdateQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewDelete() *bun.DeleteQuery {
return bun.NewDeleteQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewRaw(query string, args ...interface{}) *bun.RawQuery {
return bun.NewRawQuery(db.DB, query, args...).Conn(db)
}
func (db *WrappedDB) NewCreateTable() *bun.CreateTableQuery {
return bun.NewCreateTableQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewDropTable() *bun.DropTableQuery {
return bun.NewDropTableQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewCreateIndex() *bun.CreateIndexQuery {
return bun.NewCreateIndexQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewDropIndex() *bun.DropIndexQuery {
return bun.NewDropIndexQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewTruncateTable() *bun.TruncateTableQuery {
return bun.NewTruncateTableQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewAddColumn() *bun.AddColumnQuery {
return bun.NewAddColumnQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewDropColumn() *bun.DropColumnQuery {
return bun.NewDropColumnQuery(db.DB).Conn(db)
}
// ProcessError processes an error to replace any known values with our own error types,
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
func (db *WrappedDB) ProcessError(err error) error {
if err == nil {
return nil
}
return db.errHook(err)
}
// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
func (db *WrappedDB) Exists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
exists, err := query.Exists(ctx)
switch err {
case nil:
return exists, nil
case sql.ErrNoRows:
return false, nil
default:
return false, err
}
}
// NotExists is the functional opposite of conn.Exists()
func (db *WrappedDB) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
exists, err := db.Exists(ctx, query)
return !exists, err
}
// retryOnBusy will retry given function on returned 'errBusy'.
func retryOnBusy(ctx context.Context, fn func() error) error {
var backoff time.Duration
for i := 0; ; i++ {
// Perform func.
err := fn()
if err != errBusy {
// May be nil, or may be
// some other error, either
// way return here.
return err
}
// backoff according to a multiplier of 2ms * 2^2n,
// up to a maximum possible backoff time of 5 minutes.
//
// this works out as the following:
// 4ms
// 16ms
// 64ms
// 256ms
// 1.024s
// 4.096s
// 16.384s
// 1m5.536s
// 4m22.144s
backoff = 2 * time.Millisecond * (1 << (2*i + 1))
if backoff >= 5*time.Minute {
break
}
select {
// Context cancelled.
case <-ctx.Done():
// Backoff for some time.
case <-time.After(backoff):
}
}
return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff)
}

View file

@ -27,29 +27,29 @@ import (
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// CreateDomainBlock puts the given instance-level domain block into the database.
CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error
CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error
// GetDomainBlock returns one instance-level domain block with the given domain, if it exists.
GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error)
GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error)
// GetDomainBlockByID returns one instance-level domain block with the given id, if it exists.
GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, Error)
GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, error)
// GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance.
GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error)
// DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists.
DeleteDomainBlock(ctx context.Context, domain string) Error
DeleteDomainBlock(ctx context.Context, domain string) error
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
IsDomainBlocked(ctx context.Context, domain string) (bool, Error)
IsDomainBlocked(ctx context.Context, domain string) (bool, error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error)
AreDomainsBlocked(ctx context.Context, domains []string) (bool, error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error)
IsURIBlocked(ctx context.Context, uri *url.URL) (bool, error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error)
AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, error)
}

View file

@ -31,16 +31,16 @@ const EmojiAllDomains string = "all"
// Emoji contains functions for getting emoji in the database.
type Emoji interface {
// PutEmoji puts one emoji in the database.
PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) Error
PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error
// UpdateEmoji updates the given columns of one emoji.
// If no columns are specified, every column is updated.
UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) error
// DeleteEmojiByID deletes one emoji by its database ID.
DeleteEmojiByID(ctx context.Context, id string) Error
DeleteEmojiByID(ctx context.Context, id string) error
// GetEmojisByIDs gets emojis for the given IDs.
GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, Error)
GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, error)
// GetUseableEmojis gets all emojis which are useable by accounts on this instance.
GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, Error)
GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, error)
// GetEmojis fetches all emojis with IDs less than 'maxID', up to a maximum of 'limit' emojis.
GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error)
@ -54,22 +54,22 @@ type Emoji interface {
// GetEmojisBy gets emojis based on given parameters. Useful for admin actions.
GetEmojisBy(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, error)
// GetEmojiByID gets a specific emoji by its database ID.
GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, Error)
GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, error)
// GetEmojiByShortcodeDomain gets an emoji based on its shortcode and domain.
// For local emoji, domain should be an empty string.
GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, Error)
GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error)
// GetEmojiByURI returns one emoji based on its ActivityPub URI.
GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, Error)
GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, error)
// GetEmojiByStaticURL gets an emoji using the URL of the static version of the emoji image.
GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, Error)
GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, error)
// PutEmojiCategory puts one new emoji category in the database.
PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) Error
PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error
// GetEmojiCategoriesByIDs gets emoji categories for given IDs.
GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, Error)
GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, error)
// GetEmojiCategories gets a slice of the names of all existing emoji categories.
GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, Error)
GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, error)
// GetEmojiCategory gets one emoji category by its id.
GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, Error)
GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, error)
// GetEmojiCategoryByName gets one emoji category by its name.
GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, Error)
GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error)
}

View file

@ -17,18 +17,20 @@
package db
import "fmt"
// Error denotes a database error.
type Error error
import (
"database/sql"
"errors"
)
var (
// ErrNoEntries is returned when a caller expected an entry for a query, but none was found.
ErrNoEntries Error = fmt.Errorf("no entries")
// ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found.
ErrMultipleEntries Error = fmt.Errorf("multiple entries")
// ErrNoEntries is a direct ptr to sql.ErrNoRows since that is returned regardless
// of DB dialect. It is returned when no rows (entries) can be found for a query.
ErrNoEntries = sql.ErrNoRows
// ErrAlreadyExists is returned when a conflict was encountered in the db when doing an insert.
ErrAlreadyExists Error = fmt.Errorf("already exists")
// ErrUnknown denotes an unknown database error.
ErrUnknown Error = fmt.Errorf("unknown error")
ErrAlreadyExists = errors.New("already exists")
// ErrBusyTimeout is returned if the database connection indicates the connection is too busy
// to complete the supplied query. This is generally intended to be handled internally by the DB.
ErrBusyTimeout = errors.New("busy timeout")
)

View file

@ -26,16 +26,16 @@ import (
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain.
CountInstanceUsers(ctx context.Context, domain string) (int, Error)
CountInstanceUsers(ctx context.Context, domain string) (int, error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain.
CountInstanceStatuses(ctx context.Context, domain string) (int, Error)
CountInstanceStatuses(ctx context.Context, domain string) (int, error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with.
CountInstanceDomains(ctx context.Context, domain string) (int, Error)
CountInstanceDomains(ctx context.Context, domain string) (int, error)
// GetInstance returns the instance entry for the given domain, if it exists.
GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error)
GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, error)
// GetInstanceByID returns the instance entry corresponding to the given id, if it exists.
GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error)
@ -47,12 +47,12 @@ type Instance interface {
UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, error)
// GetInstancePeers returns a slice of instances that the host instance knows about.
GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, Error)
GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, error)
// GetInstanceModeratorAddresses returns a slice of email addresses belonging to active
// (as in, not suspended) moderators + admins on this instance.
GetInstanceModeratorAddresses(ctx context.Context) ([]string, Error)
GetInstanceModeratorAddresses(ctx context.Context) ([]string, error)
}

View file

@ -27,7 +27,7 @@ import (
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID.
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, error)
// GetAttachmentsByIDs fetches a list of media attachments for given IDs.
GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error)
@ -49,25 +49,25 @@ type Media interface {
// GetCachedAttachmentsOlderThan gets limit n remote attachments (including avatars and headers) older than
// the given time. These will be returned in order of attachment.created_at descending (i.e. newest to oldest).
GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, Error)
GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error)
// CountRemoteOlderThan is like GetRemoteOlderThan, except instead of getting limit n attachments,
// it just counts how many remote attachments in the database (including avatars and headers) meet
// the olderThan criteria.
CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, Error)
CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, error)
// GetAvatarsAndHeaders fetches limit n avatars and headers with an id < maxID. These headers
// and avis may be in use or not; the caller should check this if it's important.
GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, Error)
GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error)
// GetLocalUnattachedOlderThan fetches limit n local media attachments (including avatars and headers), older than
// the given time, which aren't header or avatars, and aren't attached to a status. In other words, attachments which were
// uploaded but never used for whatever reason, or attachments that were attached to a status which was subsequently deleted.
//
// These will be returned in order of attachment.created_at descending (newest to oldest in other words).
GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, Error)
GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error)
// CountLocalUnattachedOlderThan is like GetLocalUnattachedOlderThan, except instead of getting limit n attachments,
// it just counts how many local attachments in the database meet the olderThan criteria.
CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, Error)
CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, error)
}

View file

@ -26,10 +26,10 @@ import (
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error)
GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error)
// GetMentions gets multiple mentions.
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error)
// PutMention will insert the given mention into the database.
PutMention(ctx context.Context, mention *gtsmodel.Mention) error

View file

@ -28,21 +28,21 @@ type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, excludeTypes []string) ([]*gtsmodel.Notification, Error)
GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, excludeTypes []string) ([]*gtsmodel.Notification, error)
// GetNotification returns one notification according to its id.
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error)
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error)
// GetNotification gets one notification according to the provided parameters, if it exists.
// Since not all notifications are about a status, statusID can be an empty string.
GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, Error)
GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, error)
// PutNotification will insert the given notification into the database.
PutNotification(ctx context.Context, notif *gtsmodel.Notification) error
// DeleteNotificationByID deletes one notification according to its id,
// and removes that notification from the in-memory cache.
DeleteNotificationByID(ctx context.Context, id string) Error
DeleteNotificationByID(ctx context.Context, id string) error
// DeleteNotifications mass deletes notifications targeting targetAccountID
// and/or originating from originAccountID.
@ -57,10 +57,10 @@ type Notification interface {
// originate from originAccountID will be deleted.
//
// At least one parameter must not be an empty string.
DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error
DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error
// DeleteNotificationsForStatus deletes all notifications that relate to
// the given statusID. This function is useful when a status has been deleted,
// and so notifications relating to that status must also be deleted.
DeleteNotificationsForStatus(ctx context.Context, statusID string) Error
DeleteNotificationsForStatus(ctx context.Context, statusID string) error
}

View file

@ -26,7 +26,7 @@ import (
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether source account has a block in place against target.
IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error)
// IsEitherBlocked checks whether there is a block in place between either of account1 and account2.
IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error)
@ -53,7 +53,7 @@ type Relationship interface {
DeleteAccountBlocks(ctx context.Context, accountID string) error
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
// GetFollowByID fetches follow with given ID from the database.
GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error)
@ -77,13 +77,13 @@ type Relationship interface {
GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error)
// PutFollow attempts to place the given account follow in the database.
PutFollow(ctx context.Context, follow *gtsmodel.Follow) error
@ -125,10 +125,10 @@ type Relationship interface {
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// RejectFollowRequest fetches a follow request from the database, and then deletes it.
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error
// GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)

View file

@ -26,18 +26,18 @@ import (
// Report handles getting/creation/deletion/updating of user reports/flags.
type Report interface {
// GetReportByID gets one report by its db id
GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, Error)
GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, error)
// GetReports gets limit n reports using the given parameters.
// Parameters that are empty / zero are ignored.
GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, Error)
GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, error)
// PutReport puts the given report in the database.
PutReport(ctx context.Context, report *gtsmodel.Report) Error
PutReport(ctx context.Context, report *gtsmodel.Report) error
// UpdateReport updates one report by its db id.
// The given columns will be updated; if no columns are
// provided, then all columns will be updated.
// updated_at will also be updated, no need to pass this
// as a specific column.
UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, Error)
UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, error)
// DeleteReportByID deletes report with the given id.
DeleteReportByID(ctx context.Context, id string) Error
DeleteReportByID(ctx context.Context, id string) error
}

View file

@ -25,5 +25,5 @@ import (
// Session handles getting/creation of router sessions.
type Session interface {
GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
GetSession(ctx context.Context) (*gtsmodel.RouterSession, error)
}

View file

@ -26,34 +26,34 @@ import (
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error)
// GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, error)
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, error)
// PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc).
PopulateStatus(ctx context.Context, status *gtsmodel.Status) error
// PutStatus stores one status in the database.
PutStatus(ctx context.Context, status *gtsmodel.Status) Error
PutStatus(ctx context.Context, status *gtsmodel.Status) error
// UpdateStatus updates one status in the database.
UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) Error
UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) error
// DeleteStatusByID deletes one status from the database.
DeleteStatusByID(ctx context.Context, id string) Error
DeleteStatusByID(ctx context.Context, id string) error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error)
CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error)
CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error)
CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error)
// GetStatuses gets a slice of statuses corresponding to the given status IDs.
GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error)
@ -64,26 +64,26 @@ type Status interface {
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error)
}

View file

@ -25,24 +25,24 @@ import (
type StatusBookmark interface {
// GetStatusBookmark gets one status bookmark with the given ID.
GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, Error)
GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, error)
// GetStatusBookmarkID is a shortcut function for returning just the database ID
// of a status bookmark created by the given accountID, targeting the given statusID.
GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, Error)
GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, error)
// GetStatusBookmarks retrieves status bookmarks created by the given accountID,
// and using the provided parameters. If limit is < 0 then no limit will be set.
//
// This function is primarily useful for paging through bookmarks in a sort of
// timeline view.
GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, Error)
GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, error)
// PutStatusBookmark inserts the given statusBookmark into the database.
PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) Error
PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) error
// DeleteStatusBookmark deletes one status bookmark with the given ID.
DeleteStatusBookmark(ctx context.Context, id string) Error
DeleteStatusBookmark(ctx context.Context, id string) error
// DeleteStatusBookmarks mass deletes status bookmarks targeting targetAccountID
// and/or originating from originAccountID and/or bookmarking statusID.
@ -57,10 +57,10 @@ type StatusBookmark interface {
// originate from originAccountID will be deleted.
//
// At least one parameter must not be an empty string.
DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) Error
DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) error
// DeleteStatusBookmarksForStatus deletes all status bookmarks that target the
// given status ID. This is useful when a status has been deleted, and you need
// to clean up after it.
DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) Error
DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) error
}

View file

@ -26,23 +26,23 @@ import (
type StatusFave interface {
// GetStatusFaveByAccountID gets one status fave created by the given
// accountID, targeting the given statusID.
GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error)
// GetStatusFave returns one status fave with the given id.
GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error)
// PopulateStatusFave ensures that all sub-models of a fave are populated (account, status, etc).
PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error
// PutStatusFave inserts the given statusFave into the database.
PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error
PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error
// DeleteStatusFave deletes one status fave with the given id.
DeleteStatusFaveByID(ctx context.Context, id string) Error
DeleteStatusFaveByID(ctx context.Context, id string) error
// DeleteStatusFaves mass deletes status faves targeting targetAccountID
// and/or originating from originAccountID and/or faving statusID.
@ -57,10 +57,10 @@ type StatusFave interface {
// originate from originAccountID will be deleted.
//
// At least one parameter must not be an empty string.
DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) Error
DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error
// DeleteStatusFavesForStatus deletes all status faves that target the
// given status ID. This is useful when a status has been deleted, and you need
// to clean up after it.
DeleteStatusFavesForStatus(ctx context.Context, statusID string) Error
DeleteStatusFavesForStatus(ctx context.Context, statusID string) error
}

View file

@ -28,13 +28,13 @@ type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@ -43,7 +43,7 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error)
// GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID.
// Statuses should be returned in descending order of when they were created (newest first).

View file

@ -26,14 +26,14 @@ import (
// Tombstone contains functionality for storing + retrieving tombstones for remote AP Activities + Objects.
type Tombstone interface {
// GetTombstoneByURI attempts to fetch a tombstone by the given URI.
GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, Error)
GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error)
// TombstoneExistsWithURI returns true if a tombstone with the given URI exists.
TombstoneExistsWithURI(ctx context.Context, uri string) (bool, Error)
TombstoneExistsWithURI(ctx context.Context, uri string) (bool, error)
// PutTombstone creates a new tombstone in the database.
PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) Error
PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error
// DeleteTombstone deletes a tombstone with the given ID.
DeleteTombstone(ctx context.Context, id string) Error
DeleteTombstone(ctx context.Context, id string) error
}

View file

@ -26,21 +26,21 @@ import (
// User contains functions related to user getting/setting/creation.
type User interface {
// GetAllUsers returns all local user accounts, or an error if something goes wrong.
GetAllUsers(ctx context.Context) ([]*gtsmodel.User, Error)
GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error)
// GetUserByID returns one user with the given ID, or an error if something goes wrong.
GetUserByID(ctx context.Context, id string) (*gtsmodel.User, Error)
GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error)
// GetUserByAccountID returns one user by its account ID, or an error if something goes wrong.
GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, Error)
GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error)
// GetUserByID returns one user with the given email address, or an error if something goes wrong.
GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error)
GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error)
// GetUserByExternalID returns one user with the given external id, or an error if something goes wrong.
GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, Error)
GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error)
// GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong.
GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error)
GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error)
// PutUser will attempt to place user in the database
PutUser(ctx context.Context, user *gtsmodel.User) Error
PutUser(ctx context.Context, user *gtsmodel.User) error
// UpdateUser updates one user by its primary key, updating either only the specified columns, or all of them.
UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) Error
UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) error
// DeleteUserByID deletes one user by its ID.
DeleteUserByID(ctx context.Context, userID string) Error
DeleteUserByID(ctx context.Context, userID string) error
}

View file

@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -175,7 +176,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsername()
config.GetHost(),
)
suite.True(gtserror.Unretrievable(err))
suite.EqualError(err, "no entries")
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(fetchedAccount)
}
@ -189,7 +190,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsernameDom
"localhost:8080",
)
suite.True(gtserror.Unretrievable(err))
suite.EqualError(err, "no entries")
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(fetchedAccount)
}
@ -202,7 +203,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUserURI() {
testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"),
)
suite.True(gtserror.Unretrievable(err))
suite.EqualError(err, "no entries")
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(fetchedAccount)
}

View file

@ -232,7 +232,7 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e
return nil, err
}
l.Infof("performing request")
l.Info("performing request")
// Perform the request.
rsp, err = c.do(r)

View file

@ -22,7 +22,6 @@ import (
"net/http"
"net/url"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/log"
@ -48,7 +47,7 @@ const (
// context for use down the line.
//
// In case of an error, the request will be aborted with http code 500.
func SignatureCheck(uriBlocked func(context.Context, *url.URL) (bool, db.Error)) func(*gin.Context) {
func SignatureCheck(uriBlocked func(context.Context, *url.URL) (bool, error)) func(*gin.Context) {
return func(c *gin.Context) {
ctx := c.Request.Context()

View file

@ -22,6 +22,7 @@ import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
type AuthorizeTestSuite struct {
@ -38,7 +39,7 @@ func (suite *AuthorizeTestSuite) TestAuthorize() {
suite.Equal(suite.testAccounts["local_account_2"].ID, account2.ID)
noAccount, err := suite.streamProcessor.Authorize(context.Background(), "aaaaaaaaaaaaaaaaaaaaa!!")
suite.EqualError(err, "could not load access token: no entries")
suite.EqualError(err, "could not load access token: "+db.ErrNoEntries.Error())
suite.Nil(noAccount)
}

View file

@ -62,7 +62,7 @@ const (
type Module struct {
processor *processing.Processor
eTagCache cache.Cache[string, eTagCacheEntry]
isURIBlocked func(context.Context, *url.URL) (bool, db.Error)
isURIBlocked func(context.Context, *url.URL) (bool, error)
}
func New(db db.DB, processor *processing.Processor) *Module {