[performance] retry db queries on busy errors (#2025)

* catch SQLITE_BUSY errors, wrap bun.DB to use our own busy retrier, remove unnecessary db.Error type

Signed-off-by: kim <grufwub@gmail.com>

* remove dead code

Signed-off-by: kim <grufwub@gmail.com>

* remove more dead code, add missing error arguments

Signed-off-by: kim <grufwub@gmail.com>

* update sqlite to use maxOpenConns()

Signed-off-by: kim <grufwub@gmail.com>

* add uncommitted changes

Signed-off-by: kim <grufwub@gmail.com>

* use direct calls-through for the ConnIface to make sure we don't double query hook

Signed-off-by: kim <grufwub@gmail.com>

* expose underlying bun.DB better

Signed-off-by: kim <grufwub@gmail.com>

* retry on the correct busy error

Signed-off-by: kim <grufwub@gmail.com>

* use longer possible maxRetries for db retry-backoff

Signed-off-by: kim <grufwub@gmail.com>

* remove the note regarding max-open-conns only applying to postgres

Signed-off-by: kim <grufwub@gmail.com>

* improved code commenting

Signed-off-by: kim <grufwub@gmail.com>

* remove unnecessary infof call (just use info)

Signed-off-by: kim <grufwub@gmail.com>

* rename DBConn to WrappedDB to better follow sql package name conventions

Signed-off-by: kim <grufwub@gmail.com>

* update test error string checks

Signed-off-by: kim <grufwub@gmail.com>

* shush linter

Signed-off-by: kim <grufwub@gmail.com>

* update backoff logic to be more transparent

Signed-off-by: kim <grufwub@gmail.com>

---------

Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
kim 2023-07-25 09:34:05 +01:00 committed by GitHub
parent 9eff0d46e4
commit 5f3e095717
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
53 changed files with 1050 additions and 898 deletions

View file

@ -194,10 +194,6 @@ db-tls-ca-cert: ""
# #
# If you set the multiplier to less than 1, only one open connection will be used regardless of cpu count. # If you set the multiplier to less than 1, only one open connection will be used regardless of cpu count.
# #
# PLEASE NOTE!!: This setting currently only applies for Postgres. SQLite will always use 1 connection regardless
# of what is set here. This behavior will change in future when we implement better SQLITE_BUSY handling.
# See https://github.com/superseriousbusiness/gotosocial/issues/1407 for more details.
#
# Examples: [16, 8, 10, 2] # Examples: [16, 8, 10, 2]
# Default: 8 # Default: 8
db-max-open-conns-multiplier: 8 db-max-open-conns-multiplier: 8

View file

@ -27,67 +27,67 @@ import (
// Account contains functions related to account getting/setting/creation. // Account contains functions related to account getting/setting/creation.
type Account interface { type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong. // GetAccountByID returns one account with the given ID, or an error if something goes wrong.
GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong. // GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong. // GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong. // GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong.
GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error)
// GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong. // GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong.
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, error)
// GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong. // GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong.
GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong. // GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong.
GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong. // GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong.
GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong. // GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong.
GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error)
// PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc). // PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
PopulateAccount(ctx context.Context, account *gtsmodel.Account) error PopulateAccount(ctx context.Context, account *gtsmodel.Account) error
// PutAccount puts one account in the database. // PutAccount puts one account in the database.
PutAccount(ctx context.Context, account *gtsmodel.Account) Error PutAccount(ctx context.Context, account *gtsmodel.Account) error
// UpdateAccount updates one account by ID. // UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) Error UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) error
// DeleteAccount deletes one account from the database by its ID. // DeleteAccount deletes one account from the database by its ID.
// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the // DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the
// account as suspended instead, rather than deleting from the db entirely. // account as suspended instead, rather than deleting from the db entirely.
DeleteAccount(ctx context.Context, id string) Error DeleteAccount(ctx context.Context, id string) error
// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username. // GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, error)
// GetAccountFaves fetches faves/likes created by the target accountID. // GetAccountFaves fetches faves/likes created by the target accountID.
GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, error)
// GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column. // GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column.
GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error) GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
CountAccountStatuses(ctx context.Context, accountID string) (int, Error) CountAccountStatuses(ctx context.Context, accountID string) (int, error)
// CountAccountPinned returns the total number of pinned statuses owned by account with the given id. // CountAccountPinned returns the total number of pinned statuses owned by account with the given id.
CountAccountPinned(ctx context.Context, accountID string) (int, Error) CountAccountPinned(ctx context.Context, accountID string) (int, error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this! // be very memory intensive so you probably shouldn't do this!
// //
// In the case of no statuses, this function will return db.ErrNoEntries. // In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, Error) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error)
// GetAccountPinnedStatuses returns ONLY statuses owned by the give accountID for which a corresponding StatusPin // GetAccountPinnedStatuses returns ONLY statuses owned by the give accountID for which a corresponding StatusPin
// exists in the database. Statuses which are not pinned will not be returned by this function. // exists in the database. Statuses which are not pinned will not be returned by this function.
@ -95,28 +95,28 @@ type Account interface {
// Statuses will be returned in the order in which they were pinned, from latest pinned to oldest pinned (descending). // Statuses will be returned in the order in which they were pinned, from latest pinned to oldest pinned (descending).
// //
// In the case of no statuses, this function will return db.ErrNoEntries. // In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, Error) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error)
// GetAccountWebStatuses is similar to GetAccountStatuses, but it's specifically for returning statuses that // GetAccountWebStatuses is similar to GetAccountStatuses, but it's specifically for returning statuses that
// should be visible via the web view of an account. So, only public, federated statuses that aren't boosts // should be visible via the web view of an account. So, only public, federated statuses that aren't boosts
// or replies. // or replies.
// //
// In the case of no statuses, this function will return db.ErrNoEntries. // In the case of no statuses, this function will return db.ErrNoEntries.
GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, Error) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error)
GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account. // GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
// //
// If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned. // If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned.
// //
// The returned time will be zero if account has never posted anything. // The returned time will be zero if account has never posted anything.
GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, Error) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
// GetInstanceAccount returns the instance account for the given domain. // GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned. // If domain is empty, this instance account will be returned.
GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error)
} }

View file

@ -27,26 +27,26 @@ import (
type Admin interface { type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain. // IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db. // Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(ctx context.Context, username string) (bool, Error) IsUsernameAvailable(ctx context.Context, username string) (bool, error)
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if: // Return an error if:
// A) the email is already associated with an account // A) the email is already associated with an account
// B) we block signups from this email domain // B) we block signups from this email domain
// C) something went wrong in the db // C) something went wrong in the db
IsEmailAvailable(ctx context.Context, email string) (bool, Error) IsEmailAvailable(ctx context.Context, email string) (bool, error)
// NewSignup creates a new user in the database with the given parameters. // NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation! // By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, Error) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value. // CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account. // This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount(ctx context.Context) Error CreateInstanceAccount(ctx context.Context) error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value. // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'. // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance // This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance(ctx context.Context) Error CreateInstanceInstance(ctx context.Context) error
} }

View file

@ -23,58 +23,58 @@ import "context"
type Basic interface { type Basic interface {
// CreateTable creates a table for the given interface. // CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil. // For implementations that don't use tables, this can just return nil.
CreateTable(ctx context.Context, i interface{}) Error CreateTable(ctx context.Context, i interface{}) error
// CreateAllTables creates *all* tables necessary for the running of GoToSocial. // CreateAllTables creates *all* tables necessary for the running of GoToSocial.
// Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized. // Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized.
CreateAllTables(ctx context.Context) Error CreateAllTables(ctx context.Context) error
// DropTable drops the table for the given interface. // DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil. // For implementations that don't use tables, this can just return nil.
DropTable(ctx context.Context, i interface{}) Error DropTable(ctx context.Context, i interface{}) error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible. // Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil. // If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) Error Stop(ctx context.Context) error
// IsHealthy should return nil if the database connection is healthy, or an error if not. // IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) Error IsHealthy(ctx context.Context) error
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry, // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map. // for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned // In case of no entries, a 'no entries' error will be returned
GetByID(ctx context.Context, id string, i interface{}) Error GetByID(ctx context.Context, id string, i interface{}) error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from. // name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned // In case of no entries, a 'no entries' error will be returned
GetWhere(ctx context.Context, where []Where, i interface{}) Error GetWhere(ctx context.Context, where []Where, i interface{}) error
// GetAll will try to get all entries of type i. // GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned // In case of no entries, a 'no entries' error will be returned
GetAll(ctx context.Context, i interface{}) Error GetAll(ctx context.Context, i interface{}) error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(ctx context.Context, i interface{}) Error Put(ctx context.Context, i interface{}) error
// UpdateByID updates values of i based on its id. // UpdateByID updates values of i based on its id.
// If any columns are specified, these will be updated exclusively. // If any columns are specified, these will be updated exclusively.
// Otherwise, the whole model will be updated. // Otherwise, the whole model will be updated.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) Error UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply. // UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) error
// DeleteByID removes i with id id. // DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned. // If i didn't exist anyway, then no error should be returned.
DeleteByID(ctx context.Context, id string, i interface{}) Error DeleteByID(ctx context.Context, id string, i interface{}) error
// DeleteWhere deletes i where key = value // DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned. // If i didn't exist anyway, then no error should be returned.
DeleteWhere(ctx context.Context, where []Where, i interface{}) Error DeleteWhere(ctx context.Context, where []Where, i interface{}) error
} }

View file

@ -38,16 +38,16 @@ import (
) )
type accountDB struct { type accountDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
"ID", "ID",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.conn.NewSelect(). return a.db.NewSelect().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.id"), id). Where("? = ?", bun.Ident("account.id"), id).
Scan(ctx) Scan(ctx)
@ -77,12 +77,12 @@ func (a *accountDB) GetAccountsByIDs(ctx context.Context, ids []string) ([]*gtsm
return accounts, nil return accounts, nil
} }
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
"URI", "URI",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.conn.NewSelect(). return a.db.NewSelect().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.uri"), uri). Where("? = ?", bun.Ident("account.uri"), uri).
Scan(ctx) Scan(ctx)
@ -91,12 +91,12 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
) )
} }
func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
"URL", "URL",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.conn.NewSelect(). return a.db.NewSelect().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.url"), url). Where("? = ?", bun.Ident("account.url"), url).
Scan(ctx) Scan(ctx)
@ -105,7 +105,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
) )
} }
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, error) {
if domain != "" { if domain != "" {
// Normalize the domain as punycode // Normalize the domain as punycode
var err error var err error
@ -119,7 +119,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
ctx, ctx,
"Username.Domain", "Username.Domain",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
q := a.conn.NewSelect(). q := a.db.NewSelect().
Model(account) Model(account)
if domain != "" { if domain != "" {
@ -139,12 +139,12 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
) )
} }
func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
"PublicKeyURI", "PublicKeyURI",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.conn.NewSelect(). return a.db.NewSelect().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.public_key_uri"), id). Where("? = ?", bun.Ident("account.public_key_uri"), id).
Scan(ctx) Scan(ctx)
@ -153,12 +153,12 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
) )
} }
func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
"InboxURI", "InboxURI",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.conn.NewSelect(). return a.db.NewSelect().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.inbox_uri"), uri). Where("? = ?", bun.Ident("account.inbox_uri"), uri).
Scan(ctx) Scan(ctx)
@ -167,12 +167,12 @@ func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsm
) )
} }
func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
"OutboxURI", "OutboxURI",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.conn.NewSelect(). return a.db.NewSelect().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.outbox_uri"), uri). Where("? = ?", bun.Ident("account.outbox_uri"), uri).
Scan(ctx) Scan(ctx)
@ -181,12 +181,12 @@ func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gts
) )
} }
func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
"FollowersURI", "FollowersURI",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.conn.NewSelect(). return a.db.NewSelect().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.followers_uri"), uri). Where("? = ?", bun.Ident("account.followers_uri"), uri).
Scan(ctx) Scan(ctx)
@ -195,12 +195,12 @@ func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*
) )
} }
func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, error) {
return a.getAccount( return a.getAccount(
ctx, ctx,
"FollowingURI", "FollowingURI",
func(account *gtsmodel.Account) error { func(account *gtsmodel.Account) error {
return a.conn.NewSelect(). return a.db.NewSelect().
Model(account). Model(account).
Where("? = ?", bun.Ident("account.following_uri"), uri). Where("? = ?", bun.Ident("account.following_uri"), uri).
Scan(ctx) Scan(ctx)
@ -209,7 +209,7 @@ func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*
) )
} }
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, error) {
var username string var username string
if domain == "" { if domain == "" {
@ -223,14 +223,14 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
return a.GetAccountByUsernameDomain(ctx, username, domain) return a.GetAccountByUsernameDomain(ctx, username, domain)
} }
func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, db.Error) { func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {
// Fetch account from database cache with loader callback // Fetch account from database cache with loader callback
account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) { account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) {
var account gtsmodel.Account var account gtsmodel.Account
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(&account); err != nil { if err := dbQuery(&account); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.db.ProcessError(err)
} }
return &account, nil return &account, nil
@ -294,12 +294,12 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
return errs.Combine() return errs.Combine()
} }
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error { func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error {
return a.state.Caches.GTS.Account().Store(account, func() error { return a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store // It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook. // as the cache does not attempt a mutex lock until AFTER hook.
// //
return a.conn.RunInTx(ctx, func(tx bun.Tx) error { return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses // create links between this account and any emojis it uses
for _, i := range account.EmojiIDs { for _, i := range account.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{ if _, err := tx.NewInsert().Model(&gtsmodel.AccountToEmoji{
@ -317,7 +317,7 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) d
}) })
} }
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) db.Error { func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account, columns ...string) error {
account.UpdatedAt = time.Now() account.UpdatedAt = time.Now()
if len(columns) > 0 { if len(columns) > 0 {
// If we're updating by column, ensure "updated_at" is included. // If we're updating by column, ensure "updated_at" is included.
@ -328,7 +328,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
// It is safe to run this database transaction within cache.Store // It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook. // as the cache does not attempt a mutex lock until AFTER hook.
// //
return a.conn.RunInTx(ctx, func(tx bun.Tx) error { return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this account and any emojis it uses // create links between this account and any emojis it uses
// first clear out any old emoji links // first clear out any old emoji links
if _, err := tx. if _, err := tx.
@ -362,7 +362,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
}) })
} }
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
defer a.state.Caches.GTS.Account().Invalidate("ID", id) defer a.state.Caches.GTS.Account().Invalidate("ID", id)
// Load account into cache before attempting a delete, // Load account into cache before attempting a delete,
@ -376,7 +376,7 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
return err return err
} }
return a.conn.RunInTx(ctx, func(tx bun.Tx) error { return a.db.RunInTx(ctx, func(tx bun.Tx) error {
// clear out any emoji links // clear out any emoji links
if _, err := tx. if _, err := tx.
NewDelete(). NewDelete().
@ -396,10 +396,10 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
}) })
} }
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, db.Error) { func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) {
createdAt := time.Time{} createdAt := time.Time{}
q := a.conn. q := a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.created_at"). Column("status.created_at").
@ -416,12 +416,12 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string,
} }
if err := q.Scan(ctx, &createdAt); err != nil { if err := q.Scan(ctx, &createdAt); err != nil {
return time.Time{}, a.conn.ProcessError(err) return time.Time{}, a.db.ProcessError(err)
} }
return createdAt, nil return createdAt, nil
} }
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error { func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
if *mediaAttachment.Avatar && *mediaAttachment.Header { if *mediaAttachment.Avatar && *mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar") return errors.New("one media attachment cannot be both header and avatar")
} }
@ -437,26 +437,26 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
} }
// TODO: there are probably more side effects here that need to be handled // TODO: there are probably more side effects here that need to be handled
if _, err := a.conn. if _, err := a.db.
NewInsert(). NewInsert().
Model(mediaAttachment). Model(mediaAttachment).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return a.conn.ProcessError(err) return a.db.ProcessError(err)
} }
if _, err := a.conn. if _, err := a.db.
NewUpdate(). NewUpdate().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Set("? = ?", column, mediaAttachment.ID). Set("? = ?", column, mediaAttachment.ID).
Where("? = ?", bun.Ident("account.id"), accountID). Where("? = ?", bun.Ident("account.id"), accountID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return a.conn.ProcessError(err) return a.db.ProcessError(err)
} }
return nil return nil
} }
func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, db.Error) { func (a *accountDB) GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, error) {
account, err := a.GetAccountByUsernameDomain(ctx, username, "") account, err := a.GetAccountByUsernameDomain(ctx, username, "")
if err != nil { if err != nil {
return "", err return "", err
@ -469,7 +469,7 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) (
var accountIDs []string var accountIDs []string
// Create SELECT account query. // Create SELECT account query.
q := a.conn.NewSelect(). q := a.db.NewSelect().
Table("accounts"). Table("accounts").
Column("id") Column("id")
@ -486,37 +486,37 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) (
// Execute the query, scanning destination into accountIDs. // Execute the query, scanning destination into accountIDs.
if _, err := q.Exec(ctx, &accountIDs); err != nil { if _, err := q.Exec(ctx, &accountIDs); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.db.ProcessError(err)
} }
// Convert account IDs into account objects. // Convert account IDs into account objects.
return a.GetAccountsByIDs(ctx, accountIDs) return a.GetAccountsByIDs(ctx, accountIDs)
} }
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) { func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, error) {
faves := new([]*gtsmodel.StatusFave) faves := new([]*gtsmodel.StatusFave)
if err := a.conn. if err := a.db.
NewSelect(). NewSelect().
Model(faves). Model(faves).
Where("? = ?", bun.Ident("status_fave.account_id"), accountID). Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
Scan(ctx); err != nil { Scan(ctx); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.db.ProcessError(err)
} }
return *faves, nil return *faves, nil
} }
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) {
return a.conn. return a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.account_id"), accountID). Where("? = ?", bun.Ident("status.account_id"), accountID).
Count(ctx) Count(ctx)
} }
func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, db.Error) { func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) {
return a.conn. return a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.account_id"), accountID). Where("? = ?", bun.Ident("status.account_id"), accountID).
@ -524,7 +524,7 @@ func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (i
Count(ctx) Count(ctx)
} }
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) { func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
limit = 0 limit = 0
@ -536,7 +536,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
frontToBack = true frontToBack = true
) )
q := a.conn. q := a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table // Select only IDs from table
@ -562,7 +562,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
// implementation differs between SQLite and Postgres, // implementation differs between SQLite and Postgres,
// so we have to be thorough to cover all eventualities // so we have to be thorough to cover all eventualities
q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
switch a.conn.Dialect().Name() { switch a.db.Dialect().Name() {
case dialect.PG: case dialect.PG:
return q. return q.
Where("? IS NOT NULL", bun.Ident("status.attachments")). Where("? IS NOT NULL", bun.Ident("status.attachments")).
@ -613,7 +613,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
} }
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.db.ProcessError(err)
} }
// If we're paging up, we still want statuses // If we're paging up, we still want statuses
@ -628,10 +628,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
return a.statusesFromIDs(ctx, statusIDs) return a.statusesFromIDs(ctx, statusIDs)
} }
func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, db.Error) { func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) {
statusIDs := []string{} statusIDs := []string{}
q := a.conn. q := a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id"). Column("status.id").
@ -640,13 +640,13 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri
Order("status.pinned_at DESC") Order("status.pinned_at DESC")
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.db.ProcessError(err)
} }
return a.statusesFromIDs(ctx, statusIDs) return a.statusesFromIDs(ctx, statusIDs)
} }
func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) { func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
limit = 0 limit = 0
@ -655,7 +655,7 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
// Make educated guess for slice size // Make educated guess for slice size
statusIDs := make([]string, 0, limit) statusIDs := make([]string, 0, limit)
q := a.conn. q := a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table // Select only IDs from table
@ -688,16 +688,16 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
q = q.Order("status.id DESC") q = q.Order("status.id DESC")
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.db.ProcessError(err)
} }
return a.statusesFromIDs(ctx, statusIDs) return a.statusesFromIDs(ctx, statusIDs)
} }
func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) { func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
blocks := []*gtsmodel.Block{} blocks := []*gtsmodel.Block{}
fq := a.conn. fq := a.db.
NewSelect(). NewSelect().
Model(&blocks). Model(&blocks).
Where("? = ?", bun.Ident("block.account_id"), accountID). Where("? = ?", bun.Ident("block.account_id"), accountID).
@ -717,7 +717,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
} }
if err := fq.Scan(ctx); err != nil { if err := fq.Scan(ctx); err != nil {
return nil, "", "", a.conn.ProcessError(err) return nil, "", "", a.db.ProcessError(err)
} }
if len(blocks) == 0 { if len(blocks) == 0 {
@ -734,7 +734,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
return accounts, nextMaxID, prevMinID, nil return accounts, nextMaxID, prevMinID, nil
} }
func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, db.Error) { func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {
// Catch case of no statuses early // Catch case of no statuses early
if len(statusIDs) == 0 { if len(statusIDs) == 0 {
return nil, db.ErrNoEntries return nil, db.ErrNoEntries

View file

@ -260,7 +260,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
} }
noCache := &gtsmodel.Account{} noCache := &gtsmodel.Account{}
err = dbService.GetConn(). err = dbService.DB().
NewSelect(). NewSelect().
Model(noCache). Model(noCache).
Where("? = ?", bun.Ident("account.id"), testAccount.ID). Where("? = ?", bun.Ident("account.id"), testAccount.ID).
@ -288,7 +288,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
suite.Empty(updated.EmojiIDs) suite.Empty(updated.EmojiIDs)
suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
err = dbService.GetConn(). err = dbService.DB().
NewSelect(). NewSelect().
Model(noCache). Model(noCache).
Where("? = ?", bun.Ident("account.id"), testAccount.ID). Where("? = ?", bun.Ident("account.id"), testAccount.ID).

View file

@ -44,21 +44,21 @@ import (
const rsaKeyBits = 2048 const rsaKeyBits = 2048
type adminDB struct { type adminDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, error) {
q := a.conn. q := a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Column("account.id"). Column("account.id").
Where("? = ?", bun.Ident("account.username"), username). Where("? = ?", bun.Ident("account.username"), username).
Where("? IS NULL", bun.Ident("account.domain")) Where("? IS NULL", bun.Ident("account.domain"))
return a.conn.NotExists(ctx, q) return a.db.NotExists(ctx, q)
} }
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) { func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, error) {
// parse the domain from the email // parse the domain from the email
m, err := mail.ParseAddress(email) m, err := mail.ParseAddress(email)
if err != nil { if err != nil {
@ -67,12 +67,12 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked // check if the email domain is blocked
emailDomainBlockedQ := a.conn. emailDomainBlockedQ := a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
Column("email_domain_block.id"). Column("email_domain_block.id").
Where("? = ?", bun.Ident("email_domain_block.domain"), domain) Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ) emailDomainBlocked, err := a.db.Exists(ctx, emailDomainBlockedQ)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -81,16 +81,16 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
} }
// check if this email is associated with a user already // check if this email is associated with a user already
q := a.conn. q := a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
Column("user.id"). Column("user.id").
Where("? = ?", bun.Ident("user.email"), email). Where("? = ?", bun.Ident("user.email"), email).
WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
return a.conn.NotExists(ctx, q) return a.db.NotExists(ctx, q)
} }
func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, db.Error) { func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error) {
// If something went wrong previously while doing a new // If something went wrong previously while doing a new
// sign up with this username, we might already have an // sign up with this username, we might already have an
// account, so check first. // account, so check first.
@ -220,17 +220,17 @@ func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (
return user, nil return user, nil
} }
func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { func (a *adminDB) CreateInstanceAccount(ctx context.Context) error {
username := config.GetHost() username := config.GetHost()
q := a.conn. q := a.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Column("account.id"). Column("account.id").
Where("? = ?", bun.Ident("account.username"), username). Where("? = ?", bun.Ident("account.username"), username).
Where("? IS NULL", bun.Ident("account.domain")) Where("? IS NULL", bun.Ident("account.domain"))
exists, err := a.conn.Exists(ctx, q) exists, err := a.db.Exists(ctx, q)
if err != nil { if err != nil {
return err return err
} }
@ -277,18 +277,18 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return nil return nil
} }
func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { func (a *adminDB) CreateInstanceInstance(ctx context.Context) error {
protocol := config.GetProtocol() protocol := config.GetProtocol()
host := config.GetHost() host := config.GetHost()
// check if instance entry already exists // check if instance entry already exists
q := a.conn. q := a.db.
NewSelect(). NewSelect().
Column("instance.id"). Column("instance.id").
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
Where("? = ?", bun.Ident("instance.domain"), host) Where("? = ?", bun.Ident("instance.domain"), host)
exists, err := a.conn.Exists(ctx, q) exists, err := a.db.Exists(ctx, q)
if err != nil { if err != nil {
return err return err
} }
@ -309,13 +309,13 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
URI: fmt.Sprintf("%s://%s", protocol, host), URI: fmt.Sprintf("%s://%s", protocol, host),
} }
insertQ := a.conn. insertQ := a.db.
NewInsert(). NewInsert().
Model(i) Model(i)
_, err = insertQ.Exec(ctx) _, err = insertQ.Exec(ctx)
if err != nil { if err != nil {
return a.conn.ProcessError(err) return a.db.ProcessError(err)
} }
log.Infof(ctx, "created instance instance %s with id %s", host, i.ID) log.Infof(ctx, "created instance instance %s with id %s", host, i.ID)

View file

@ -28,99 +28,99 @@ import (
) )
type basicDB struct { type basicDB struct {
conn *DBConn db *WrappedDB
} }
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error { func (b *basicDB) Put(ctx context.Context, i interface{}) error {
_, err := b.conn.NewInsert().Model(i).Exec(ctx) _, err := b.db.NewInsert().Model(i).Exec(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error { func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) error {
q := b.conn. q := b.db.
NewSelect(). NewSelect().
Model(i). Model(i).
Where("id = ?", id) Where("id = ?", id)
err := q.Scan(ctx) err := q.Scan(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) error {
if len(where) == 0 { if len(where) == 0 {
return errors.New("no queries provided") return errors.New("no queries provided")
} }
q := b.conn.NewSelect().Model(i) q := b.db.NewSelect().Model(i)
selectWhere(q, where) selectWhere(q, where)
err := q.Scan(ctx) err := q.Scan(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { func (b *basicDB) GetAll(ctx context.Context, i interface{}) error {
q := b.conn. q := b.db.
NewSelect(). NewSelect().
Model(i) Model(i)
err := q.Scan(ctx) err := q.Scan(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error { func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) error {
q := b.conn. q := b.db.
NewDelete(). NewDelete().
Model(i). Model(i).
Where("id = ?", id) Where("id = ?", id)
_, err := q.Exec(ctx) _, err := q.Exec(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) error {
if len(where) == 0 { if len(where) == 0 {
return errors.New("no queries provided") return errors.New("no queries provided")
} }
q := b.conn. q := b.db.
NewDelete(). NewDelete().
Model(i) Model(i)
deleteWhere(q, where) deleteWhere(q, where)
_, err := q.Exec(ctx) _, err := q.Exec(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) db.Error { func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) error {
q := b.conn. q := b.db.
NewUpdate(). NewUpdate().
Model(i). Model(i).
Column(columns...). Column(columns...).
Where("? = ?", bun.Ident("id"), id) Where("? = ?", bun.Ident("id"), id)
_, err := q.Exec(ctx) _, err := q.Exec(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) error {
q := b.conn.NewUpdate().Model(i) q := b.db.NewUpdate().Model(i)
updateWhere(q, where) updateWhere(q, where)
q = q.Set("? = ?", bun.Ident(key), value) q = q.Set("? = ?", bun.Ident(key), value)
_, err := q.Exec(ctx) _, err := q.Exec(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error {
_, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx) _, err := b.db.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
return err return err
} }
func (b *basicDB) CreateAllTables(ctx context.Context) db.Error { func (b *basicDB) CreateAllTables(ctx context.Context) error {
models := []interface{}{ models := []interface{}{
&gtsmodel.Account{}, &gtsmodel.Account{},
&gtsmodel.Application{}, &gtsmodel.Application{},
@ -154,16 +154,16 @@ func (b *basicDB) CreateAllTables(ctx context.Context) db.Error {
return nil return nil
} }
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { func (b *basicDB) DropTable(ctx context.Context, i interface{}) error {
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) _, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx)
return b.conn.ProcessError(err) return b.db.ProcessError(err)
} }
func (b *basicDB) IsHealthy(ctx context.Context) db.Error { func (b *basicDB) IsHealthy(ctx context.Context) error {
return b.conn.PingContext(ctx) return b.db.DB.PingContext(ctx)
} }
func (b *basicDB) Stop(ctx context.Context) db.Error { func (b *basicDB) Stop(ctx context.Context) error {
log.Info(ctx, "closing db connection") log.Info(ctx, "closing db connection")
return b.conn.Close() return b.db.DB.Close()
} }

View file

@ -79,13 +79,13 @@ type DBService struct {
db.Timeline db.Timeline
db.User db.User
db.Tombstone db.Tombstone
conn *DBConn db *WrappedDB
} }
// GetConn returns the underlying bun connection. // GetDB returns the underlying database connection pool.
// Should only be used in testing + exceptional circumstance. // Should only be used in testing + exceptional circumstance.
func (dbService *DBService) GetConn() *DBConn { func (dbService *DBService) DB() *WrappedDB {
return dbService.conn return dbService.db
} }
func doMigration(ctx context.Context, db *bun.DB) error { func doMigration(ctx context.Context, db *bun.DB) error {
@ -112,18 +112,18 @@ func doMigration(ctx context.Context, db *bun.DB) error {
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
var conn *DBConn var db *WrappedDB
var err error var err error
t := strings.ToLower(config.GetDbType()) t := strings.ToLower(config.GetDbType())
switch t { switch t {
case "postgres": case "postgres":
conn, err = pgConn(ctx) db, err = pgConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
case "sqlite": case "sqlite":
conn, err = sqliteConn(ctx) db, err = sqliteConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -132,15 +132,15 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
} }
// Add database query hooks. // Add database query hooks.
conn.DB.AddQueryHook(queryHook{}) db.AddQueryHook(queryHook{})
if config.GetTracingEnabled() { if config.GetTracingEnabled() {
conn.DB.AddQueryHook(tracing.InstrumentBun()) db.AddQueryHook(tracing.InstrumentBun())
} }
// execute sqlite pragmas *after* adding database hook; // execute sqlite pragmas *after* adding database hook;
// this allows the pragma queries to be logged // this allows the pragma queries to be logged
if t == "sqlite" { if t == "sqlite" {
if err := sqlitePragmas(ctx, conn); err != nil { if err := sqlitePragmas(ctx, db); err != nil {
return nil, err return nil, err
} }
} }
@ -148,103 +148,103 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
// table registration is needed for many-to-many, see: // table registration is needed for many-to-many, see:
// https://bun.uptrace.dev/orm/many-to-many-relation/ // https://bun.uptrace.dev/orm/many-to-many-relation/
for _, t := range registerTables { for _, t := range registerTables {
conn.RegisterModel(t) db.RegisterModel(t)
} }
// perform any pending database migrations: this includes // perform any pending database migrations: this includes
// the very first 'migration' on startup which just creates // the very first 'migration' on startup which just creates
// necessary tables // necessary tables
if err := doMigration(ctx, conn.DB); err != nil { if err := doMigration(ctx, db.DB); err != nil {
return nil, fmt.Errorf("db migration error: %s", err) return nil, fmt.Errorf("db migration error: %s", err)
} }
ps := &DBService{ ps := &DBService{
Account: &accountDB{ Account: &accountDB{
conn: conn, db: db,
state: state, state: state,
}, },
Admin: &adminDB{ Admin: &adminDB{
conn: conn, db: db,
state: state, state: state,
}, },
Basic: &basicDB{ Basic: &basicDB{
conn: conn, db: db,
}, },
Domain: &domainDB{ Domain: &domainDB{
conn: conn, db: db,
state: state, state: state,
}, },
Emoji: &emojiDB{ Emoji: &emojiDB{
conn: conn, db: db,
state: state, state: state,
}, },
Instance: &instanceDB{ Instance: &instanceDB{
conn: conn, db: db,
state: state, state: state,
}, },
List: &listDB{ List: &listDB{
conn: conn, db: db,
state: state, state: state,
}, },
Media: &mediaDB{ Media: &mediaDB{
conn: conn, db: db,
state: state, state: state,
}, },
Mention: &mentionDB{ Mention: &mentionDB{
conn: conn, db: db,
state: state, state: state,
}, },
Notification: &notificationDB{ Notification: &notificationDB{
conn: conn, db: db,
state: state, state: state,
}, },
Relationship: &relationshipDB{ Relationship: &relationshipDB{
conn: conn, db: db,
state: state, state: state,
}, },
Report: &reportDB{ Report: &reportDB{
conn: conn, db: db,
state: state, state: state,
}, },
Search: &searchDB{ Search: &searchDB{
conn: conn, db: db,
state: state, state: state,
}, },
Session: &sessionDB{ Session: &sessionDB{
conn: conn, db: db,
}, },
Status: &statusDB{ Status: &statusDB{
conn: conn, db: db,
state: state, state: state,
}, },
StatusBookmark: &statusBookmarkDB{ StatusBookmark: &statusBookmarkDB{
conn: conn, db: db,
state: state, state: state,
}, },
StatusFave: &statusFaveDB{ StatusFave: &statusFaveDB{
conn: conn, db: db,
state: state, state: state,
}, },
Timeline: &timelineDB{ Timeline: &timelineDB{
conn: conn, db: db,
state: state, state: state,
}, },
User: &userDB{ User: &userDB{
conn: conn, db: db,
state: state, state: state,
}, },
Tombstone: &tombstoneDB{ Tombstone: &tombstoneDB{
conn: conn, db: db,
state: state, state: state,
}, },
conn: conn, db: db,
} }
// we can confidently return this useable service now // we can confidently return this useable service now
return ps, nil return ps, nil
} }
func pgConn(ctx context.Context) (*DBConn, error) { func pgConn(ctx context.Context) (*WrappedDB, error) {
opts, err := deriveBunDBPGOptions() //nolint:contextcheck opts, err := deriveBunDBPGOptions() //nolint:contextcheck
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create bundb postgres options: %s", err) return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
@ -259,10 +259,10 @@ func pgConn(ctx context.Context) (*DBConn, error) {
sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted
sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections
conn := WrapDBConn(bun.NewDB(sqldb, pgdialect.New())) conn := WrapDB(bun.NewDB(sqldb, pgdialect.New()))
// ping to check the db is there and listening // ping to check the db is there and listening
if err := conn.PingContext(ctx); err != nil { if err := conn.DB.PingContext(ctx); err != nil {
return nil, fmt.Errorf("postgres ping: %s", err) return nil, fmt.Errorf("postgres ping: %s", err)
} }
@ -270,7 +270,7 @@ func pgConn(ctx context.Context) (*DBConn, error) {
return conn, nil return conn, nil
} }
func sqliteConn(ctx context.Context) (*DBConn, error) { func sqliteConn(ctx context.Context) (*WrappedDB, error) {
// validate db address has actually been set // validate db address has actually been set
address := config.GetDbAddress() address := config.GetDbAddress()
if address == "" { if address == "" {
@ -326,15 +326,15 @@ func sqliteConn(ctx context.Context) (*DBConn, error) {
// Tune db connections for sqlite, see: // Tune db connections for sqlite, see:
// - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql // - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql
// - https://www.alexedwards.net/blog/configuring-sqldb // - https://www.alexedwards.net/blog/configuring-sqldb
sqldb.SetMaxOpenConns(1) // only 1 connection regardless of multiplier, see https://github.com/superseriousbusiness/gotosocial/issues/1407 sqldb.SetMaxOpenConns(maxOpenConns()) // x number of conns per CPU
sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around
sqldb.SetConnMaxLifetime(0) // don't kill connections due to age sqldb.SetConnMaxLifetime(0) // don't kill connections due to age
// Wrap Bun database conn in our own wrapper // Wrap Bun database conn in our own wrapper
conn := WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New())) conn := WrapDB(bun.NewDB(sqldb, sqlitedialect.New()))
// ping to check the db is there and listening // ping to check the db is there and listening
if err := conn.PingContext(ctx); err != nil { if err := conn.DB.PingContext(ctx); err != nil {
if errWithCode, ok := err.(*sqlite.Error); ok { if errWithCode, ok := err.(*sqlite.Error); ok {
err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
} }
@ -445,7 +445,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
// sqlitePragmas sets desired sqlite pragmas based on configured values, and // sqlitePragmas sets desired sqlite pragmas based on configured values, and
// logs the results of the pragma queries. Errors if something goes wrong. // logs the results of the pragma queries. Errors if something goes wrong.
func sqlitePragmas(ctx context.Context, conn *DBConn) error { func sqlitePragmas(ctx context.Context, db *WrappedDB) error {
var pragmas [][]string var pragmas [][]string
if mode := config.GetDbSqliteJournalMode(); mode != "" { if mode := config.GetDbSqliteJournalMode(); mode != "" {
// Set the user provided SQLite journal mode // Set the user provided SQLite journal mode
@ -475,12 +475,12 @@ func sqlitePragmas(ctx context.Context, conn *DBConn) error {
pk := p[0] pk := p[0]
pv := p[1] pv := p[1]
if _, err := conn.DB.ExecContext(ctx, "PRAGMA ?=?", bun.Ident(pk), bun.Safe(pv)); err != nil { if _, err := db.ExecContext(ctx, "PRAGMA ?=?", bun.Ident(pk), bun.Safe(pv)); err != nil {
return fmt.Errorf("error executing sqlite pragma %s: %w", pk, err) return fmt.Errorf("error executing sqlite pragma %s: %w", pk, err)
} }
var res string var res string
if err := conn.DB.NewRaw("PRAGMA ?", bun.Ident(pk)).Scan(ctx, &res); err != nil { if err := db.NewRaw("PRAGMA ?", bun.Ident(pk)).Scan(ctx, &res); err != nil {
return fmt.Errorf("error scanning sqlite pragma %s: %w", pv, err) return fmt.Errorf("error scanning sqlite pragma %s: %w", pv, err)
} }
@ -502,7 +502,7 @@ func (dbService *DBService) TagStringToTag(ctx context.Context, t string, origin
tag := &gtsmodel.Tag{} tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already // we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created // inserted will be true if this is a new tag we just created
if err := dbService.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil && err != sql.ErrNoRows { if err := dbService.db.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("error getting tag with name %s: %s", t, err) return nil, fmt.Errorf("error getting tag with name %s: %s", t, err)
} }

View file

@ -1,113 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
// DBConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
type DBConn struct {
errProc func(error) db.Error // errProc is the SQL-type specific error processor
*bun.DB // DB is the underlying bun.DB connection
}
// WrapDBConn wraps a bun DB connection to provide our own error processing dependent on DB dialect.
func WrapDBConn(dbConn *bun.DB) *DBConn {
var errProc func(error) db.Error
switch dbConn.Dialect().Name() {
case dialect.PG:
errProc = processPostgresError
case dialect.SQLite:
errProc = processSQLiteError
default:
panic("unknown dialect name: " + dbConn.Dialect().Name().String())
}
return &DBConn{
errProc: errProc,
DB: dbConn,
}
}
// RunInTx wraps execution of the supplied transaction function.
func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error {
return conn.ProcessError(func() error {
// Acquire a new transaction
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return err
}
var done bool
defer func() {
if !done {
_ = tx.Rollback()
}
}()
// Perform supplied transaction
if err := fn(tx); err != nil {
return err
}
// Finally, commit
err = tx.Commit() //nolint:contextcheck
done = true
return err
}())
}
// ProcessError processes an error to replace any known values with our own db.Error types,
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
func (conn *DBConn) ProcessError(err error) db.Error {
switch {
case err == nil:
return nil
case err == sql.ErrNoRows:
return db.ErrNoEntries
default:
return conn.errProc(err)
}
}
// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
exists, err := query.Exists(ctx)
// Process error as our own and check if it exists
switch err := conn.ProcessError(err); err {
case nil:
return exists, nil
case db.ErrNoEntries:
return false, nil
default:
return false, err
}
}
// NotExists is the functional opposite of conn.Exists()
func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
exists, err := conn.Exists(ctx, query)
return !exists, err
}

View file

@ -30,11 +30,11 @@ import (
) )
type domainDB struct { type domainDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error {
// Normalize the domain as punycode // Normalize the domain as punycode
var err error var err error
block.Domain, err = util.Punify(block.Domain) block.Domain, err = util.Punify(block.Domain)
@ -43,10 +43,10 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
} }
// Attempt to store domain block in DB // Attempt to store domain block in DB
if _, err := d.conn.NewInsert(). if _, err := d.db.NewInsert().
Model(block). Model(block).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return d.conn.ProcessError(err) return d.db.ProcessError(err)
} }
// Clear the domain block cache (for later reload) // Clear the domain block cache (for later reload)
@ -55,7 +55,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
return nil return nil
} }
func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error) {
// Normalize the domain as punycode // Normalize the domain as punycode
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
@ -71,12 +71,12 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
var block gtsmodel.DomainBlock var block gtsmodel.DomainBlock
// Look for block matching domain in DB // Look for block matching domain in DB
q := d.conn. q := d.db.
NewSelect(). NewSelect().
Model(&block). Model(&block).
Where("? = ?", bun.Ident("domain_block.domain"), domain) Where("? = ?", bun.Ident("domain_block.domain"), domain)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, d.conn.ProcessError(err) return nil, d.db.ProcessError(err)
} }
return &block, nil return &block, nil
@ -85,31 +85,31 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) { func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) {
blocks := []*gtsmodel.DomainBlock{} blocks := []*gtsmodel.DomainBlock{}
if err := d.conn. if err := d.db.
NewSelect(). NewSelect().
Model(&blocks). Model(&blocks).
Scan(ctx); err != nil { Scan(ctx); err != nil {
return nil, d.conn.ProcessError(err) return nil, d.db.ProcessError(err)
} }
return blocks, nil return blocks, nil
} }
func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, db.Error) { func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, error) {
var block gtsmodel.DomainBlock var block gtsmodel.DomainBlock
q := d.conn. q := d.db.
NewSelect(). NewSelect().
Model(&block). Model(&block).
Where("? = ?", bun.Ident("domain_block.id"), id) Where("? = ?", bun.Ident("domain_block.id"), id)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, d.conn.ProcessError(err) return nil, d.db.ProcessError(err)
} }
return &block, nil return &block, nil
} }
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
// Normalize the domain as punycode // Normalize the domain as punycode
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
@ -117,11 +117,11 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
} }
// Attempt to delete domain block // Attempt to delete domain block
if _, err := d.conn.NewDelete(). if _, err := d.db.NewDelete().
Model((*gtsmodel.DomainBlock)(nil)). Model((*gtsmodel.DomainBlock)(nil)).
Where("? = ?", bun.Ident("domain_block.domain"), domain). Where("? = ?", bun.Ident("domain_block.domain"), domain).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return d.conn.ProcessError(err) return d.db.ProcessError(err)
} }
// Clear the domain block cache (for later reload) // Clear the domain block cache (for later reload)
@ -130,7 +130,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
return nil return nil
} }
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, error) {
// Normalize the domain as punycode // Normalize the domain as punycode
domain, err := util.Punify(domain) domain, err := util.Punify(domain)
if err != nil { if err != nil {
@ -148,18 +148,18 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
var domains []string var domains []string
// Scan list of all blocked domains from DB // Scan list of all blocked domains from DB
q := d.conn.NewSelect(). q := d.db.NewSelect().
Table("domain_blocks"). Table("domain_blocks").
Column("domain") Column("domain")
if err := q.Scan(ctx, &domains); err != nil { if err := q.Scan(ctx, &domains); err != nil {
return nil, d.conn.ProcessError(err) return nil, d.db.ProcessError(err)
} }
return domains, nil return domains, nil
}) })
} }
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error) {
for _, domain := range domains { for _, domain := range domains {
if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil { if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
return false, err return false, err
@ -170,11 +170,11 @@ func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (boo
return false, nil return false, nil
} }
func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) { func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, error) {
return d.IsDomainBlocked(ctx, uri.Hostname()) return d.IsDomainBlocked(ctx, uri.Hostname())
} }
func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) { func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, error) {
for _, uri := range uris { for _, uri := range uris {
if blocked, err := d.IsDomainBlocked(ctx, uri.Hostname()); err != nil { if blocked, err := d.IsDomainBlocked(ctx, uri.Hostname()); err != nil {
return false, err return false, err

View file

@ -34,14 +34,14 @@ import (
) )
type emojiDB struct { type emojiDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error { func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error {
return e.state.Caches.GTS.Emoji().Store(emoji, func() error { return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
_, err := e.conn.NewInsert().Model(emoji).Exec(ctx) _, err := e.db.NewInsert().Model(emoji).Exec(ctx)
return e.conn.ProcessError(err) return e.db.ProcessError(err)
}) })
} }
@ -54,17 +54,17 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column
// Update the emoji model in the database. // Update the emoji model in the database.
return e.state.Caches.GTS.Emoji().Store(emoji, func() error { return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
_, err := e.conn. _, err := e.db.
NewUpdate(). NewUpdate().
Model(emoji). Model(emoji).
Where("? = ?", bun.Ident("emoji.id"), emoji.ID). Where("? = ?", bun.Ident("emoji.id"), emoji.ID).
Column(columns...). Column(columns...).
Exec(ctx) Exec(ctx)
return e.conn.ProcessError(err) return e.db.ProcessError(err)
}) })
} }
func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error { func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
var ( var (
accountIDs []string accountIDs []string
statusIDs []string statusIDs []string
@ -105,7 +105,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
return err return err
} }
return e.conn.RunInTx(ctx, func(tx bun.Tx) error { return e.db.RunInTx(ctx, func(tx bun.Tx) error {
// delete links between this emoji and any statuses that use it // delete links between this emoji and any statuses that use it
// TODO: remove when we delete this table // TODO: remove when we delete this table
if _, err := tx. if _, err := tx.
@ -229,7 +229,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) db.Error {
func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, error) { func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, error) {
emojiIDs := []string{} emojiIDs := []string{}
subQuery := e.conn. subQuery := e.db.
NewSelect(). NewSelect().
ColumnExpr("? AS ?", bun.Ident("emoji.id"), bun.Ident("emoji_ids")) ColumnExpr("? AS ?", bun.Ident("emoji.id"), bun.Ident("emoji_ids"))
@ -255,7 +255,7 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable
// "emojis" AS "emoji" // "emojis" AS "emoji"
// ORDER BY // ORDER BY
// "shortcode_domain" ASC // "shortcode_domain" ASC
switch e.conn.Dialect().Name() { switch e.db.Dialect().Name() {
case dialect.SQLite: case dialect.SQLite:
subQuery = subQuery.ColumnExpr("LOWER(? || ? || COALESCE(?, ?)) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain")) subQuery = subQuery.ColumnExpr("LOWER(? || ? || COALESCE(?, ?)) AS ?", bun.Ident("emoji.shortcode"), "@", bun.Ident("emoji.domain"), "", bun.Ident("shortcode_domain"))
case dialect.PG: case dialect.PG:
@ -321,12 +321,12 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable
// ORDER BY // ORDER BY
// "shortcode_domain" ASC // "shortcode_domain" ASC
// ) AS "subquery" // ) AS "subquery"
if err := e.conn. if err := e.db.
NewSelect(). NewSelect().
Column("subquery.emoji_ids"). Column("subquery.emoji_ids").
TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")). TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")).
Scan(ctx, &emojiIDs); err != nil { Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.db.ProcessError(err)
} }
if order == "DESC" { if order == "DESC" {
@ -346,7 +346,7 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable
func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) { func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) {
var emojiIDs []string var emojiIDs []string
q := e.conn.NewSelect(). q := e.db.NewSelect().
Table("emojis"). Table("emojis").
Column("id"). Column("id").
Order("id DESC") Order("id DESC")
@ -360,7 +360,7 @@ func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gt
} }
if err := q.Scan(ctx, &emojiIDs); err != nil { if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.db.ProcessError(err)
} }
return e.GetEmojisByIDs(ctx, emojiIDs) return e.GetEmojisByIDs(ctx, emojiIDs)
@ -369,7 +369,7 @@ func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gt
func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) { func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) {
var emojiIDs []string var emojiIDs []string
q := e.conn.NewSelect(). q := e.db.NewSelect().
Table("emojis"). Table("emojis").
Column("id"). Column("id").
Where("domain IS NOT NULL"). Where("domain IS NOT NULL").
@ -384,7 +384,7 @@ func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int)
} }
if err := q.Scan(ctx, &emojiIDs); err != nil { if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.db.ProcessError(err)
} }
return e.GetEmojisByIDs(ctx, emojiIDs) return e.GetEmojisByIDs(ctx, emojiIDs)
@ -393,7 +393,7 @@ func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int)
func (e *emojiDB) GetCachedEmojisOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.Emoji, error) { func (e *emojiDB) GetCachedEmojisOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.Emoji, error) {
var emojiIDs []string var emojiIDs []string
q := e.conn.NewSelect(). q := e.db.NewSelect().
Table("emojis"). Table("emojis").
Column("id"). Column("id").
Where("cached = true"). Where("cached = true").
@ -406,16 +406,16 @@ func (e *emojiDB) GetCachedEmojisOlderThan(ctx context.Context, olderThan time.T
} }
if err := q.Scan(ctx, &emojiIDs); err != nil { if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.db.ProcessError(err)
} }
return e.GetEmojisByIDs(ctx, emojiIDs) return e.GetEmojisByIDs(ctx, emojiIDs)
} }
func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, error) {
emojiIDs := []string{} emojiIDs := []string{}
q := e.conn. q := e.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")). TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
Column("emoji.id"). Column("emoji.id").
@ -425,18 +425,18 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.E
Order("emoji.shortcode ASC") Order("emoji.shortcode ASC")
if err := q.Scan(ctx, &emojiIDs); err != nil { if err := q.Scan(ctx, &emojiIDs); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.db.ProcessError(err)
} }
return e.GetEmojisByIDs(ctx, emojiIDs) return e.GetEmojisByIDs(ctx, emojiIDs)
} }
func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
"ID", "ID",
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
return e.conn. return e.db.
NewSelect(). NewSelect().
Model(emoji). Model(emoji).
Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx) Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
@ -445,12 +445,12 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,
) )
} }
func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
"URI", "URI",
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
return e.conn. return e.db.
NewSelect(). NewSelect().
Model(emoji). Model(emoji).
Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx) Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
@ -459,12 +459,12 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
) )
} }
func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
"Shortcode.Domain", "Shortcode.Domain",
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
q := e.conn. q := e.db.
NewSelect(). NewSelect().
Model(emoji) Model(emoji)
@ -483,12 +483,12 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
) )
} }
func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, error) {
return e.getEmoji( return e.getEmoji(
ctx, ctx,
"ImageStaticURL", "ImageStaticURL",
func(emoji *gtsmodel.Emoji) error { func(emoji *gtsmodel.Emoji) error {
return e.conn. return e.db.
NewSelect(). NewSelect().
Model(emoji). Model(emoji).
Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL). Where("? = ?", bun.Ident("emoji.image_static_url"), imageStaticURL).
@ -498,35 +498,35 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string
) )
} }
func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) db.Error { func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error {
return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error { return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error {
_, err := e.conn.NewInsert().Model(emojiCategory).Exec(ctx) _, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx)
return e.conn.ProcessError(err) return e.db.ProcessError(err)
}) })
} }
func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, error) {
emojiCategoryIDs := []string{} emojiCategoryIDs := []string{}
q := e.conn. q := e.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("emoji_categories"), bun.Ident("emoji_category")). TableExpr("? AS ?", bun.Ident("emoji_categories"), bun.Ident("emoji_category")).
Column("emoji_category.id"). Column("emoji_category.id").
Order("emoji_category.name ASC") Order("emoji_category.name ASC")
if err := q.Scan(ctx, &emojiCategoryIDs); err != nil { if err := q.Scan(ctx, &emojiCategoryIDs); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.db.ProcessError(err)
} }
return e.GetEmojiCategoriesByIDs(ctx, emojiCategoryIDs) return e.GetEmojiCategoriesByIDs(ctx, emojiCategoryIDs)
} }
func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, error) {
return e.getEmojiCategory( return e.getEmojiCategory(
ctx, ctx,
"ID", "ID",
func(emojiCategory *gtsmodel.EmojiCategory) error { func(emojiCategory *gtsmodel.EmojiCategory) error {
return e.conn. return e.db.
NewSelect(). NewSelect().
Model(emojiCategory). Model(emojiCategory).
Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx) Where("? = ?", bun.Ident("emoji_category.id"), id).Scan(ctx)
@ -535,12 +535,12 @@ func (e *emojiDB) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.Em
) )
} }
func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) {
return e.getEmojiCategory( return e.getEmojiCategory(
ctx, ctx,
"Name", "Name",
func(emojiCategory *gtsmodel.EmojiCategory) error { func(emojiCategory *gtsmodel.EmojiCategory) error {
return e.conn. return e.db.
NewSelect(). NewSelect().
Model(emojiCategory). Model(emojiCategory).
Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx) Where("LOWER(?) = ?", bun.Ident("emoji_category.name"), strings.ToLower(name)).Scan(ctx)
@ -549,14 +549,14 @@ func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gts
) )
} }
func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, db.Error) { func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) {
// Fetch emoji from database cache with loader callback // Fetch emoji from database cache with loader callback
emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) { emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) {
var emoji gtsmodel.Emoji var emoji gtsmodel.Emoji
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(&emoji); err != nil { if err := dbQuery(&emoji); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.db.ProcessError(err)
} }
return &emoji, nil return &emoji, nil
@ -580,7 +580,7 @@ func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gts
return emoji, nil return emoji, nil
} }
func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) { func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) {
if len(emojiIDs) == 0 { if len(emojiIDs) == 0 {
return nil, db.ErrNoEntries return nil, db.ErrNoEntries
} }
@ -600,20 +600,20 @@ func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gts
return emojis, nil return emojis, nil
} }
func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) {
return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) { return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) {
var category gtsmodel.EmojiCategory var category gtsmodel.EmojiCategory
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(&category); err != nil { if err := dbQuery(&category); err != nil {
return nil, e.conn.ProcessError(err) return nil, e.db.ProcessError(err)
} }
return &category, nil return &category, nil
}, keyParts...) }, keyParts...)
} }
func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, db.Error) { func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, error) {
if len(emojiCategoryIDs) == 0 { if len(emojiCategoryIDs) == 0 {
return nil, db.ErrNoEntries return nil, db.ErrNoEntries
} }

View file

@ -18,14 +18,20 @@
package bundb package bundb
import ( import (
"errors"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"modernc.org/sqlite" "modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib" sqlite3 "modernc.org/sqlite/lib"
) )
// errBusy is a sentinel error indicating
// busy database (e.g. retry needed).
var errBusy = errors.New("busy")
// processPostgresError processes an error, replacing any postgres specific errors with our own error type // processPostgresError processes an error, replacing any postgres specific errors with our own error type
func processPostgresError(err error) db.Error { func processPostgresError(err error) error {
// Attempt to cast as postgres // Attempt to cast as postgres
pgErr, ok := err.(*pgconn.PgError) pgErr, ok := err.(*pgconn.PgError)
if !ok { if !ok {
@ -34,16 +40,16 @@ func processPostgresError(err error) db.Error {
// Handle supplied error code: // Handle supplied error code:
// (https://www.postgresql.org/docs/10/errcodes-appendix.html) // (https://www.postgresql.org/docs/10/errcodes-appendix.html)
switch pgErr.Code { switch pgErr.Code { //nolint
case "23505" /* unique_violation */ : case "23505" /* unique_violation */ :
return db.ErrAlreadyExists return db.ErrAlreadyExists
default:
return err
} }
return err
} }
// processSQLiteError processes an error, replacing any sqlite specific errors with our own error type // processSQLiteError processes an error, replacing any sqlite specific errors with our own error type
func processSQLiteError(err error) db.Error { func processSQLiteError(err error) error {
// Attempt to cast as sqlite // Attempt to cast as sqlite
sqliteErr, ok := err.(*sqlite.Error) sqliteErr, ok := err.(*sqlite.Error)
if !ok { if !ok {
@ -55,7 +61,11 @@ func processSQLiteError(err error) db.Error {
case sqlite3.SQLITE_CONSTRAINT_UNIQUE, case sqlite3.SQLITE_CONSTRAINT_UNIQUE,
sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY: sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY:
return db.ErrAlreadyExists return db.ErrAlreadyExists
default: case sqlite3.SQLITE_BUSY:
return err return errBusy
case sqlite3.SQLITE_BUSY_TIMEOUT:
return db.ErrBusyTimeout
} }
return err
} }

View file

@ -34,12 +34,12 @@ import (
) )
type instanceDB struct { type instanceDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, error) {
q := i.conn. q := i.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
Column("account.id"). Column("account.id").
@ -56,13 +56,13 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
count, err := q.Count(ctx) count, err := q.Count(ctx)
if err != nil { if err != nil {
return 0, i.conn.ProcessError(err) return 0, i.db.ProcessError(err)
} }
return count, nil return count, nil
} }
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, error) {
q := i.conn. q := i.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")) TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status"))
@ -78,13 +78,13 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
count, err := q.Count(ctx) count, err := q.Count(ctx)
if err != nil { if err != nil {
return 0, i.conn.ProcessError(err) return 0, i.db.ProcessError(err)
} }
return count, nil return count, nil
} }
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, error) {
q := i.conn. q := i.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")) TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance"))
@ -101,12 +101,12 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
count, err := q.Count(ctx) count, err := q.Count(ctx)
if err != nil { if err != nil {
return 0, i.conn.ProcessError(err) return 0, i.db.ProcessError(err)
} }
return count, nil return count, nil
} }
func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) { func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, error) {
// Normalize the domain as punycode // Normalize the domain as punycode
var err error var err error
domain, err = util.Punify(domain) domain, err = util.Punify(domain)
@ -118,7 +118,7 @@ func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.
ctx, ctx,
"Domain", "Domain",
func(instance *gtsmodel.Instance) error { func(instance *gtsmodel.Instance) error {
return i.conn.NewSelect(). return i.db.NewSelect().
Model(instance). Model(instance).
Where("? = ?", bun.Ident("instance.domain"), domain). Where("? = ?", bun.Ident("instance.domain"), domain).
Scan(ctx) Scan(ctx)
@ -132,7 +132,7 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.
ctx, ctx,
"ID", "ID",
func(instance *gtsmodel.Instance) error { func(instance *gtsmodel.Instance) error {
return i.conn.NewSelect(). return i.db.NewSelect().
Model(instance). Model(instance).
Where("? = ?", bun.Ident("instance.id"), id). Where("? = ?", bun.Ident("instance.id"), id).
Scan(ctx) Scan(ctx)
@ -141,14 +141,14 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.
) )
} }
func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, db.Error) { func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) {
// Fetch instance from database cache with loader callback // Fetch instance from database cache with loader callback
instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) { instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) {
var instance gtsmodel.Instance var instance gtsmodel.Instance
// Not cached! Perform database query. // Not cached! Perform database query.
if err := dbQuery(&instance); err != nil { if err := dbQuery(&instance); err != nil {
return nil, i.conn.ProcessError(err) return nil, i.db.ProcessError(err)
} }
return &instance, nil return &instance, nil
@ -210,8 +210,8 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc
} }
return i.state.Caches.GTS.Instance().Store(instance, func() error { return i.state.Caches.GTS.Instance().Store(instance, func() error {
_, err := i.conn.NewInsert().Model(instance).Exec(ctx) _, err := i.db.NewInsert().Model(instance).Exec(ctx)
return i.conn.ProcessError(err) return i.db.ProcessError(err)
}) })
} }
@ -230,20 +230,20 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst
} }
return i.state.Caches.GTS.Instance().Store(instance, func() error { return i.state.Caches.GTS.Instance().Store(instance, func() error {
_, err := i.conn. _, err := i.db.
NewUpdate(). NewUpdate().
Model(instance). Model(instance).
Where("? = ?", bun.Ident("instance.id"), instance.ID). Where("? = ?", bun.Ident("instance.id"), instance.ID).
Column(columns...). Column(columns...).
Exec(ctx) Exec(ctx)
return i.conn.ProcessError(err) return i.db.ProcessError(err)
}) })
} }
func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) { func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, error) {
instanceIDs := []string{} instanceIDs := []string{}
q := i.conn. q := i.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
// Select just the IDs of each instance. // Select just the IDs of each instance.
@ -256,7 +256,7 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
} }
if err := q.Scan(ctx, &instanceIDs); err != nil { if err := q.Scan(ctx, &instanceIDs); err != nil {
return nil, i.conn.ProcessError(err) return nil, i.db.ProcessError(err)
} }
if len(instanceIDs) == 0 { if len(instanceIDs) == 0 {
@ -280,7 +280,7 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
return instances, nil return instances, nil
} }
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
limit = 0 limit = 0
@ -296,7 +296,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
// Make educated guess for slice size // Make educated guess for slice size
accountIDs := make([]string, 0, limit) accountIDs := make([]string, 0, limit)
q := i.conn. q := i.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
// Select just the account ID. // Select just the account ID.
@ -315,7 +315,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
} }
if err := q.Scan(ctx, &accountIDs); err != nil { if err := q.Scan(ctx, &accountIDs); err != nil {
return nil, i.conn.ProcessError(err) return nil, i.db.ProcessError(err)
} }
// Catch case of no accounts early. // Catch case of no accounts early.
@ -340,13 +340,13 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
return accounts, nil return accounts, nil
} }
func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]string, db.Error) { func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]string, error) {
addresses := []string{} addresses := []string{}
// Select email addresses of approved, confirmed, // Select email addresses of approved, confirmed,
// and enabled moderators or admins. // and enabled moderators or admins.
q := i.conn. q := i.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
Column("user.email"). Column("user.email").
@ -361,7 +361,7 @@ func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]strin
OrderExpr("? ASC", bun.Ident("user.email")) OrderExpr("? ASC", bun.Ident("user.email"))
if err := q.Scan(ctx, &addresses); err != nil { if err := q.Scan(ctx, &addresses); err != nil {
return nil, i.conn.ProcessError(err) return nil, i.db.ProcessError(err)
} }
if len(addresses) == 0 { if len(addresses) == 0 {

View file

@ -33,7 +33,7 @@ import (
) )
type listDB struct { type listDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
@ -46,7 +46,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er
ctx, ctx,
"ID", "ID",
func(list *gtsmodel.List) error { func(list *gtsmodel.List) error {
return l.conn.NewSelect(). return l.db.NewSelect().
Model(list). Model(list).
Where("? = ?", bun.Ident("list.id"), id). Where("? = ?", bun.Ident("list.id"), id).
Scan(ctx) Scan(ctx)
@ -61,7 +61,7 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo
// Not cached! Perform database query. // Not cached! Perform database query.
if err := dbQuery(&list); err != nil { if err := dbQuery(&list); err != nil {
return nil, l.conn.ProcessError(err) return nil, l.db.ProcessError(err)
} }
return &list, nil return &list, nil
@ -86,14 +86,14 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo
func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) { func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) {
// Fetch IDs of all lists owned by this account. // Fetch IDs of all lists owned by this account.
var listIDs []string var listIDs []string
if err := l.conn. if err := l.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")). TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")).
Column("list.id"). Column("list.id").
Where("? = ?", bun.Ident("list.account_id"), accountID). Where("? = ?", bun.Ident("list.account_id"), accountID).
Order("list.id DESC"). Order("list.id DESC").
Scan(ctx, &listIDs); err != nil { Scan(ctx, &listIDs); err != nil {
return nil, l.conn.ProcessError(err) return nil, l.db.ProcessError(err)
} }
if len(listIDs) == 0 { if len(listIDs) == 0 {
@ -148,8 +148,8 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
return l.state.Caches.GTS.List().Store(list, func() error { return l.state.Caches.GTS.List().Store(list, func() error {
_, err := l.conn.NewInsert().Model(list).Exec(ctx) _, err := l.db.NewInsert().Model(list).Exec(ctx)
return l.conn.ProcessError(err) return l.db.ProcessError(err)
}) })
} }
@ -171,12 +171,12 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
}() }()
return l.state.Caches.GTS.List().Store(list, func() error { return l.state.Caches.GTS.List().Store(list, func() error {
_, err := l.conn.NewUpdate(). _, err := l.db.NewUpdate().
Model(list). Model(list).
Where("? = ?", bun.Ident("list.id"), list.ID). Where("? = ?", bun.Ident("list.id"), list.ID).
Column(columns...). Column(columns...).
Exec(ctx) Exec(ctx)
return l.conn.ProcessError(err) return l.db.ProcessError(err)
}) })
} }
@ -207,7 +207,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
} }
}() }()
return l.conn.RunInTx(ctx, func(tx bun.Tx) error { return l.db.RunInTx(ctx, func(tx bun.Tx) error {
// Delete all entries attached to list. // Delete all entries attached to list.
if _, err := tx.NewDelete(). if _, err := tx.NewDelete().
Table("list_entries"). Table("list_entries").
@ -234,7 +234,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis
ctx, ctx,
"ID", "ID",
func(listEntry *gtsmodel.ListEntry) error { func(listEntry *gtsmodel.ListEntry) error {
return l.conn.NewSelect(). return l.db.NewSelect().
Model(listEntry). Model(listEntry).
Where("? = ?", bun.Ident("list_entry.id"), id). Where("? = ?", bun.Ident("list_entry.id"), id).
Scan(ctx) Scan(ctx)
@ -249,7 +249,7 @@ func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*
// Not cached! Perform database query. // Not cached! Perform database query.
if err := dbQuery(&listEntry); err != nil { if err := dbQuery(&listEntry); err != nil {
return nil, l.conn.ProcessError(err) return nil, l.db.ProcessError(err)
} }
return &listEntry, nil return &listEntry, nil
@ -289,7 +289,7 @@ func (l *listDB) GetListEntries(ctx context.Context,
frontToBack = true frontToBack = true
) )
q := l.conn. q := l.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
// Select only IDs from table // Select only IDs from table
@ -329,7 +329,7 @@ func (l *listDB) GetListEntries(ctx context.Context,
} }
if err := q.Scan(ctx, &entryIDs); err != nil { if err := q.Scan(ctx, &entryIDs); err != nil {
return nil, l.conn.ProcessError(err) return nil, l.db.ProcessError(err)
} }
if len(entryIDs) == 0 { if len(entryIDs) == 0 {
@ -362,7 +362,7 @@ func (l *listDB) GetListEntries(ctx context.Context,
func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
var entryIDs []string var entryIDs []string
if err := l.conn. if err := l.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
// Select only IDs from table // Select only IDs from table
@ -370,7 +370,7 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string)
// Select only entries belonging with given followID. // Select only entries belonging with given followID.
Where("? = ?", bun.Ident("entry.follow_id"), followID). Where("? = ?", bun.Ident("entry.follow_id"), followID).
Scan(ctx, &entryIDs); err != nil { Scan(ctx, &entryIDs); err != nil {
return nil, l.conn.ProcessError(err) return nil, l.db.ProcessError(err)
} }
if len(entryIDs) == 0 { if len(entryIDs) == 0 {
@ -424,7 +424,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt
}() }()
// Finally, insert each list entry into the database. // Finally, insert each list entry into the database.
return l.conn.RunInTx(ctx, func(tx bun.Tx) error { return l.db.RunInTx(ctx, func(tx bun.Tx) error {
for _, entry := range entries { for _, entry := range entries {
if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error { if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error {
_, err := tx. _, err := tx.
@ -468,7 +468,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
}() }()
// Finally delete the list entry. // Finally delete the list entry.
_, err = l.conn.NewDelete(). _, err = l.db.NewDelete().
Table("list_entries"). Table("list_entries").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Exec(ctx) Exec(ctx)
@ -479,14 +479,14 @@ func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID stri
var entryIDs []string var entryIDs []string
// Fetch entry IDs for follow ID. // Fetch entry IDs for follow ID.
if err := l.conn. if err := l.db.
NewSelect(). NewSelect().
Table("list_entries"). Table("list_entries").
Column("id"). Column("id").
Where("? = ?", bun.Ident("follow_id"), followID). Where("? = ?", bun.Ident("follow_id"), followID).
Order("id DESC"). Order("id DESC").
Scan(ctx, &entryIDs); err != nil { Scan(ctx, &entryIDs); err != nil {
return l.conn.ProcessError(err) return l.db.ProcessError(err)
} }
for _, id := range entryIDs { for _, id := range entryIDs {

View file

@ -32,16 +32,16 @@ import (
) )
type mediaDB struct { type mediaDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) { func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, error) {
return m.getAttachment( return m.getAttachment(
ctx, ctx,
"ID", "ID",
func(attachment *gtsmodel.MediaAttachment) error { func(attachment *gtsmodel.MediaAttachment) error {
return m.conn.NewSelect(). return m.db.NewSelect().
Model(attachment). Model(attachment).
Where("? = ?", bun.Ident("media_attachment.id"), id). Where("? = ?", bun.Ident("media_attachment.id"), id).
Scan(ctx) Scan(ctx)
@ -68,13 +68,13 @@ func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gts
return attachments, nil return attachments, nil
} }
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) { func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) {
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) { return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment var attachment gtsmodel.MediaAttachment
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(&attachment); err != nil { if err := dbQuery(&attachment); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.db.ProcessError(err)
} }
return &attachment, nil return &attachment, nil
@ -83,8 +83,8 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func
func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error { func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error {
return m.state.Caches.GTS.Media().Store(media, func() error { return m.state.Caches.GTS.Media().Store(media, func() error {
_, err := m.conn.NewInsert().Model(media).Exec(ctx) _, err := m.db.NewInsert().Model(media).Exec(ctx)
return m.conn.ProcessError(err) return m.db.ProcessError(err)
}) })
} }
@ -96,12 +96,12 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt
} }
return m.state.Caches.GTS.Media().Store(media, func() error { return m.state.Caches.GTS.Media().Store(media, func() error {
_, err := m.conn.NewUpdate(). _, err := m.db.NewUpdate().
Model(media). Model(media).
Where("? = ?", bun.Ident("media_attachment.id"), media.ID). Where("? = ?", bun.Ident("media_attachment.id"), media.ID).
Column(columns...). Column(columns...).
Exec(ctx) Exec(ctx)
return m.conn.ProcessError(err) return m.db.ProcessError(err)
}) })
} }
@ -126,7 +126,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
) )
// Delete media attachment in new transaction. // Delete media attachment in new transaction.
err = m.conn.RunInTx(ctx, func(tx bun.Tx) error { err = m.db.RunInTx(ctx, func(tx bun.Tx) error {
if media.AccountID != "" { if media.AccountID != "" {
var account gtsmodel.Account var account gtsmodel.Account
@ -229,11 +229,11 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID) m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID)
} }
return m.conn.ProcessError(err) return m.db.ProcessError(err)
} }
func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) { func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, error) {
q := m.conn. q := m.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
Column("media_attachment.id"). Column("media_attachment.id").
@ -243,7 +243,7 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time)
count, err := q.Count(ctx) count, err := q.Count(ctx)
if err != nil { if err != nil {
return 0, m.conn.ProcessError(err) return 0, m.db.ProcessError(err)
} }
return count, nil return count, nil
@ -252,7 +252,7 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time)
func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) { func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit) attachmentIDs := make([]string, 0, limit)
q := m.conn.NewSelect(). q := m.db.NewSelect().
Table("media_attachments"). Table("media_attachments").
Column("id"). Column("id").
Order("id DESC") Order("id DESC")
@ -266,7 +266,7 @@ func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) (
} }
if err := q.Scan(ctx, &attachmentIDs); err != nil { if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.db.ProcessError(err)
} }
return m.GetAttachmentsByIDs(ctx, attachmentIDs) return m.GetAttachmentsByIDs(ctx, attachmentIDs)
@ -275,7 +275,7 @@ func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) (
func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) { func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit) attachmentIDs := make([]string, 0, limit)
q := m.conn.NewSelect(). q := m.db.NewSelect().
Table("media_attachments"). Table("media_attachments").
Column("id"). Column("id").
Where("remote_url IS NOT NULL"). Where("remote_url IS NOT NULL").
@ -290,16 +290,16 @@ func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit
} }
if err := q.Scan(ctx, &attachmentIDs); err != nil { if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.db.ProcessError(err)
} }
return m.GetAttachmentsByIDs(ctx, attachmentIDs) return m.GetAttachmentsByIDs(ctx, attachmentIDs)
} }
func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) { func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit) attachmentIDs := make([]string, 0, limit)
q := m.conn. q := m.db.
NewSelect(). NewSelect().
Table("media_attachments"). Table("media_attachments").
Column("id"). Column("id").
@ -313,16 +313,16 @@ func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan t
} }
if err := q.Scan(ctx, &attachmentIDs); err != nil { if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.db.ProcessError(err)
} }
return m.GetAttachmentsByIDs(ctx, attachmentIDs) return m.GetAttachmentsByIDs(ctx, attachmentIDs)
} }
func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, db.Error) { func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit) attachmentIDs := make([]string, 0, limit)
q := m.conn.NewSelect(). q := m.db.NewSelect().
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
Column("media_attachment.id"). Column("media_attachment.id").
WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery { WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {
@ -341,16 +341,16 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
} }
if err := q.Scan(ctx, &attachmentIDs); err != nil { if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.db.ProcessError(err)
} }
return m.GetAttachmentsByIDs(ctx, attachmentIDs) return m.GetAttachmentsByIDs(ctx, attachmentIDs)
} }
func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) { func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error) {
attachmentIDs := make([]string, 0, limit) attachmentIDs := make([]string, 0, limit)
q := m.conn. q := m.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
Column("media_attachment.id"). Column("media_attachment.id").
@ -367,14 +367,14 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
} }
if err := q.Scan(ctx, &attachmentIDs); err != nil { if err := q.Scan(ctx, &attachmentIDs); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.db.ProcessError(err)
} }
return m.GetAttachmentsByIDs(ctx, attachmentIDs) return m.GetAttachmentsByIDs(ctx, attachmentIDs)
} }
func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) { func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, error) {
q := m.conn. q := m.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
Column("media_attachment.id"). Column("media_attachment.id").
@ -387,7 +387,7 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t
count, err := q.Count(ctx) count, err := q.Count(ctx)
if err != nil { if err != nil {
return 0, m.conn.ProcessError(err) return 0, m.db.ProcessError(err)
} }
return count, nil return count, nil

View file

@ -31,21 +31,21 @@ import (
) )
type mentionDB struct { type mentionDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) { func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) {
mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) { mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention var mention gtsmodel.Mention
q := m.conn. q := m.db.
NewSelect(). NewSelect().
Model(&mention). Model(&mention).
Where("? = ?", bun.Ident("mention.id"), id) Where("? = ?", bun.Ident("mention.id"), id)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, m.conn.ProcessError(err) return nil, m.db.ProcessError(err)
} }
return &mention, nil return &mention, nil
@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
return mention, nil return mention, nil
} }
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) { func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) {
mentions := make([]*gtsmodel.Mention, 0, len(ids)) mentions := make([]*gtsmodel.Mention, 0, len(ids))
for _, id := range ids { for _, id := range ids {
@ -104,8 +104,8 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
return m.state.Caches.GTS.Mention().Store(mention, func() error { return m.state.Caches.GTS.Mention().Store(mention, func() error {
_, err := m.conn.NewInsert().Model(mention).Exec(ctx) _, err := m.db.NewInsert().Model(mention).Exec(ctx)
return m.conn.ProcessError(err) return m.db.ProcessError(err)
}) })
} }
@ -125,9 +125,9 @@ func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
} }
// Finally delete mention from DB. // Finally delete mention from DB.
_, err = m.conn.NewDelete(). _, err = m.db.NewDelete().
Table("mentions"). Table("mentions").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Exec(ctx) Exec(ctx)
return m.conn.ProcessError(err) return m.db.ProcessError(err)
} }

View file

@ -31,19 +31,19 @@ import (
) )
type notificationDB struct { type notificationDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) { func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) {
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) { return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification var notif gtsmodel.Notification
q := n.conn.NewSelect(). q := n.db.NewSelect().
Model(&notif). Model(&notif).
Where("? = ?", bun.Ident("notification.id"), id) Where("? = ?", bun.Ident("notification.id"), id)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, n.conn.ProcessError(err) return nil, n.db.ProcessError(err)
} }
return &notif, nil return &notif, nil
@ -56,11 +56,11 @@ func (n *notificationDB) GetNotification(
targetAccountID string, targetAccountID string,
originAccountID string, originAccountID string,
statusID string, statusID string,
) (*gtsmodel.Notification, db.Error) { ) (*gtsmodel.Notification, error) {
return n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) { return n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification var notif gtsmodel.Notification
q := n.conn.NewSelect(). q := n.db.NewSelect().
Model(&notif). Model(&notif).
Where("? = ?", bun.Ident("notification_type"), notificationType). Where("? = ?", bun.Ident("notification_type"), notificationType).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID). Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
@ -68,7 +68,7 @@ func (n *notificationDB) GetNotification(
Where("? = ?", bun.Ident("status_id"), statusID) Where("? = ?", bun.Ident("status_id"), statusID)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, n.conn.ProcessError(err) return nil, n.db.ProcessError(err)
} }
return &notif, nil return &notif, nil
@ -83,7 +83,7 @@ func (n *notificationDB) GetAccountNotifications(
minID string, minID string,
limit int, limit int,
excludeTypes []string, excludeTypes []string,
) ([]*gtsmodel.Notification, db.Error) { ) ([]*gtsmodel.Notification, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
limit = 0 limit = 0
@ -95,7 +95,7 @@ func (n *notificationDB) GetAccountNotifications(
frontToBack = true frontToBack = true
) )
q := n.conn. q := n.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Column("notification.id") Column("notification.id")
@ -140,7 +140,7 @@ func (n *notificationDB) GetAccountNotifications(
} }
if err := q.Scan(ctx, &notifIDs); err != nil { if err := q.Scan(ctx, &notifIDs); err != nil {
return nil, n.conn.ProcessError(err) return nil, n.db.ProcessError(err)
} }
if len(notifIDs) == 0 { if len(notifIDs) == 0 {
@ -174,12 +174,12 @@ func (n *notificationDB) GetAccountNotifications(
func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
return n.state.Caches.GTS.Notification().Store(notif, func() error { return n.state.Caches.GTS.Notification().Store(notif, func() error {
_, err := n.conn.NewInsert().Model(notif).Exec(ctx) _, err := n.db.NewInsert().Model(notif).Exec(ctx)
return n.conn.ProcessError(err) return n.db.ProcessError(err)
}) })
} }
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error { func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error {
defer n.state.Caches.GTS.Notification().Invalidate("ID", id) defer n.state.Caches.GTS.Notification().Invalidate("ID", id)
// Load notif into cache before attempting a delete, // Load notif into cache before attempting a delete,
@ -195,21 +195,21 @@ func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string)
} }
// Finally delete notif from DB. // Finally delete notif from DB.
_, err = n.conn.NewDelete(). _, err = n.db.NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Where("? = ?", bun.Ident("notification.id"), id). Where("? = ?", bun.Ident("notification.id"), id).
Exec(ctx) Exec(ctx)
return n.conn.ProcessError(err) return n.db.ProcessError(err)
} }
func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error { func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error {
if targetAccountID == "" && originAccountID == "" { if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set") return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set")
} }
var notifIDs []string var notifIDs []string
q := n.conn. q := n.db.
NewSelect(). NewSelect().
Column("id"). Column("id").
Table("notifications") Table("notifications")
@ -227,7 +227,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
} }
if _, err := q.Exec(ctx, &notifIDs); err != nil { if _, err := q.Exec(ctx, &notifIDs); err != nil {
return n.conn.ProcessError(err) return n.db.ProcessError(err)
} }
defer func() { defer func() {
@ -248,24 +248,24 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
} }
// Finally delete all from DB. // Finally delete all from DB.
_, err := n.conn.NewDelete(). _, err := n.db.NewDelete().
Table("notifications"). Table("notifications").
Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)).
Exec(ctx) Exec(ctx)
return n.conn.ProcessError(err) return n.db.ProcessError(err)
} }
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error { func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) error {
var notifIDs []string var notifIDs []string
q := n.conn. q := n.db.
NewSelect(). NewSelect().
Column("id"). Column("id").
Table("notifications"). Table("notifications").
Where("? = ?", bun.Ident("status_id"), statusID) Where("? = ?", bun.Ident("status_id"), statusID)
if _, err := q.Exec(ctx, &notifIDs); err != nil { if _, err := q.Exec(ctx, &notifIDs); err != nil {
return n.conn.ProcessError(err) return n.db.ProcessError(err)
} }
defer func() { defer func() {
@ -286,9 +286,9 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu
} }
// Finally delete all from DB. // Finally delete all from DB.
_, err := n.conn.NewDelete(). _, err := n.db.NewDelete().
Table("notifications"). Table("notifications").
Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)).
Exec(ctx) Exec(ctx)
return n.conn.ProcessError(err) return n.db.ProcessError(err)
} }

View file

@ -30,11 +30,11 @@ import (
) )
type relationshipDB struct { type relationshipDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
var rel gtsmodel.Relationship var rel gtsmodel.Relationship
rel.ID = targetAccount rel.ID = targetAccount
@ -90,91 +90,91 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string var followIDs []string
if err := newSelectFollows(r.conn, accountID). if err := newSelectFollows(r.db, accountID).
Scan(ctx, &followIDs); err != nil { Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return r.GetFollowsByIDs(ctx, followIDs) return r.GetFollowsByIDs(ctx, followIDs)
} }
func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string var followIDs []string
if err := newSelectLocalFollows(r.conn, accountID). if err := newSelectLocalFollows(r.db, accountID).
Scan(ctx, &followIDs); err != nil { Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return r.GetFollowsByIDs(ctx, followIDs) return r.GetFollowsByIDs(ctx, followIDs)
} }
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string var followIDs []string
if err := newSelectFollowers(r.conn, accountID). if err := newSelectFollowers(r.db, accountID).
Scan(ctx, &followIDs); err != nil { Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return r.GetFollowsByIDs(ctx, followIDs) return r.GetFollowsByIDs(ctx, followIDs)
} }
func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string var followIDs []string
if err := newSelectLocalFollowers(r.conn, accountID). if err := newSelectLocalFollowers(r.db, accountID).
Scan(ctx, &followIDs); err != nil { Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return r.GetFollowsByIDs(ctx, followIDs) return r.GetFollowsByIDs(ctx, followIDs)
} }
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollows(r.conn, accountID).Count(ctx) n, err := newSelectFollows(r.db, accountID).Count(ctx)
return n, r.conn.ProcessError(err) return n, r.db.ProcessError(err)
} }
func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx) n, err := newSelectLocalFollows(r.db, accountID).Count(ctx)
return n, r.conn.ProcessError(err) return n, r.db.ProcessError(err)
} }
func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowers(r.conn, accountID).Count(ctx) n, err := newSelectFollowers(r.db, accountID).Count(ctx)
return n, r.conn.ProcessError(err) return n, r.db.ProcessError(err)
} }
func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx) n, err := newSelectLocalFollowers(r.db, accountID).Count(ctx)
return n, r.conn.ProcessError(err) return n, r.db.ProcessError(err)
} }
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
var followReqIDs []string var followReqIDs []string
if err := newSelectFollowRequests(r.conn, accountID). if err := newSelectFollowRequests(r.db, accountID).
Scan(ctx, &followReqIDs); err != nil { Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return r.GetFollowRequestsByIDs(ctx, followReqIDs) return r.GetFollowRequestsByIDs(ctx, followReqIDs)
} }
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
var followReqIDs []string var followReqIDs []string
if err := newSelectFollowRequesting(r.conn, accountID). if err := newSelectFollowRequesting(r.db, accountID).
Scan(ctx, &followReqIDs); err != nil { Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return r.GetFollowRequestsByIDs(ctx, followReqIDs) return r.GetFollowRequestsByIDs(ctx, followReqIDs)
} }
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx) n, err := newSelectFollowRequests(r.db, accountID).Count(ctx)
return n, r.conn.ProcessError(err) return n, r.db.ProcessError(err)
} }
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx) n, err := newSelectFollowRequesting(r.db, accountID).Count(ctx)
return n, r.conn.ProcessError(err) return n, r.db.ProcessError(err)
} }
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery { func newSelectFollowRequests(db *WrappedDB, accountID string) *bun.SelectQuery {
return conn.NewSelect(). return db.NewSelect().
TableExpr("?", bun.Ident("follow_requests")). TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")). ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID). Where("? = ?", bun.Ident("target_account_id"), accountID).
@ -182,8 +182,8 @@ func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
} }
// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery { func newSelectFollowRequesting(db *WrappedDB, accountID string) *bun.SelectQuery {
return conn.NewSelect(). return db.NewSelect().
TableExpr("?", bun.Ident("follow_requests")). TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")). ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID). Where("? = ?", bun.Ident("target_account_id"), accountID).
@ -191,8 +191,8 @@ func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery
} }
// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery { func newSelectFollows(db *WrappedDB, accountID string) *bun.SelectQuery {
return conn.NewSelect(). return db.NewSelect().
Table("follows"). Table("follows").
Column("id"). Column("id").
Where("? = ?", bun.Ident("account_id"), accountID). Where("? = ?", bun.Ident("account_id"), accountID).
@ -201,15 +201,15 @@ func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
// newSelectLocalFollows returns a new select query for all rows in the follows table with // newSelectLocalFollows returns a new select query for all rows in the follows table with
// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). // account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery { func newSelectLocalFollows(db *WrappedDB, accountID string) *bun.SelectQuery {
return conn.NewSelect(). return db.NewSelect().
Table("follows"). Table("follows").
Column("id"). Column("id").
Where("? = ? AND ? IN (?)", Where("? = ? AND ? IN (?)",
bun.Ident("account_id"), bun.Ident("account_id"),
accountID, accountID,
bun.Ident("target_account_id"), bun.Ident("target_account_id"),
conn.NewSelect(). db.NewSelect().
Table("accounts"). Table("accounts").
Column("id"). Column("id").
Where("? IS NULL", bun.Ident("domain")), Where("? IS NULL", bun.Ident("domain")),
@ -218,8 +218,8 @@ func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
} }
// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery { func newSelectFollowers(db *WrappedDB, accountID string) *bun.SelectQuery {
return conn.NewSelect(). return db.NewSelect().
Table("follows"). Table("follows").
Column("id"). Column("id").
Where("? = ?", bun.Ident("target_account_id"), accountID). Where("? = ?", bun.Ident("target_account_id"), accountID).
@ -228,15 +228,15 @@ func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
// newSelectLocalFollowers returns a new select query for all rows in the follows table with // newSelectLocalFollowers returns a new select query for all rows in the follows table with
// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). // target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery { func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery {
return conn.NewSelect(). return db.NewSelect().
Table("follows"). Table("follows").
Column("id"). Column("id").
Where("? = ? AND ? IN (?)", Where("? = ? AND ? IN (?)",
bun.Ident("target_account_id"), bun.Ident("target_account_id"),
accountID, accountID,
bun.Ident("account_id"), bun.Ident("account_id"),
conn.NewSelect(). db.NewSelect().
Table("accounts"). Table("accounts").
Column("id"). Column("id").
Where("? IS NULL", bun.Ident("domain")), Where("? IS NULL", bun.Ident("domain")),

View file

@ -28,7 +28,7 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
block, err := r.GetBlock( block, err := r.GetBlock(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
sourceAccountID, sourceAccountID,
@ -61,7 +61,7 @@ func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel
ctx, ctx,
"ID", "ID",
func(block *gtsmodel.Block) error { func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block). return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.id"), id). Where("? = ?", bun.Ident("block.id"), id).
Scan(ctx) Scan(ctx)
}, },
@ -74,7 +74,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod
ctx, ctx,
"URI", "URI",
func(block *gtsmodel.Block) error { func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block). return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.uri"), uri). Where("? = ?", bun.Ident("block.uri"), uri).
Scan(ctx) Scan(ctx)
}, },
@ -87,7 +87,7 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t
ctx, ctx,
"AccountID.TargetAccountID", "AccountID.TargetAccountID",
func(block *gtsmodel.Block) error { func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block). return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID). Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID). Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID).
Scan(ctx) Scan(ctx)
@ -104,7 +104,7 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(&block); err != nil { if err := dbQuery(&block); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return &block, nil return &block, nil
@ -142,8 +142,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
return r.state.Caches.GTS.Block().Store(block, func() error { return r.state.Caches.GTS.Block().Store(block, func() error {
_, err := r.conn.NewInsert().Model(block).Exec(ctx) _, err := r.db.NewInsert().Model(block).Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
}) })
} }
@ -163,11 +163,11 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
} }
// Finally delete block from DB. // Finally delete block from DB.
_, err = r.conn.NewDelete(). _, err = r.db.NewDelete().
Table("blocks"). Table("blocks").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error { func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
@ -186,18 +186,18 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error
} }
// Finally delete block from DB. // Finally delete block from DB.
_, err = r.conn.NewDelete(). _, err = r.db.NewDelete().
Table("blocks"). Table("blocks").
Where("? = ?", bun.Ident("uri"), uri). Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error { func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error {
var blockIDs []string var blockIDs []string
// Get full list of IDs. // Get full list of IDs.
if err := r.conn.NewSelect(). if err := r.db.NewSelect().
Column("id"). Column("id").
Table("blocks"). Table("blocks").
WhereOr("? = ? OR ? = ?", WhereOr("? = ? OR ? = ?",
@ -207,7 +207,7 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
accountID, accountID,
). ).
Scan(ctx, &blockIDs); err != nil { Scan(ctx, &blockIDs); err != nil {
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
defer func() { defer func() {
@ -228,9 +228,9 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
} }
// Finally delete all from DB. // Finally delete all from DB.
_, err := r.conn.NewDelete(). _, err := r.db.NewDelete().
Table("blocks"). Table("blocks").
Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)). Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }

View file

@ -36,7 +36,7 @@ func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmode
ctx, ctx,
"ID", "ID",
func(follow *gtsmodel.Follow) error { func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect(). return r.db.NewSelect().
Model(follow). Model(follow).
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Scan(ctx) Scan(ctx)
@ -50,7 +50,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo
ctx, ctx,
"URI", "URI",
func(follow *gtsmodel.Follow) error { func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect(). return r.db.NewSelect().
Model(follow). Model(follow).
Where("? = ?", bun.Ident("uri"), uri). Where("? = ?", bun.Ident("uri"), uri).
Scan(ctx) Scan(ctx)
@ -64,7 +64,7 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string,
ctx, ctx,
"AccountID.TargetAccountID", "AccountID.TargetAccountID",
func(follow *gtsmodel.Follow) error { func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect(). return r.db.NewSelect().
Model(follow). Model(follow).
Where("? = ?", bun.Ident("account_id"), sourceAccountID). Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID). Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
@ -94,7 +94,7 @@ func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*
return follows, nil return follows, nil
} }
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
follow, err := r.GetFollow( follow, err := r.GetFollow(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
sourceAccountID, sourceAccountID,
@ -106,7 +106,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string
return (follow != nil), nil return (follow != nil), nil
} }
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) { func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, error) {
// make sure account 1 follows account 2 // make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx, f1, err := r.IsFollowing(ctx,
accountID1, accountID1,
@ -135,7 +135,7 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(&follow); err != nil { if err := dbQuery(&follow); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return &follow, nil return &follow, nil
@ -190,8 +190,8 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
return r.state.Caches.GTS.Follow().Store(follow, func() error { return r.state.Caches.GTS.Follow().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx) _, err := r.db.NewInsert().Model(follow).Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
}) })
} }
@ -203,12 +203,12 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
} }
return r.state.Caches.GTS.Follow().Store(follow, func() error { return r.state.Caches.GTS.Follow().Store(follow, func() error {
if _, err := r.conn.NewUpdate(). if _, err := r.db.NewUpdate().
Model(follow). Model(follow).
Where("? = ?", bun.Ident("follow.id"), follow.ID). Where("? = ?", bun.Ident("follow.id"), follow.ID).
Column(columns...). Column(columns...).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
return nil return nil
@ -217,11 +217,11 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error { func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {
// Delete the follow itself using the given ID. // Delete the follow itself using the given ID.
if _, err := r.conn.NewDelete(). if _, err := r.db.NewDelete().
Table("follows"). Table("follows").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
// Delete every list entry that used this followID. // Delete every list entry that used this followID.
@ -297,7 +297,7 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
var followIDs []string var followIDs []string
// Get full list of IDs. // Get full list of IDs.
if _, err := r.conn. if _, err := r.db.
NewSelect(). NewSelect().
Column("id"). Column("id").
Table("follows"). Table("follows").
@ -308,7 +308,7 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
accountID, accountID,
). ).
Exec(ctx, &followIDs); err != nil { Exec(ctx, &followIDs); err != nil {
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
defer func() { defer func() {

View file

@ -35,7 +35,7 @@ func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*
ctx, ctx,
"ID", "ID",
func(followReq *gtsmodel.FollowRequest) error { func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect(). return r.db.NewSelect().
Model(followReq). Model(followReq).
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Scan(ctx) Scan(ctx)
@ -49,7 +49,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string)
ctx, ctx,
"URI", "URI",
func(followReq *gtsmodel.FollowRequest) error { func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect(). return r.db.NewSelect().
Model(followReq). Model(followReq).
Where("? = ?", bun.Ident("uri"), uri). Where("? = ?", bun.Ident("uri"), uri).
Scan(ctx) Scan(ctx)
@ -63,7 +63,7 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s
ctx, ctx,
"AccountID.TargetAccountID", "AccountID.TargetAccountID",
func(followReq *gtsmodel.FollowRequest) error { func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect(). return r.db.NewSelect().
Model(followReq). Model(followReq).
Where("? = ?", bun.Ident("account_id"), sourceAccountID). Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID). Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
@ -93,7 +93,7 @@ func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []strin
return followReqs, nil return followReqs, nil
} }
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) { func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
followReq, err := r.GetFollowRequest( followReq, err := r.GetFollowRequest(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
sourceAccountID, sourceAccountID,
@ -112,7 +112,7 @@ func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, db
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(&followReq); err != nil { if err := dbQuery(&followReq); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return &followReq, nil return &followReq, nil
@ -150,8 +150,8 @@ func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, db
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
return r.state.Caches.GTS.FollowRequest().Store(follow, func() error { return r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx) _, err := r.db.NewInsert().Model(follow).Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
}) })
} }
@ -163,19 +163,19 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest
} }
return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error { return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error {
if _, err := r.conn.NewUpdate(). if _, err := r.db.NewUpdate().
Model(followRequest). Model(followRequest).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
Column(columns...). Column(columns...).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
return nil return nil
}) })
} }
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
// Get original follow request. // Get original follow request.
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID) followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
if err != nil { if err != nil {
@ -198,12 +198,12 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
if err := r.state.Caches.GTS.Follow().Store(follow, func() error { if err := r.state.Caches.GTS.Follow().Store(follow, func() error {
// If the follow already exists, just // If the follow already exists, just
// replace the URI with the new one. // replace the URI with the new one.
_, err := r.conn. _, err := r.db.
NewInsert(). NewInsert().
Model(follow). Model(follow).
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@ -212,12 +212,12 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID) defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
// Delete original follow request. // Delete original follow request.
if _, err := r.conn. if _, err := r.db.
NewDelete(). NewDelete().
Table("follow_requests"). Table("follow_requests").
Where("? = ?", bun.Ident("id"), followReq.ID). Where("? = ?", bun.Ident("id"), followReq.ID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
// Delete original follow request notification // Delete original follow request notification
@ -230,7 +230,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
return follow, nil return follow, nil
} }
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error { func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) error {
// Delete follow request first. // Delete follow request first.
if err := r.DeleteFollowRequest(ctx, sourceAccountID, targetAccountID); err != nil { if err := r.DeleteFollowRequest(ctx, sourceAccountID, targetAccountID); err != nil {
return err return err
@ -262,11 +262,11 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI
} }
// Finally delete followreq from DB. // Finally delete followreq from DB.
_, err = r.conn.NewDelete(). _, err = r.db.NewDelete().
Table("follow_requests"). Table("follow_requests").
Where("? = ?", bun.Ident("id"), follow.ID). Where("? = ?", bun.Ident("id"), follow.ID).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error { func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
@ -285,11 +285,11 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)
} }
// Finally delete followreq from DB. // Finally delete followreq from DB.
_, err = r.conn.NewDelete(). _, err = r.db.NewDelete().
Table("follow_requests"). Table("follow_requests").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error { func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
@ -308,18 +308,18 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin
} }
// Finally delete followreq from DB. // Finally delete followreq from DB.
_, err = r.conn.NewDelete(). _, err = r.db.NewDelete().
Table("follow_requests"). Table("follow_requests").
Where("? = ?", bun.Ident("uri"), uri). Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error { func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
var followReqIDs []string var followReqIDs []string
// Get full list of IDs. // Get full list of IDs.
if _, err := r.conn. if _, err := r.db.
NewSelect(). NewSelect().
Column("id"). Column("id").
Table("follow_requestss"). Table("follow_requestss").
@ -330,7 +330,7 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
accountID, accountID,
). ).
Exec(ctx, &followReqIDs); err != nil { Exec(ctx, &followReqIDs); err != nil {
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }
defer func() { defer func() {
@ -351,9 +351,9 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
} }
// Finally delete all from DB. // Finally delete all from DB.
_, err := r.conn.NewDelete(). _, err := r.db.NewDelete().
Table("follow_requests"). Table("follow_requests").
Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)). Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }

View file

@ -32,15 +32,15 @@ import (
) )
type reportDB struct { type reportDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (r *reportDB) newReportQ(report interface{}) *bun.SelectQuery { func (r *reportDB) newReportQ(report interface{}) *bun.SelectQuery {
return r.conn.NewSelect().Model(report) return r.db.NewSelect().Model(report)
} }
func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, db.Error) { func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, error) {
return r.getReport( return r.getReport(
ctx, ctx,
"ID", "ID",
@ -51,10 +51,10 @@ func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Repo
) )
} }
func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, db.Error) { func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, error) {
reportIDs := []string{} reportIDs := []string{}
q := r.conn. q := r.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")). TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")).
Column("report.id"). Column("report.id").
@ -94,7 +94,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str
} }
if err := q.Scan(ctx, &reportIDs); err != nil { if err := q.Scan(ctx, &reportIDs); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
// Catch case of no reports early // Catch case of no reports early
@ -118,14 +118,14 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str
return reports, nil return reports, nil
} }
func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, db.Error) { func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) {
// Fetch report from database cache with loader callback // Fetch report from database cache with loader callback
report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) { report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) {
var report gtsmodel.Report var report gtsmodel.Report
// Not cached! Perform database query // Not cached! Perform database query
if err := dbQuery(&report); err != nil { if err := dbQuery(&report); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
return &report, nil return &report, nil
@ -166,34 +166,34 @@ func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*g
return report, nil return report, nil
} }
func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) db.Error { func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error {
return r.state.Caches.GTS.Report().Store(report, func() error { return r.state.Caches.GTS.Report().Store(report, func() error {
_, err := r.conn.NewInsert().Model(report).Exec(ctx) _, err := r.db.NewInsert().Model(report).Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
}) })
} }
func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, db.Error) { func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, error) {
// Update the report's last-updated // Update the report's last-updated
report.UpdatedAt = time.Now() report.UpdatedAt = time.Now()
if len(columns) != 0 { if len(columns) != 0 {
columns = append(columns, "updated_at") columns = append(columns, "updated_at")
} }
if _, err := r.conn. if _, err := r.db.
NewUpdate(). NewUpdate().
Model(report). Model(report).
Where("? = ?", bun.Ident("report.id"), report.ID). Where("? = ?", bun.Ident("report.id"), report.ID).
Column(columns...). Column(columns...).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err) return nil, r.db.ProcessError(err)
} }
r.state.Caches.GTS.Report().Invalidate("ID", report.ID) r.state.Caches.GTS.Report().Invalidate("ID", report.ID)
return report, nil return report, nil
} }
func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error { func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error {
defer r.state.Caches.GTS.Report().Invalidate("ID", id) defer r.state.Caches.GTS.Report().Invalidate("ID", id)
// Load status into cache before attempting a delete, // Load status into cache before attempting a delete,
@ -209,9 +209,9 @@ func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error {
} }
// Finally delete report from DB. // Finally delete report from DB.
_, err = r.conn.NewDelete(). _, err = r.db.NewDelete().
TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")). TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")).
Where("? = ?", bun.Ident("report.id"), id). Where("? = ?", bun.Ident("report.id"), id).
Exec(ctx) Exec(ctx)
return r.conn.ProcessError(err) return r.db.ProcessError(err)
} }

View file

@ -56,7 +56,7 @@ import (
// This isn't ideal, of course, but at least we could cover the most common use case of // This isn't ideal, of course, but at least we could cover the most common use case of
// a caller paging down through results. // a caller paging down through results.
type searchDB struct { type searchDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
@ -89,7 +89,7 @@ func (s *searchDB) SearchForAccounts(
frontToBack = true frontToBack = true
) )
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
// Select only IDs from table. // Select only IDs from table.
@ -148,7 +148,7 @@ func (s *searchDB) SearchForAccounts(
} }
if err := q.Scan(ctx, &accountIDs); err != nil { if err := q.Scan(ctx, &accountIDs); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
if len(accountIDs) == 0 { if len(accountIDs) == 0 {
@ -183,7 +183,7 @@ func (s *searchDB) SearchForAccounts(
// followedAccounts returns a subquery that selects only IDs // followedAccounts returns a subquery that selects only IDs
// of accounts that are followed by the given accountID. // of accounts that are followed by the given accountID.
func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery { func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery {
return s.conn. return s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.target_account_id"). Column("follow.target_account_id").
@ -196,7 +196,7 @@ func (s *searchDB) followedAccounts(accountID string) *bun.SelectQuery {
// in the concatenation. // in the concatenation.
func (s *searchDB) accountText(following bool) *bun.SelectQuery { func (s *searchDB) accountText(following bool) *bun.SelectQuery {
var ( var (
accountText = s.conn.NewSelect() accountText = s.db.NewSelect()
query string query string
args []interface{} args []interface{}
) )
@ -225,7 +225,7 @@ func (s *searchDB) accountText(following bool) *bun.SelectQuery {
// different number of placeholders depending on // different number of placeholders depending on
// following/not following. COALESCE calls ensure // following/not following. COALESCE calls ensure
// that we're not trying to concatenate null values. // that we're not trying to concatenate null values.
d := s.conn.Dialect().Name() d := s.db.Dialect().Name()
switch { switch {
case d == dialect.SQLite && following: case d == dialect.SQLite && following:
@ -276,7 +276,7 @@ func (s *searchDB) SearchForStatuses(
frontToBack = true frontToBack = true
) )
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table // Select only IDs from table
@ -326,7 +326,7 @@ func (s *searchDB) SearchForStatuses(
} }
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
if len(statusIDs) == 0 { if len(statusIDs) == 0 {
@ -361,11 +361,11 @@ func (s *searchDB) SearchForStatuses(
// statusText returns a subquery that selects a concatenation // statusText returns a subquery that selects a concatenation
// of status content and content warning as "status_text". // of status content and content warning as "status_text".
func (s *searchDB) statusText() *bun.SelectQuery { func (s *searchDB) statusText() *bun.SelectQuery {
statusText := s.conn.NewSelect() statusText := s.db.NewSelect()
// SQLite and Postgres use different // SQLite and Postgres use different
// syntaxes for concatenation. // syntaxes for concatenation.
switch s.conn.Dialect().Name() { switch s.db.Dialect().Name() {
case dialect.SQLite: case dialect.SQLite:
statusText = statusText.ColumnExpr( statusText = statusText.ColumnExpr(

View file

@ -22,26 +22,25 @@ import (
"crypto/rand" "crypto/rand"
"io" "io"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
) )
type sessionDB struct { type sessionDB struct {
conn *DBConn db *WrappedDB
} }
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) {
rss := make([]*gtsmodel.RouterSession, 0, 1) rss := make([]*gtsmodel.RouterSession, 0, 1)
// get the first router session in the db or... // get the first router session in the db or...
if err := s.conn. if err := s.db.
NewSelect(). NewSelect().
Model(&rss). Model(&rss).
Limit(1). Limit(1).
Order("router_session.id DESC"). Order("router_session.id DESC").
Scan(ctx); err != nil { Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
// ... create a new one // ... create a new one
@ -52,7 +51,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db
return rss[0], nil return rss[0], nil
} }
func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, error) {
buf := make([]byte, 64) buf := make([]byte, 64)
auth := buf[:32] auth := buf[:32]
crypt := buf[32:64] crypt := buf[32:64]
@ -67,11 +66,11 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession,
Crypt: crypt, Crypt: crypt,
} }
if _, err := s.conn. if _, err := s.db.
NewInsert(). NewInsert().
Model(rs). Model(rs).
Exec(ctx); err != nil { Exec(ctx); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
return rs, nil return rs, nil

View file

@ -35,19 +35,19 @@ import (
) )
type statusDB struct { type statusDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery { func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn. return s.db.
NewSelect(). NewSelect().
Model(status). Model(status).
Relation("Tags"). Relation("Tags").
Relation("CreatedWithApplication") Relation("CreatedWithApplication")
} }
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) {
return s.getStatus( return s.getStatus(
ctx, ctx,
"ID", "ID",
@ -76,7 +76,7 @@ func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmo
return statuses, nil return statuses, nil
} }
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, error) {
return s.getStatus( return s.getStatus(
ctx, ctx,
"URI", "URI",
@ -87,7 +87,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
) )
} }
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, error) {
return s.getStatus( return s.getStatus(
ctx, ctx,
"URL", "URL",
@ -98,14 +98,14 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St
) )
} }
func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, db.Error) { func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) {
// Fetch status from database cache with loader callback // Fetch status from database cache with loader callback
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) { status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status var status gtsmodel.Status
// Not cached! Perform database query. // Not cached! Perform database query.
if err := dbQuery(&status); err != nil { if err := dbQuery(&status); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
return &status, nil return &status, nil
@ -243,12 +243,12 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
return errs.Combine() return errs.Combine()
} }
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error {
return s.state.Caches.GTS.Status().Store(status, func() error { return s.state.Caches.GTS.Status().Store(status, func() error {
// It is safe to run this database transaction within cache.Store // It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook. // as the cache does not attempt a mutex lock until AFTER hook.
// //
return s.conn.RunInTx(ctx, func(tx bun.Tx) error { return s.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses // create links between this status and any emojis it uses
for _, i := range status.EmojiIDs { for _, i := range status.EmojiIDs {
if _, err := tx. if _, err := tx.
@ -259,7 +259,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
}). }).
On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")).
Exec(ctx); err != nil { Exec(ctx); err != nil {
err = s.conn.ProcessError(err) err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -276,7 +276,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
}). }).
On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")).
Exec(ctx); err != nil { Exec(ctx); err != nil {
err = s.conn.ProcessError(err) err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -292,7 +292,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
Model(a). Model(a).
Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
err = s.conn.ProcessError(err) err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -306,7 +306,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
}) })
} }
func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) db.Error { func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) error {
status.UpdatedAt = time.Now() status.UpdatedAt = time.Now()
if len(columns) > 0 { if len(columns) > 0 {
// If we're updating by column, ensure "updated_at" is included. // If we're updating by column, ensure "updated_at" is included.
@ -317,7 +317,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
// It is safe to run this database transaction within cache.Store // It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook. // as the cache does not attempt a mutex lock until AFTER hook.
// //
return s.conn.RunInTx(ctx, func(tx bun.Tx) error { return s.db.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses // create links between this status and any emojis it uses
for _, i := range status.EmojiIDs { for _, i := range status.EmojiIDs {
if _, err := tx. if _, err := tx.
@ -328,7 +328,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
}). }).
On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")).
Exec(ctx); err != nil { Exec(ctx); err != nil {
err = s.conn.ProcessError(err) err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -345,7 +345,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
}). }).
On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")).
Exec(ctx); err != nil { Exec(ctx); err != nil {
err = s.conn.ProcessError(err) err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -361,7 +361,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
Model(a). Model(a).
Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
Exec(ctx); err != nil { Exec(ctx); err != nil {
err = s.conn.ProcessError(err) err = s.db.ProcessError(err)
if !errors.Is(err, db.ErrAlreadyExists) { if !errors.Is(err, db.ErrAlreadyExists) {
return err return err
} }
@ -380,7 +380,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
}) })
} }
func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {
defer s.state.Caches.GTS.Status().Invalidate("ID", id) defer s.state.Caches.GTS.Status().Invalidate("ID", id)
// Load status into cache before attempting a delete, // Load status into cache before attempting a delete,
@ -397,7 +397,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
return err return err
} }
return s.conn.RunInTx(ctx, func(tx bun.Tx) error { return s.db.RunInTx(ctx, func(tx bun.Tx) error {
// delete links between this status and any emojis it uses // delete links between this status and any emojis it uses
if _, err := tx. if _, err := tx.
NewDelete(). NewDelete().
@ -433,7 +433,7 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([
var statusIDs []string var statusIDs []string
// Create SELECT status query. // Create SELECT status query.
q := s.conn.NewSelect(). q := s.db.NewSelect().
Table("statuses"). Table("statuses").
Column("id") Column("id")
@ -450,14 +450,14 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([
// Execute the query, scanning destination into statusIDs. // Execute the query, scanning destination into statusIDs.
if _, err := q.Exec(ctx, &statusIDs); err != nil { if _, err := q.Exec(ctx, &statusIDs); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
// Convert status IDs into status objects. // Convert status IDs into status objects.
return s.GetStatusesByIDs(ctx, statusIDs) return s.GetStatusesByIDs(ctx, statusIDs)
} }
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
if onlyDirect { if onlyDirect {
// Only want the direct parent, no further than first level // Only want the direct parent, no further than first level
parent, err := s.GetStatusByID(ctx, status.InReplyToID) parent, err := s.GetStatusByID(ctx, status.InReplyToID)
@ -485,7 +485,7 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status
return parents, nil return parents, nil
} }
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
foundStatuses := &list.List{} foundStatuses := &list.List{}
foundStatuses.PushFront(status) foundStatuses.PushFront(status)
s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID) s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
@ -509,7 +509,7 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
var childIDs []string var childIDs []string
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id"). Column("status.id").
@ -554,71 +554,71 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
} }
} }
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error) {
return s.conn. return s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID). Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID).
Count(ctx) Count(ctx)
} }
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error) {
return s.conn. return s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). Where("? = ?", bun.Ident("status.boost_of_id"), status.ID).
Count(ctx) Count(ctx)
} }
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error) {
return s.conn. return s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). Where("? = ?", bun.Ident("status_fave.status_id"), status.ID).
Count(ctx) Count(ctx)
} }
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) {
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). Where("? = ?", bun.Ident("status_fave.status_id"), status.ID).
Where("? = ?", bun.Ident("status_fave.account_id"), accountID) Where("? = ?", bun.Ident("status_fave.account_id"), accountID)
return s.conn.Exists(ctx, q) return s.db.Exists(ctx, q)
} }
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) {
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). Where("? = ?", bun.Ident("status.boost_of_id"), status.ID).
Where("? = ?", bun.Ident("status.account_id"), accountID) Where("? = ?", bun.Ident("status.account_id"), accountID)
return s.conn.Exists(ctx, q) return s.db.Exists(ctx, q)
} }
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) {
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")). TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")).
Where("? = ?", bun.Ident("status_mute.status_id"), status.ID). Where("? = ?", bun.Ident("status_mute.status_id"), status.ID).
Where("? = ?", bun.Ident("status_mute.account_id"), accountID) Where("? = ?", bun.Ident("status_mute.account_id"), accountID)
return s.conn.Exists(ctx, q) return s.db.Exists(ctx, q)
} }
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) {
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID). Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID).
Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID) Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID)
return s.conn.Exists(ctx, q) return s.db.Exists(ctx, q)
} }
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) {
reblogs := []*gtsmodel.Status{} reblogs := []*gtsmodel.Status{}
q := s. q := s.
@ -626,7 +626,7 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID) Where("? = ?", bun.Ident("status.boost_of_id"), status.ID)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
return reblogs, nil return reblogs, nil
} }

View file

@ -22,7 +22,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
@ -30,20 +29,20 @@ import (
) )
type statusBookmarkDB struct { type statusBookmarkDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, db.Error) { func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, error) {
bookmark := new(gtsmodel.StatusBookmark) bookmark := new(gtsmodel.StatusBookmark)
err := s.conn. err := s.db.
NewSelect(). NewSelect().
Model(bookmark). Model(bookmark).
Where("? = ?", bun.Ident("status_bookmark.id"), id). Where("? = ?", bun.Ident("status_bookmark.id"), id).
Scan(ctx) Scan(ctx)
if err != nil { if err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
bookmark.Account, err = s.state.DB.GetAccountByID(ctx, bookmark.AccountID) bookmark.Account, err = s.state.DB.GetAccountByID(ctx, bookmark.AccountID)
@ -64,10 +63,10 @@ func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*g
return bookmark, nil return bookmark, nil
} }
func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, db.Error) { func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, error) {
var id string var id string
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Column("status_bookmark.id"). Column("status_bookmark.id").
@ -76,13 +75,13 @@ func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID st
Limit(1) Limit(1)
if err := q.Scan(ctx, &id); err != nil { if err := q.Scan(ctx, &id); err != nil {
return "", s.conn.ProcessError(err) return "", s.db.ProcessError(err)
} }
return id, nil return id, nil
} }
func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, db.Error) { func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
limit = 0 limit = 0
@ -91,7 +90,7 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str
// Guess size of IDs based on limit. // Guess size of IDs based on limit.
ids := make([]string, 0, limit) ids := make([]string, 0, limit)
q := s.conn. q := s.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Column("status_bookmark.id"). Column("status_bookmark.id").
@ -115,7 +114,7 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str
} }
if err := q.Scan(ctx, &ids); err != nil { if err := q.Scan(ctx, &ids); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
bookmarks := make([]*gtsmodel.StatusBookmark, 0, len(ids)) bookmarks := make([]*gtsmodel.StatusBookmark, 0, len(ids))
@ -133,26 +132,26 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str
return bookmarks, nil return bookmarks, nil
} }
func (s *statusBookmarkDB) PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) db.Error { func (s *statusBookmarkDB) PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) error {
_, err := s.conn. _, err := s.db.
NewInsert(). NewInsert().
Model(statusBookmark). Model(statusBookmark).
Exec(ctx) Exec(ctx)
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }
func (s *statusBookmarkDB) DeleteStatusBookmark(ctx context.Context, id string) db.Error { func (s *statusBookmarkDB) DeleteStatusBookmark(ctx context.Context, id string) error {
_, err := s.conn. _, err := s.db.
NewDelete(). NewDelete().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Where("? = ?", bun.Ident("status_bookmark.id"), id). Where("? = ?", bun.Ident("status_bookmark.id"), id).
Exec(ctx) Exec(ctx)
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }
func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) db.Error { func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) error {
if targetAccountID == "" && originAccountID == "" { if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteBookmarks: one of targetAccountID or originAccountID must be set") return errors.New("DeleteBookmarks: one of targetAccountID or originAccountID must be set")
} }
@ -161,7 +160,7 @@ func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAcco
// statement (when bookmarks have a cache), // statement (when bookmarks have a cache),
// + use the IDs to invalidate cache entries. // + use the IDs to invalidate cache entries.
q := s.conn. q := s.db.
NewDelete(). NewDelete().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")) TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark"))
@ -174,24 +173,24 @@ func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAcco
} }
if _, err := q.Exec(ctx); err != nil { if _, err := q.Exec(ctx); err != nil {
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }
return nil return nil
} }
func (s *statusBookmarkDB) DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) db.Error { func (s *statusBookmarkDB) DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) error {
// TODO: Capture bookmark IDs in a RETURNING // TODO: Capture bookmark IDs in a RETURNING
// statement (when bookmarks have a cache), // statement (when bookmarks have a cache),
// + use the IDs to invalidate cache entries. // + use the IDs to invalidate cache entries.
q := s.conn. q := s.db.
NewDelete(). NewDelete().
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Where("? = ?", bun.Ident("status_bookmark.status_id"), statusID) Where("? = ?", bun.Ident("status_bookmark.status_id"), statusID)
if _, err := q.Exec(ctx); err != nil { if _, err := q.Exec(ctx); err != nil {
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }
return nil return nil

View file

@ -32,16 +32,16 @@ import (
) )
type statusFaveDB struct { type statusFaveDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) { func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) {
return s.getStatusFave( return s.getStatusFave(
ctx, ctx,
"AccountID.StatusID", "AccountID.StatusID",
func(fave *gtsmodel.StatusFave) error { func(fave *gtsmodel.StatusFave) error {
return s.conn. return s.db.
NewSelect(). NewSelect().
Model(fave). Model(fave).
Where("? = ?", bun.Ident("account_id"), accountID). Where("? = ?", bun.Ident("account_id"), accountID).
@ -53,12 +53,12 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, stat
) )
} }
func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) { func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, error) {
return s.getStatusFave( return s.getStatusFave(
ctx, ctx,
"ID", "ID",
func(fave *gtsmodel.StatusFave) error { func(fave *gtsmodel.StatusFave) error {
return s.conn. return s.db.
NewSelect(). NewSelect().
Model(fave). Model(fave).
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
@ -75,7 +75,7 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery
// Not cached! Perform database query. // Not cached! Perform database query.
if err := dbQuery(&fave); err != nil { if err := dbQuery(&fave); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
return &fave, nil return &fave, nil
@ -119,16 +119,16 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery
return fave, nil return fave, nil
} }
func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) { func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error) {
ids := []string{} ids := []string{}
if err := s.conn. if err := s.db.
NewSelect(). NewSelect().
Table("status_faves"). Table("status_faves").
Column("id"). Column("id").
Where("? = ?", bun.Ident("status_id"), statusID). Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx, &ids); err != nil { Scan(ctx, &ids); err != nil {
return nil, s.conn.ProcessError(err) return nil, s.db.ProcessError(err)
} }
faves := make([]*gtsmodel.StatusFave, 0, len(ids)) faves := make([]*gtsmodel.StatusFave, 0, len(ids))
@ -188,17 +188,17 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo
return errs.Combine() return errs.Combine()
} }
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error { func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error {
return s.state.Caches.GTS.StatusFave().Store(fave, func() error { return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
_, err := s.conn. _, err := s.db.
NewInsert(). NewInsert().
Model(fave). Model(fave).
Exec(ctx) Exec(ctx)
return s.conn.ProcessError(err) return s.db.ProcessError(err)
}) })
} }
func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error { func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) error {
defer s.state.Caches.GTS.StatusFave().Invalidate("ID", id) defer s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
// Load fave into cache before attempting a delete, // Load fave into cache before attempting a delete,
@ -214,21 +214,21 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.E
} }
// Finally delete fave from DB. // Finally delete fave from DB.
_, err = s.conn.NewDelete(). _, err = s.db.NewDelete().
Table("status_faves"). Table("status_faves").
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Exec(ctx) Exec(ctx)
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error { func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error {
if targetAccountID == "" && originAccountID == "" { if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set") return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set")
} }
var faveIDs []string var faveIDs []string
q := s.conn. q := s.db.
NewSelect(). NewSelect().
Column("id"). Column("id").
Table("status_faves") Table("status_faves")
@ -242,7 +242,7 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
} }
if _, err := q.Exec(ctx, &faveIDs); err != nil { if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }
defer func() { defer func() {
@ -263,24 +263,24 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
} }
// Finally delete all from DB. // Finally delete all from DB.
_, err := s.conn.NewDelete(). _, err := s.db.NewDelete().
Table("status_faves"). Table("status_faves").
Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)). Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)).
Exec(ctx) Exec(ctx)
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error { func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) error {
// Capture fave IDs in a RETURNING statement. // Capture fave IDs in a RETURNING statement.
var faveIDs []string var faveIDs []string
q := s.conn. q := s.db.
NewSelect(). NewSelect().
Column("id"). Column("id").
Table("status_faves"). Table("status_faves").
Where("? = ?", bun.Ident("status_id"), statusID) Where("? = ?", bun.Ident("status_id"), statusID)
if _, err := q.Exec(ctx, &faveIDs); err != nil { if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }
defer func() { defer func() {
@ -301,9 +301,9 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID
} }
// Finally delete all from DB. // Finally delete all from DB.
_, err := s.conn.NewDelete(). _, err := s.db.NewDelete().
Table("status_faves"). Table("status_faves").
Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)). Where("? IN (?)", bun.Ident("id"), bun.In(faveIDs)).
Exec(ctx) Exec(ctx)
return s.conn.ProcessError(err) return s.db.ProcessError(err)
} }

View file

@ -33,11 +33,11 @@ import (
) )
type timelineDB struct { type timelineDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
limit = 0 limit = 0
@ -49,7 +49,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
frontToBack = true frontToBack = true
) )
q := t.conn. q := t.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table // Select only IDs from table
@ -103,7 +103,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
// Subquery to select target (followed) account // Subquery to select target (followed) account
// IDs from follows owned by given accountID. // IDs from follows owned by given accountID.
subQ := t.conn. subQ := t.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.target_account_id"). Column("follow.target_account_id").
@ -119,7 +119,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
}) })
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err) return nil, t.db.ProcessError(err)
} }
if len(statusIDs) == 0 { if len(statusIDs) == 0 {
@ -151,7 +151,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
return statuses, nil return statuses, nil
} }
func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
limit = 0 limit = 0
@ -160,7 +160,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
// Make educated guess for slice size // Make educated guess for slice size
statusIDs := make([]string, 0, limit) statusIDs := make([]string, 0, limit)
q := t.conn. q := t.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id"). Column("status.id").
@ -202,7 +202,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
} }
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err) return nil, t.db.ProcessError(err)
} }
statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
@ -224,7 +224,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) {
// Ensure reasonable // Ensure reasonable
if limit < 0 { if limit < 0 {
limit = 0 limit = 0
@ -233,7 +233,7 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
// Make educated guess for slice size // Make educated guess for slice size
faves := make([]*gtsmodel.StatusFave, 0, limit) faves := make([]*gtsmodel.StatusFave, 0, limit)
fq := t.conn. fq := t.db.
NewSelect(). NewSelect().
Model(&faves). Model(&faves).
Where("? = ?", bun.Ident("status_fave.account_id"), accountID). Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
@ -253,7 +253,7 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
err := fq.Scan(ctx) err := fq.Scan(ctx)
if err != nil { if err != nil {
return nil, "", "", t.conn.ProcessError(err) return nil, "", "", t.db.ProcessError(err)
} }
if len(faves) == 0 { if len(faves) == 0 {
@ -322,7 +322,7 @@ func (t *timelineDB) GetListTimeline(
} }
// Select target account IDs from follows. // Select target account IDs from follows.
subQ := t.conn. subQ := t.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.target_account_id"). Column("follow.target_account_id").
@ -330,7 +330,7 @@ func (t *timelineDB) GetListTimeline(
// Select only status IDs created // Select only status IDs created
// by one of the followed accounts. // by one of the followed accounts.
q := t.conn. q := t.db.
NewSelect(). NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table // Select only IDs from table
@ -379,7 +379,7 @@ func (t *timelineDB) GetListTimeline(
} }
if err := q.Scan(ctx, &statusIDs); err != nil { if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err) return nil, t.db.ProcessError(err)
} }
if len(statusIDs) == 0 { if len(statusIDs) == 0 {

View file

@ -27,28 +27,28 @@ import (
) )
type tombstoneDB struct { type tombstoneDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, db.Error) { func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) {
return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) { return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) {
var tomb gtsmodel.Tombstone var tomb gtsmodel.Tombstone
q := t.conn. q := t.db.
NewSelect(). NewSelect().
Model(&tomb). Model(&tomb).
Where("? = ?", bun.Ident("tombstone.uri"), uri) Where("? = ?", bun.Ident("tombstone.uri"), uri)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, t.conn.ProcessError(err) return nil, t.db.ProcessError(err)
} }
return &tomb, nil return &tomb, nil
}, uri) }, uri)
} }
func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (bool, db.Error) { func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (bool, error) {
tomb, err := t.GetTombstoneByURI(ctx, uri) tomb, err := t.GetTombstoneByURI(ctx, uri)
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
err = nil err = nil
@ -56,23 +56,23 @@ func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (b
return (tomb != nil), err return (tomb != nil), err
} }
func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) db.Error { func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error {
return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error { return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error {
_, err := t.conn. _, err := t.db.
NewInsert(). NewInsert().
Model(tombstone). Model(tombstone).
Exec(ctx) Exec(ctx)
return t.conn.ProcessError(err) return t.db.ProcessError(err)
}) })
} }
func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) db.Error { func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error {
defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id) defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id)
// Delete tombstone from DB. // Delete tombstone from DB.
_, err := t.conn.NewDelete(). _, err := t.db.NewDelete().
TableExpr("? AS ?", bun.Ident("tombstones"), bun.Ident("tombstone")). TableExpr("? AS ?", bun.Ident("tombstones"), bun.Ident("tombstone")).
Where("? = ?", bun.Ident("tombstone.id"), id). Where("? = ?", bun.Ident("tombstone.id"), id).
Exec(ctx) Exec(ctx)
return t.conn.ProcessError(err) return t.db.ProcessError(err)
} }

View file

@ -30,125 +30,125 @@ import (
) )
type userDB struct { type userDB struct {
conn *DBConn db *WrappedDB
state *state.State state *state.State
} }
func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("ID", func() (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("ID", func() (*gtsmodel.User, error) {
var user gtsmodel.User var user gtsmodel.User
q := u.conn. q := u.db.
NewSelect(). NewSelect().
Model(&user). Model(&user).
Relation("Account"). Relation("Account").
Where("? = ?", bun.Ident("user.id"), id) Where("? = ?", bun.Ident("user.id"), id)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err) return nil, u.db.ProcessError(err)
} }
return &user, nil return &user, nil
}, id) }, id)
} }
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("AccountID", func() (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("AccountID", func() (*gtsmodel.User, error) {
var user gtsmodel.User var user gtsmodel.User
q := u.conn. q := u.db.
NewSelect(). NewSelect().
Model(&user). Model(&user).
Relation("Account"). Relation("Account").
Where("? = ?", bun.Ident("user.account_id"), accountID) Where("? = ?", bun.Ident("user.account_id"), accountID)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err) return nil, u.db.ProcessError(err)
} }
return &user, nil return &user, nil
}, accountID) }, accountID)
} }
func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("Email", func() (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("Email", func() (*gtsmodel.User, error) {
var user gtsmodel.User var user gtsmodel.User
q := u.conn. q := u.db.
NewSelect(). NewSelect().
Model(&user). Model(&user).
Relation("Account"). Relation("Account").
Where("? = ?", bun.Ident("user.email"), emailAddress) Where("? = ?", bun.Ident("user.email"), emailAddress)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err) return nil, u.db.ProcessError(err)
} }
return &user, nil return &user, nil
}, emailAddress) }, emailAddress)
} }
func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("ExternalID", func() (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("ExternalID", func() (*gtsmodel.User, error) {
var user gtsmodel.User var user gtsmodel.User
q := u.conn. q := u.db.
NewSelect(). NewSelect().
Model(&user). Model(&user).
Relation("Account"). Relation("Account").
Where("? = ?", bun.Ident("user.external_id"), id) Where("? = ?", bun.Ident("user.external_id"), id)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err) return nil, u.db.ProcessError(err)
} }
return &user, nil return &user, nil
}, id) }, id)
} }
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) { func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("ConfirmationToken", func() (*gtsmodel.User, error) { return u.state.Caches.GTS.User().Load("ConfirmationToken", func() (*gtsmodel.User, error) {
var user gtsmodel.User var user gtsmodel.User
q := u.conn. q := u.db.
NewSelect(). NewSelect().
Model(&user). Model(&user).
Relation("Account"). Relation("Account").
Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken) Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken)
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err) return nil, u.db.ProcessError(err)
} }
return &user, nil return &user, nil
}, confirmationToken) }, confirmationToken)
} }
func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, db.Error) { func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
var users []*gtsmodel.User var users []*gtsmodel.User
q := u.conn. q := u.db.
NewSelect(). NewSelect().
Model(&users). Model(&users).
Relation("Account") Relation("Account")
if err := q.Scan(ctx); err != nil { if err := q.Scan(ctx); err != nil {
return nil, u.conn.ProcessError(err) return nil, u.db.ProcessError(err)
} }
return users, nil return users, nil
} }
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) db.Error { func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error {
return u.state.Caches.GTS.User().Store(user, func() error { return u.state.Caches.GTS.User().Store(user, func() error {
_, err := u.conn. _, err := u.db.
NewInsert(). NewInsert().
Model(user). Model(user).
Exec(ctx) Exec(ctx)
return u.conn.ProcessError(err) return u.db.ProcessError(err)
}) })
} }
func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) db.Error { func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) error {
// Update the user's last-updated // Update the user's last-updated
user.UpdatedAt = time.Now() user.UpdatedAt = time.Now()
@ -158,17 +158,17 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
} }
return u.state.Caches.GTS.User().Store(user, func() error { return u.state.Caches.GTS.User().Store(user, func() error {
_, err := u.conn. _, err := u.db.
NewUpdate(). NewUpdate().
Model(user). Model(user).
Where("? = ?", bun.Ident("user.id"), user.ID). Where("? = ?", bun.Ident("user.id"), user.ID).
Column(columns...). Column(columns...).
Exec(ctx) Exec(ctx)
return u.conn.ProcessError(err) return u.db.ProcessError(err)
}) })
} }
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error {
defer u.state.Caches.GTS.User().Invalidate("ID", userID) defer u.state.Caches.GTS.User().Invalidate("ID", userID)
// Load user into cache before attempting a delete, // Load user into cache before attempting a delete,
@ -184,9 +184,9 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
} }
// Finally delete user from DB. // Finally delete user from DB.
_, err = u.conn.NewDelete(). _, err = u.db.NewDelete().
TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
Where("? = ?", bun.Ident("user.id"), userID). Where("? = ?", bun.Ident("user.id"), userID).
Exec(ctx) Exec(ctx)
return u.conn.ProcessError(err) return u.db.ProcessError(err)
} }

258
internal/db/bundb/wrap.go Normal file
View file

@ -0,0 +1,258 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"database/sql"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
// WrappedDB wraps a bun database instance
// to provide common per-dialect SQL error
// conversions to common types, and retries
// on returned busy errors (SQLite only for now).
type WrappedDB struct {
errHook func(error) error
*bun.DB // underlying conn
}
// WrapDB wraps a bun database instance in our own WrappedDB type.
func WrapDB(db *bun.DB) *WrappedDB {
var errProc func(error) error
switch name := db.Dialect().Name(); name {
case dialect.PG:
errProc = processPostgresError
case dialect.SQLite:
errProc = processSQLiteError
default:
panic("unknown dialect name: " + name.String())
}
return &WrappedDB{
errHook: errProc,
DB: db,
}
}
// BeginTx wraps bun.DB.BeginTx() with retry-busy timeout.
func (db *WrappedDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (tx bun.Tx, err error) {
err = retryOnBusy(ctx, func() error {
tx, err = db.DB.BeginTx(ctx, opts)
err = db.ProcessError(err)
return err
})
return
}
// ExecContext wraps bun.DB.ExecContext() with retry-busy timeout.
func (db *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
err = retryOnBusy(ctx, func() error {
result, err = db.DB.ExecContext(ctx, query, args...)
err = db.ProcessError(err)
return err
})
return
}
// QueryContext wraps bun.DB.QueryContext() with retry-busy timeout.
func (db *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
err = retryOnBusy(ctx, func() error {
rows, err = db.DB.QueryContext(ctx, query, args...)
err = db.ProcessError(err)
return err
})
return
}
// QueryRowContext wraps bun.DB.QueryRowContext() with retry-busy timeout.
func (db *WrappedDB) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) {
_ = retryOnBusy(ctx, func() error {
row = db.DB.QueryRowContext(ctx, query, args...)
err := db.ProcessError(row.Err())
return err
})
return
}
// RunInTx is functionally the same as bun.DB.RunInTx() but with retry-busy timeouts.
func (db *WrappedDB) RunInTx(ctx context.Context, fn func(bun.Tx) error) error {
// Attempt to start new transaction.
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
var done bool
defer func() {
if !done {
// Rollback (with retry-backoff).
_ = retryOnBusy(ctx, func() error {
err := tx.Rollback()
return db.errHook(err)
})
}
}()
// Perform supplied transaction
if err := fn(tx); err != nil {
return db.errHook(err)
}
// Commit (with retry-backoff).
err = retryOnBusy(ctx, func() error {
err := tx.Commit()
return db.errHook(err)
})
done = true
return err
}
func (db *WrappedDB) NewValues(model interface{}) *bun.ValuesQuery {
return bun.NewValuesQuery(db.DB, model).Conn(db)
}
func (db *WrappedDB) NewMerge() *bun.MergeQuery {
return bun.NewMergeQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewSelect() *bun.SelectQuery {
return bun.NewSelectQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewInsert() *bun.InsertQuery {
return bun.NewInsertQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewUpdate() *bun.UpdateQuery {
return bun.NewUpdateQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewDelete() *bun.DeleteQuery {
return bun.NewDeleteQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewRaw(query string, args ...interface{}) *bun.RawQuery {
return bun.NewRawQuery(db.DB, query, args...).Conn(db)
}
func (db *WrappedDB) NewCreateTable() *bun.CreateTableQuery {
return bun.NewCreateTableQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewDropTable() *bun.DropTableQuery {
return bun.NewDropTableQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewCreateIndex() *bun.CreateIndexQuery {
return bun.NewCreateIndexQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewDropIndex() *bun.DropIndexQuery {
return bun.NewDropIndexQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewTruncateTable() *bun.TruncateTableQuery {
return bun.NewTruncateTableQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewAddColumn() *bun.AddColumnQuery {
return bun.NewAddColumnQuery(db.DB).Conn(db)
}
func (db *WrappedDB) NewDropColumn() *bun.DropColumnQuery {
return bun.NewDropColumnQuery(db.DB).Conn(db)
}
// ProcessError processes an error to replace any known values with our own error types,
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
func (db *WrappedDB) ProcessError(err error) error {
if err == nil {
return nil
}
return db.errHook(err)
}
// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
func (db *WrappedDB) Exists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
exists, err := query.Exists(ctx)
switch err {
case nil:
return exists, nil
case sql.ErrNoRows:
return false, nil
default:
return false, err
}
}
// NotExists is the functional opposite of conn.Exists()
func (db *WrappedDB) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
exists, err := db.Exists(ctx, query)
return !exists, err
}
// retryOnBusy will retry given function on returned 'errBusy'.
func retryOnBusy(ctx context.Context, fn func() error) error {
var backoff time.Duration
for i := 0; ; i++ {
// Perform func.
err := fn()
if err != errBusy {
// May be nil, or may be
// some other error, either
// way return here.
return err
}
// backoff according to a multiplier of 2ms * 2^2n,
// up to a maximum possible backoff time of 5 minutes.
//
// this works out as the following:
// 4ms
// 16ms
// 64ms
// 256ms
// 1.024s
// 4.096s
// 16.384s
// 1m5.536s
// 4m22.144s
backoff = 2 * time.Millisecond * (1 << (2*i + 1))
if backoff >= 5*time.Minute {
break
}
select {
// Context cancelled.
case <-ctx.Done():
// Backoff for some time.
case <-time.After(backoff):
}
}
return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff)
}

View file

@ -27,29 +27,29 @@ import (
// Domain contains DB functions related to domains and domain blocks. // Domain contains DB functions related to domains and domain blocks.
type Domain interface { type Domain interface {
// CreateDomainBlock puts the given instance-level domain block into the database. // CreateDomainBlock puts the given instance-level domain block into the database.
CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) error
// GetDomainBlock returns one instance-level domain block with the given domain, if it exists. // GetDomainBlock returns one instance-level domain block with the given domain, if it exists.
GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, error)
// GetDomainBlockByID returns one instance-level domain block with the given id, if it exists. // GetDomainBlockByID returns one instance-level domain block with the given id, if it exists.
GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, Error) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, error)
// GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance. // GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance.
GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error)
// DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists. // DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists.
DeleteDomainBlock(ctx context.Context, domain string) Error DeleteDomainBlock(ctx context.Context, domain string) error
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`). // IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
IsDomainBlocked(ctx context.Context, domain string) (bool, Error) IsDomainBlocked(ctx context.Context, domain string) (bool, error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found. // AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error) AreDomainsBlocked(ctx context.Context, domains []string) (bool, error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`). // IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found. // AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, error)
} }

View file

@ -31,16 +31,16 @@ const EmojiAllDomains string = "all"
// Emoji contains functions for getting emoji in the database. // Emoji contains functions for getting emoji in the database.
type Emoji interface { type Emoji interface {
// PutEmoji puts one emoji in the database. // PutEmoji puts one emoji in the database.
PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) Error PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error
// UpdateEmoji updates the given columns of one emoji. // UpdateEmoji updates the given columns of one emoji.
// If no columns are specified, every column is updated. // If no columns are specified, every column is updated.
UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) error UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, columns ...string) error
// DeleteEmojiByID deletes one emoji by its database ID. // DeleteEmojiByID deletes one emoji by its database ID.
DeleteEmojiByID(ctx context.Context, id string) Error DeleteEmojiByID(ctx context.Context, id string) error
// GetEmojisByIDs gets emojis for the given IDs. // GetEmojisByIDs gets emojis for the given IDs.
GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, Error) GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, error)
// GetUseableEmojis gets all emojis which are useable by accounts on this instance. // GetUseableEmojis gets all emojis which are useable by accounts on this instance.
GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, Error) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, error)
// GetEmojis fetches all emojis with IDs less than 'maxID', up to a maximum of 'limit' emojis. // GetEmojis fetches all emojis with IDs less than 'maxID', up to a maximum of 'limit' emojis.
GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gtsmodel.Emoji, error)
@ -54,22 +54,22 @@ type Emoji interface {
// GetEmojisBy gets emojis based on given parameters. Useful for admin actions. // GetEmojisBy gets emojis based on given parameters. Useful for admin actions.
GetEmojisBy(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, error) GetEmojisBy(ctx context.Context, domain string, includeDisabled bool, includeEnabled bool, shortcode string, maxShortcodeDomain string, minShortcodeDomain string, limit int) ([]*gtsmodel.Emoji, error)
// GetEmojiByID gets a specific emoji by its database ID. // GetEmojiByID gets a specific emoji by its database ID.
GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, Error) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, error)
// GetEmojiByShortcodeDomain gets an emoji based on its shortcode and domain. // GetEmojiByShortcodeDomain gets an emoji based on its shortcode and domain.
// For local emoji, domain should be an empty string. // For local emoji, domain should be an empty string.
GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, Error) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error)
// GetEmojiByURI returns one emoji based on its ActivityPub URI. // GetEmojiByURI returns one emoji based on its ActivityPub URI.
GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, Error) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, error)
// GetEmojiByStaticURL gets an emoji using the URL of the static version of the emoji image. // GetEmojiByStaticURL gets an emoji using the URL of the static version of the emoji image.
GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, Error) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string) (*gtsmodel.Emoji, error)
// PutEmojiCategory puts one new emoji category in the database. // PutEmojiCategory puts one new emoji category in the database.
PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) Error PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error
// GetEmojiCategoriesByIDs gets emoji categories for given IDs. // GetEmojiCategoriesByIDs gets emoji categories for given IDs.
GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, Error) GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, error)
// GetEmojiCategories gets a slice of the names of all existing emoji categories. // GetEmojiCategories gets a slice of the names of all existing emoji categories.
GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, Error) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCategory, error)
// GetEmojiCategory gets one emoji category by its id. // GetEmojiCategory gets one emoji category by its id.
GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, Error) GetEmojiCategory(ctx context.Context, id string) (*gtsmodel.EmojiCategory, error)
// GetEmojiCategoryByName gets one emoji category by its name. // GetEmojiCategoryByName gets one emoji category by its name.
GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, Error) GetEmojiCategoryByName(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error)
} }

View file

@ -17,18 +17,20 @@
package db package db
import "fmt" import (
"database/sql"
// Error denotes a database error. "errors"
type Error error )
var ( var (
// ErrNoEntries is returned when a caller expected an entry for a query, but none was found. // ErrNoEntries is a direct ptr to sql.ErrNoRows since that is returned regardless
ErrNoEntries Error = fmt.Errorf("no entries") // of DB dialect. It is returned when no rows (entries) can be found for a query.
// ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found. ErrNoEntries = sql.ErrNoRows
ErrMultipleEntries Error = fmt.Errorf("multiple entries")
// ErrAlreadyExists is returned when a conflict was encountered in the db when doing an insert. // ErrAlreadyExists is returned when a conflict was encountered in the db when doing an insert.
ErrAlreadyExists Error = fmt.Errorf("already exists") ErrAlreadyExists = errors.New("already exists")
// ErrUnknown denotes an unknown database error.
ErrUnknown Error = fmt.Errorf("unknown error") // ErrBusyTimeout is returned if the database connection indicates the connection is too busy
// to complete the supplied query. This is generally intended to be handled internally by the DB.
ErrBusyTimeout = errors.New("busy timeout")
) )

View file

@ -26,16 +26,16 @@ import (
// Instance contains functions for instance-level actions (counting instance users etc.). // Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface { type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain. // CountInstanceUsers returns the number of known accounts registered with the given domain.
CountInstanceUsers(ctx context.Context, domain string) (int, Error) CountInstanceUsers(ctx context.Context, domain string) (int, error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain. // CountInstanceStatuses returns the number of known statuses posted from the given domain.
CountInstanceStatuses(ctx context.Context, domain string) (int, Error) CountInstanceStatuses(ctx context.Context, domain string) (int, error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with. // CountInstanceDomains returns the number of known instances known that the given domain federates with.
CountInstanceDomains(ctx context.Context, domain string) (int, Error) CountInstanceDomains(ctx context.Context, domain string) (int, error)
// GetInstance returns the instance entry for the given domain, if it exists. // GetInstance returns the instance entry for the given domain, if it exists.
GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, error)
// GetInstanceByID returns the instance entry corresponding to the given id, if it exists. // GetInstanceByID returns the instance entry corresponding to the given id, if it exists.
GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error)
@ -47,12 +47,12 @@ type Instance interface {
UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID. // GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, error)
// GetInstancePeers returns a slice of instances that the host instance knows about. // GetInstancePeers returns a slice of instances that the host instance knows about.
GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, Error) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, error)
// GetInstanceModeratorAddresses returns a slice of email addresses belonging to active // GetInstanceModeratorAddresses returns a slice of email addresses belonging to active
// (as in, not suspended) moderators + admins on this instance. // (as in, not suspended) moderators + admins on this instance.
GetInstanceModeratorAddresses(ctx context.Context) ([]string, Error) GetInstanceModeratorAddresses(ctx context.Context) ([]string, error)
} }

View file

@ -27,7 +27,7 @@ import (
// Media contains functions related to creating/getting/removing media attachments. // Media contains functions related to creating/getting/removing media attachments.
type Media interface { type Media interface {
// GetAttachmentByID gets a single attachment by its ID. // GetAttachmentByID gets a single attachment by its ID.
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, error)
// GetAttachmentsByIDs fetches a list of media attachments for given IDs. // GetAttachmentsByIDs fetches a list of media attachments for given IDs.
GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error)
@ -49,25 +49,25 @@ type Media interface {
// GetCachedAttachmentsOlderThan gets limit n remote attachments (including avatars and headers) older than // GetCachedAttachmentsOlderThan gets limit n remote attachments (including avatars and headers) older than
// the given time. These will be returned in order of attachment.created_at descending (i.e. newest to oldest). // the given time. These will be returned in order of attachment.created_at descending (i.e. newest to oldest).
GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, Error) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error)
// CountRemoteOlderThan is like GetRemoteOlderThan, except instead of getting limit n attachments, // CountRemoteOlderThan is like GetRemoteOlderThan, except instead of getting limit n attachments,
// it just counts how many remote attachments in the database (including avatars and headers) meet // it just counts how many remote attachments in the database (including avatars and headers) meet
// the olderThan criteria. // the olderThan criteria.
CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, Error) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, error)
// GetAvatarsAndHeaders fetches limit n avatars and headers with an id < maxID. These headers // GetAvatarsAndHeaders fetches limit n avatars and headers with an id < maxID. These headers
// and avis may be in use or not; the caller should check this if it's important. // and avis may be in use or not; the caller should check this if it's important.
GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, Error) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error)
// GetLocalUnattachedOlderThan fetches limit n local media attachments (including avatars and headers), older than // GetLocalUnattachedOlderThan fetches limit n local media attachments (including avatars and headers), older than
// the given time, which aren't header or avatars, and aren't attached to a status. In other words, attachments which were // the given time, which aren't header or avatars, and aren't attached to a status. In other words, attachments which were
// uploaded but never used for whatever reason, or attachments that were attached to a status which was subsequently deleted. // uploaded but never used for whatever reason, or attachments that were attached to a status which was subsequently deleted.
// //
// These will be returned in order of attachment.created_at descending (newest to oldest in other words). // These will be returned in order of attachment.created_at descending (newest to oldest in other words).
GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, Error) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, error)
// CountLocalUnattachedOlderThan is like GetLocalUnattachedOlderThan, except instead of getting limit n attachments, // CountLocalUnattachedOlderThan is like GetLocalUnattachedOlderThan, except instead of getting limit n attachments,
// it just counts how many local attachments in the database meet the olderThan criteria. // it just counts how many local attachments in the database meet the olderThan criteria.
CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, Error) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, error)
} }

View file

@ -26,10 +26,10 @@ import (
// Mention contains functions for getting/creating mentions in the database. // Mention contains functions for getting/creating mentions in the database.
type Mention interface { type Mention interface {
// GetMention gets a single mention by ID // GetMention gets a single mention by ID
GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error)
// GetMentions gets multiple mentions. // GetMentions gets multiple mentions.
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error)
// PutMention will insert the given mention into the database. // PutMention will insert the given mention into the database.
PutMention(ctx context.Context, mention *gtsmodel.Mention) error PutMention(ctx context.Context, mention *gtsmodel.Mention) error

View file

@ -28,21 +28,21 @@ type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID. // GetNotifications returns a slice of notifications that pertain to the given accountID.
// //
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, excludeTypes []string) ([]*gtsmodel.Notification, Error) GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, excludeTypes []string) ([]*gtsmodel.Notification, error)
// GetNotification returns one notification according to its id. // GetNotification returns one notification according to its id.
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error)
// GetNotification gets one notification according to the provided parameters, if it exists. // GetNotification gets one notification according to the provided parameters, if it exists.
// Since not all notifications are about a status, statusID can be an empty string. // Since not all notifications are about a status, statusID can be an empty string.
GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, Error) GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, error)
// PutNotification will insert the given notification into the database. // PutNotification will insert the given notification into the database.
PutNotification(ctx context.Context, notif *gtsmodel.Notification) error PutNotification(ctx context.Context, notif *gtsmodel.Notification) error
// DeleteNotificationByID deletes one notification according to its id, // DeleteNotificationByID deletes one notification according to its id,
// and removes that notification from the in-memory cache. // and removes that notification from the in-memory cache.
DeleteNotificationByID(ctx context.Context, id string) Error DeleteNotificationByID(ctx context.Context, id string) error
// DeleteNotifications mass deletes notifications targeting targetAccountID // DeleteNotifications mass deletes notifications targeting targetAccountID
// and/or originating from originAccountID. // and/or originating from originAccountID.
@ -57,10 +57,10 @@ type Notification interface {
// originate from originAccountID will be deleted. // originate from originAccountID will be deleted.
// //
// At least one parameter must not be an empty string. // At least one parameter must not be an empty string.
DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error
// DeleteNotificationsForStatus deletes all notifications that relate to // DeleteNotificationsForStatus deletes all notifications that relate to
// the given statusID. This function is useful when a status has been deleted, // the given statusID. This function is useful when a status has been deleted,
// and so notifications relating to that status must also be deleted. // and so notifications relating to that status must also be deleted.
DeleteNotificationsForStatus(ctx context.Context, statusID string) Error DeleteNotificationsForStatus(ctx context.Context, statusID string) error
} }

View file

@ -26,7 +26,7 @@ import (
// Relationship contains functions for getting or modifying the relationship between two accounts. // Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface { type Relationship interface {
// IsBlocked checks whether source account has a block in place against target. // IsBlocked checks whether source account has a block in place against target.
IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error)
// IsEitherBlocked checks whether there is a block in place between either of account1 and account2. // IsEitherBlocked checks whether there is a block in place between either of account1 and account2.
IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error)
@ -53,7 +53,7 @@ type Relationship interface {
DeleteAccountBlocks(ctx context.Context, accountID string) error DeleteAccountBlocks(ctx context.Context, accountID string) error
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
// GetFollowByID fetches follow with given ID from the database. // GetFollowByID fetches follow with given ID from the database.
GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error)
@ -77,13 +77,13 @@ type Relationship interface {
GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. // IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error)
// PutFollow attempts to place the given account follow in the database. // PutFollow attempts to place the given account follow in the database.
PutFollow(ctx context.Context, follow *gtsmodel.Follow) error PutFollow(ctx context.Context, follow *gtsmodel.Follow) error
@ -125,10 +125,10 @@ type Relationship interface {
// In other words, it should create the follow, and delete the existing follow request. // In other words, it should create the follow, and delete the existing follow request.
// //
// It will return the newly created follow for further processing. // It will return the newly created follow for further processing.
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// RejectFollowRequest fetches a follow request from the database, and then deletes it. // RejectFollowRequest fetches a follow request from the database, and then deletes it.
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error
// GetAccountFollows returns a slice of follows owned by the given accountID. // GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)

View file

@ -26,18 +26,18 @@ import (
// Report handles getting/creation/deletion/updating of user reports/flags. // Report handles getting/creation/deletion/updating of user reports/flags.
type Report interface { type Report interface {
// GetReportByID gets one report by its db id // GetReportByID gets one report by its db id
GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, Error) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, error)
// GetReports gets limit n reports using the given parameters. // GetReports gets limit n reports using the given parameters.
// Parameters that are empty / zero are ignored. // Parameters that are empty / zero are ignored.
GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, Error) GetReports(ctx context.Context, resolved *bool, accountID string, targetAccountID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Report, error)
// PutReport puts the given report in the database. // PutReport puts the given report in the database.
PutReport(ctx context.Context, report *gtsmodel.Report) Error PutReport(ctx context.Context, report *gtsmodel.Report) error
// UpdateReport updates one report by its db id. // UpdateReport updates one report by its db id.
// The given columns will be updated; if no columns are // The given columns will be updated; if no columns are
// provided, then all columns will be updated. // provided, then all columns will be updated.
// updated_at will also be updated, no need to pass this // updated_at will also be updated, no need to pass this
// as a specific column. // as a specific column.
UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, Error) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, error)
// DeleteReportByID deletes report with the given id. // DeleteReportByID deletes report with the given id.
DeleteReportByID(ctx context.Context, id string) Error DeleteReportByID(ctx context.Context, id string) error
} }

View file

@ -25,5 +25,5 @@ import (
// Session handles getting/creation of router sessions. // Session handles getting/creation of router sessions.
type Session interface { type Session interface {
GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error)
} }

View file

@ -26,34 +26,34 @@ import (
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses. // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface { type Status interface {
// GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs // GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error)
// GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs // GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, error)
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs // GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, error)
// PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc). // PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc).
PopulateStatus(ctx context.Context, status *gtsmodel.Status) error PopulateStatus(ctx context.Context, status *gtsmodel.Status) error
// PutStatus stores one status in the database. // PutStatus stores one status in the database.
PutStatus(ctx context.Context, status *gtsmodel.Status) Error PutStatus(ctx context.Context, status *gtsmodel.Status) error
// UpdateStatus updates one status in the database. // UpdateStatus updates one status in the database.
UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) Error UpdateStatus(ctx context.Context, status *gtsmodel.Status, columns ...string) error
// DeleteStatusByID deletes one status from the database. // DeleteStatusByID deletes one status from the database.
DeleteStatusByID(ctx context.Context, id string) Error DeleteStatusByID(ctx context.Context, id string) error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong // CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong // CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong // CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, error)
// GetStatuses gets a slice of statuses corresponding to the given status IDs. // GetStatuses gets a slice of statuses corresponding to the given status IDs.
GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error)
@ -64,26 +64,26 @@ type Status interface {
// GetStatusParents gets the parent statuses of a given status. // GetStatusParents gets the parent statuses of a given status.
// //
// If onlyDirect is true, only the immediate parent will be returned. // If onlyDirect is true, only the immediate parent will be returned.
GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
// GetStatusChildren gets the child statuses of a given status. // GetStatusChildren gets the child statuses of a given status.
// //
// If onlyDirect is true, only the immediate children will be returned. // If onlyDirect is true, only the immediate children will be returned.
GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID // IsStatusFavedBy checks if a given status has been faved by a given account ID
IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID // IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID // IsStatusMutedBy checks if a given status has been muted by a given account ID
IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status. // GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error)
} }

View file

@ -25,24 +25,24 @@ import (
type StatusBookmark interface { type StatusBookmark interface {
// GetStatusBookmark gets one status bookmark with the given ID. // GetStatusBookmark gets one status bookmark with the given ID.
GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, Error) GetStatusBookmark(ctx context.Context, id string) (*gtsmodel.StatusBookmark, error)
// GetStatusBookmarkID is a shortcut function for returning just the database ID // GetStatusBookmarkID is a shortcut function for returning just the database ID
// of a status bookmark created by the given accountID, targeting the given statusID. // of a status bookmark created by the given accountID, targeting the given statusID.
GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, Error) GetStatusBookmarkID(ctx context.Context, accountID string, statusID string) (string, error)
// GetStatusBookmarks retrieves status bookmarks created by the given accountID, // GetStatusBookmarks retrieves status bookmarks created by the given accountID,
// and using the provided parameters. If limit is < 0 then no limit will be set. // and using the provided parameters. If limit is < 0 then no limit will be set.
// //
// This function is primarily useful for paging through bookmarks in a sort of // This function is primarily useful for paging through bookmarks in a sort of
// timeline view. // timeline view.
GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, Error) GetStatusBookmarks(ctx context.Context, accountID string, limit int, maxID string, minID string) ([]*gtsmodel.StatusBookmark, error)
// PutStatusBookmark inserts the given statusBookmark into the database. // PutStatusBookmark inserts the given statusBookmark into the database.
PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) Error PutStatusBookmark(ctx context.Context, statusBookmark *gtsmodel.StatusBookmark) error
// DeleteStatusBookmark deletes one status bookmark with the given ID. // DeleteStatusBookmark deletes one status bookmark with the given ID.
DeleteStatusBookmark(ctx context.Context, id string) Error DeleteStatusBookmark(ctx context.Context, id string) error
// DeleteStatusBookmarks mass deletes status bookmarks targeting targetAccountID // DeleteStatusBookmarks mass deletes status bookmarks targeting targetAccountID
// and/or originating from originAccountID and/or bookmarking statusID. // and/or originating from originAccountID and/or bookmarking statusID.
@ -57,10 +57,10 @@ type StatusBookmark interface {
// originate from originAccountID will be deleted. // originate from originAccountID will be deleted.
// //
// At least one parameter must not be an empty string. // At least one parameter must not be an empty string.
DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) Error DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) error
// DeleteStatusBookmarksForStatus deletes all status bookmarks that target the // DeleteStatusBookmarksForStatus deletes all status bookmarks that target the
// given status ID. This is useful when a status has been deleted, and you need // given status ID. This is useful when a status has been deleted, and you need
// to clean up after it. // to clean up after it.
DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) Error DeleteStatusBookmarksForStatus(ctx context.Context, statusID string) error
} }

View file

@ -26,23 +26,23 @@ import (
type StatusFave interface { type StatusFave interface {
// GetStatusFaveByAccountID gets one status fave created by the given // GetStatusFaveByAccountID gets one status fave created by the given
// accountID, targeting the given statusID. // accountID, targeting the given statusID.
GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error)
// GetStatusFave returns one status fave with the given id. // GetStatusFave returns one status fave with the given id.
GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, error)
// GetStatusFaves returns a slice of faves/likes of the given status. // GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, error)
// PopulateStatusFave ensures that all sub-models of a fave are populated (account, status, etc). // PopulateStatusFave ensures that all sub-models of a fave are populated (account, status, etc).
PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error
// PutStatusFave inserts the given statusFave into the database. // PutStatusFave inserts the given statusFave into the database.
PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error
// DeleteStatusFave deletes one status fave with the given id. // DeleteStatusFave deletes one status fave with the given id.
DeleteStatusFaveByID(ctx context.Context, id string) Error DeleteStatusFaveByID(ctx context.Context, id string) error
// DeleteStatusFaves mass deletes status faves targeting targetAccountID // DeleteStatusFaves mass deletes status faves targeting targetAccountID
// and/or originating from originAccountID and/or faving statusID. // and/or originating from originAccountID and/or faving statusID.
@ -57,10 +57,10 @@ type StatusFave interface {
// originate from originAccountID will be deleted. // originate from originAccountID will be deleted.
// //
// At least one parameter must not be an empty string. // At least one parameter must not be an empty string.
DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) Error DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) error
// DeleteStatusFavesForStatus deletes all status faves that target the // DeleteStatusFavesForStatus deletes all status faves that target the
// given status ID. This is useful when a status has been deleted, and you need // given status ID. This is useful when a status has been deleted, and you need
// to clean up after it. // to clean up after it.
DeleteStatusFavesForStatus(ctx context.Context, statusID string) Error DeleteStatusFavesForStatus(ctx context.Context, statusID string) error
} }

View file

@ -28,13 +28,13 @@ type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id. // GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
// //
// Statuses should be returned in descending order of when they were created (newest first). // Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public. // GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit. // It will use the given filters and try to return as many statuses as possible up to the limit.
// //
// Statuses should be returned in descending order of when they were created (newest first). // Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved. // GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit. // It will use the given filters and try to return as many statuses as possible up to the limit.
@ -43,7 +43,7 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created. // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
// //
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error)
// GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID. // GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID.
// Statuses should be returned in descending order of when they were created (newest first). // Statuses should be returned in descending order of when they were created (newest first).

View file

@ -26,14 +26,14 @@ import (
// Tombstone contains functionality for storing + retrieving tombstones for remote AP Activities + Objects. // Tombstone contains functionality for storing + retrieving tombstones for remote AP Activities + Objects.
type Tombstone interface { type Tombstone interface {
// GetTombstoneByURI attempts to fetch a tombstone by the given URI. // GetTombstoneByURI attempts to fetch a tombstone by the given URI.
GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, Error) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error)
// TombstoneExistsWithURI returns true if a tombstone with the given URI exists. // TombstoneExistsWithURI returns true if a tombstone with the given URI exists.
TombstoneExistsWithURI(ctx context.Context, uri string) (bool, Error) TombstoneExistsWithURI(ctx context.Context, uri string) (bool, error)
// PutTombstone creates a new tombstone in the database. // PutTombstone creates a new tombstone in the database.
PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) Error PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error
// DeleteTombstone deletes a tombstone with the given ID. // DeleteTombstone deletes a tombstone with the given ID.
DeleteTombstone(ctx context.Context, id string) Error DeleteTombstone(ctx context.Context, id string) error
} }

View file

@ -26,21 +26,21 @@ import (
// User contains functions related to user getting/setting/creation. // User contains functions related to user getting/setting/creation.
type User interface { type User interface {
// GetAllUsers returns all local user accounts, or an error if something goes wrong. // GetAllUsers returns all local user accounts, or an error if something goes wrong.
GetAllUsers(ctx context.Context) ([]*gtsmodel.User, Error) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error)
// GetUserByID returns one user with the given ID, or an error if something goes wrong. // GetUserByID returns one user with the given ID, or an error if something goes wrong.
GetUserByID(ctx context.Context, id string) (*gtsmodel.User, Error) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error)
// GetUserByAccountID returns one user by its account ID, or an error if something goes wrong. // GetUserByAccountID returns one user by its account ID, or an error if something goes wrong.
GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, Error) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error)
// GetUserByID returns one user with the given email address, or an error if something goes wrong. // GetUserByID returns one user with the given email address, or an error if something goes wrong.
GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error)
// GetUserByExternalID returns one user with the given external id, or an error if something goes wrong. // GetUserByExternalID returns one user with the given external id, or an error if something goes wrong.
GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, Error) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error)
// GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong. // GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong.
GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error)
// PutUser will attempt to place user in the database // PutUser will attempt to place user in the database
PutUser(ctx context.Context, user *gtsmodel.User) Error PutUser(ctx context.Context, user *gtsmodel.User) error
// UpdateUser updates one user by its primary key, updating either only the specified columns, or all of them. // UpdateUser updates one user by its primary key, updating either only the specified columns, or all of them.
UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) Error UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) error
// DeleteUserByID deletes one user by its ID. // DeleteUserByID deletes one user by its ID.
DeleteUserByID(ctx context.Context, userID string) Error DeleteUserByID(ctx context.Context, userID string) error
} }

View file

@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -175,7 +176,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsername()
config.GetHost(), config.GetHost(),
) )
suite.True(gtserror.Unretrievable(err)) suite.True(gtserror.Unretrievable(err))
suite.EqualError(err, "no entries") suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(fetchedAccount) suite.Nil(fetchedAccount)
} }
@ -189,7 +190,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsernameDom
"localhost:8080", "localhost:8080",
) )
suite.True(gtserror.Unretrievable(err)) suite.True(gtserror.Unretrievable(err))
suite.EqualError(err, "no entries") suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(fetchedAccount) suite.Nil(fetchedAccount)
} }
@ -202,7 +203,7 @@ func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUserURI() {
testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"), testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"),
) )
suite.True(gtserror.Unretrievable(err)) suite.True(gtserror.Unretrievable(err))
suite.EqualError(err, "no entries") suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(fetchedAccount) suite.Nil(fetchedAccount)
} }

View file

@ -232,7 +232,7 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e
return nil, err return nil, err
} }
l.Infof("performing request") l.Info("performing request")
// Perform the request. // Perform the request.
rsp, err = c.do(r) rsp, err = c.do(r)

View file

@ -22,7 +22,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
@ -48,7 +47,7 @@ const (
// context for use down the line. // context for use down the line.
// //
// In case of an error, the request will be aborted with http code 500. // In case of an error, the request will be aborted with http code 500.
func SignatureCheck(uriBlocked func(context.Context, *url.URL) (bool, db.Error)) func(*gin.Context) { func SignatureCheck(uriBlocked func(context.Context, *url.URL) (bool, error)) func(*gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()

View file

@ -22,6 +22,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
) )
type AuthorizeTestSuite struct { type AuthorizeTestSuite struct {
@ -38,7 +39,7 @@ func (suite *AuthorizeTestSuite) TestAuthorize() {
suite.Equal(suite.testAccounts["local_account_2"].ID, account2.ID) suite.Equal(suite.testAccounts["local_account_2"].ID, account2.ID)
noAccount, err := suite.streamProcessor.Authorize(context.Background(), "aaaaaaaaaaaaaaaaaaaaa!!") noAccount, err := suite.streamProcessor.Authorize(context.Background(), "aaaaaaaaaaaaaaaaaaaaa!!")
suite.EqualError(err, "could not load access token: no entries") suite.EqualError(err, "could not load access token: "+db.ErrNoEntries.Error())
suite.Nil(noAccount) suite.Nil(noAccount)
} }

View file

@ -62,7 +62,7 @@ const (
type Module struct { type Module struct {
processor *processing.Processor processor *processing.Processor
eTagCache cache.Cache[string, eTagCacheEntry] eTagCache cache.Cache[string, eTagCacheEntry]
isURIBlocked func(context.Context, *url.URL) (bool, db.Error) isURIBlocked func(context.Context, *url.URL) (bool, error)
} }
func New(db db.DB, processor *processing.Processor) *Module { func New(db db.DB, processor *processing.Processor) *Module {