diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 705e1b118..4b4c78726 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -37,7 +37,7 @@ import ( ) type accountDB struct { - db *DB + db *bun.DB state *state.State } @@ -334,7 +334,7 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) e // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return a.db.RunInTx(ctx, func(tx Tx) error { + return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // create links between this account and any emojis it uses for _, i := range account.EmojiIDs { if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ @@ -363,7 +363,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return a.db.RunInTx(ctx, func(tx Tx) error { + return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // create links between this account and any emojis it uses // first clear out any old emoji links if _, err := tx. @@ -411,7 +411,7 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { return err } - return a.db.RunInTx(ctx, func(tx Tx) error { + return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // clear out any emoji links if _, err := tx. NewDelete(). diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index e189c508e..70ae68026 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -45,7 +45,7 @@ import ( const rsaKeyBits = 2048 type adminDB struct { - db *DB + db *bun.DB state *state.State } @@ -56,7 +56,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo Column("account.id"). Where("? = ?", bun.Ident("account.username"), username). Where("? IS NULL", bun.Ident("account.domain")) - return a.db.NotExists(ctx, q) + return notExists(ctx, q) } func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, error) { @@ -73,7 +73,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, err TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). Column("email_domain_block.id"). Where("? = ?", bun.Ident("email_domain_block.domain"), domain) - emailDomainBlocked, err := a.db.Exists(ctx, emailDomainBlockedQ) + emailDomainBlocked, err := exists(ctx, emailDomainBlockedQ) if err != nil { return false, err } @@ -88,7 +88,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, err Column("user.id"). Where("? = ?", bun.Ident("user.email"), email). WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) - return a.db.NotExists(ctx, q) + return notExists(ctx, q) } func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error) { @@ -229,7 +229,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) error { Where("? = ?", bun.Ident("account.username"), username). Where("? IS NULL", bun.Ident("account.domain")) - exists, err := a.db.Exists(ctx, q) + exists, err := exists(ctx, q) if err != nil { return err } @@ -287,7 +287,7 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) error { TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). Where("? = ?", bun.Ident("instance.domain"), host) - exists, err := a.db.Exists(ctx, q) + exists, err := exists(ctx, q) if err != nil { return err } diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index 2e17a0e94..f02632793 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -26,7 +26,7 @@ import ( ) type applicationDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index 488f59ad5..7b523f309 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -27,7 +27,7 @@ import ( ) type basicDB struct { - db *DB + db *bun.DB } func (b *basicDB) Put(ctx context.Context, i interface{}) error { diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 048474782..4ecbec7b9 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -52,13 +52,6 @@ import ( "modernc.org/sqlite" ) -var registerTables = []interface{}{ - >smodel.AccountToEmoji{}, - >smodel.StatusToEmoji{}, - >smodel.StatusToTag{}, - >smodel.ThreadToStatus{}, -} - // DBService satisfies the DB interface type DBService struct { db.Account @@ -88,12 +81,12 @@ type DBService struct { db.Timeline db.User db.Tombstone - db *DB + db *bun.DB } // GetDB returns the underlying database connection pool. // Should only be used in testing + exceptional circumstance. -func (dbService *DBService) DB() *DB { +func (dbService *DBService) DB() *bun.DB { return dbService.db } @@ -129,18 +122,18 @@ func doMigration(ctx context.Context, db *bun.DB) error { // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { - var db *DB + var db *bun.DB var err error t := strings.ToLower(config.GetDbType()) switch t { case "postgres": - db, err = pgConn(ctx) + db, err = pgConn(ctx, state) if err != nil { return nil, err } case "sqlite": - db, err = sqliteConn(ctx) + db, err = sqliteConn(ctx, state) if err != nil { return nil, err } @@ -159,14 +152,19 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { // table registration is needed for many-to-many, see: // https://bun.uptrace.dev/orm/many-to-many-relation/ - for _, t := range registerTables { + for _, t := range []interface{}{ + >smodel.AccountToEmoji{}, + >smodel.StatusToEmoji{}, + >smodel.StatusToTag{}, + >smodel.ThreadToStatus{}, + } { db.RegisterModel(t) } // perform any pending database migrations: this includes // the very first 'migration' on startup which just creates // necessary tables - if err := doMigration(ctx, db.bun); err != nil { + if err := doMigration(ctx, db); err != nil { return nil, fmt.Errorf("db migration error: %s", err) } @@ -284,13 +282,18 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { return ps, nil } -func pgConn(ctx context.Context) (*DB, error) { +func pgConn(ctx context.Context, state *state.State) (*bun.DB, error) { opts, err := deriveBunDBPGOptions() //nolint:contextcheck if err != nil { - return nil, fmt.Errorf("could not create bundb postgres options: %s", err) + return nil, fmt.Errorf("could not create bundb postgres options: %w", err) } - sqldb := stdlib.OpenDB(*opts) + cfg := stdlib.RegisterConnConfig(opts) + + sqldb, err := sql.Open("pgx-gts", cfg) + if err != nil { + return nil, fmt.Errorf("could not open postgres db: %w", err) + } // Tune db connections for postgres, see: // - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql @@ -299,18 +302,18 @@ func pgConn(ctx context.Context) (*DB, error) { sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections - db := WrapDB(bun.NewDB(sqldb, pgdialect.New())) + db := bun.NewDB(sqldb, pgdialect.New()) // ping to check the db is there and listening if err := db.PingContext(ctx); err != nil { - return nil, fmt.Errorf("postgres ping: %s", err) + return nil, fmt.Errorf("postgres ping: %w", err) } log.Info(ctx, "connected to POSTGRES database") return db, nil } -func sqliteConn(ctx context.Context) (*DB, error) { +func sqliteConn(ctx context.Context, state *state.State) (*bun.DB, error) { // validate db address has actually been set address := config.GetDbAddress() if address == "" { @@ -321,7 +324,7 @@ func sqliteConn(ctx context.Context) (*DB, error) { address = buildSQLiteAddress(address) // Open new DB instance - sqldb, err := sql.Open("sqlite", address) + sqldb, err := sql.Open("sqlite-gts", address) if err != nil { if errWithCode, ok := err.(*sqlite.Error); ok { err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) @@ -336,15 +339,14 @@ func sqliteConn(ctx context.Context) (*DB, error) { sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around sqldb.SetConnMaxLifetime(0) // don't kill connections due to age - // Wrap Bun database conn in our own wrapper - db := WrapDB(bun.NewDB(sqldb, sqlitedialect.New())) + db := bun.NewDB(sqldb, sqlitedialect.New()) // ping to check the db is there and listening if err := db.PingContext(ctx); err != nil { if errWithCode, ok := err.(*sqlite.Error); ok { err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) } - return nil, fmt.Errorf("sqlite ping: %s", err) + return nil, fmt.Errorf("sqlite ping: %w", err) } log.Infof(ctx, "connected to SQLITE database with address %s", address) @@ -418,7 +420,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { // parse the PEM block into the certificate caCert, err := x509.ParseCertificate(caPem.Bytes) if err != nil { - return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err) + return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %w", certPath, err) } // we're happy, add it to the existing pool and then use this pool in our tls config diff --git a/internal/db/bundb/db.go b/internal/db/bundb/db.go deleted file mode 100644 index 2b19ba0c4..000000000 --- a/internal/db/bundb/db.go +++ /dev/null @@ -1,578 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package bundb - -import ( - "context" - "database/sql" - "time" - "unsafe" - - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/uptrace/bun" - "github.com/uptrace/bun/dialect" - "github.com/uptrace/bun/schema" -) - -// DB wraps a bun database instance -// to provide common per-dialect SQL error -// conversions to common types, and retries -// on returned busy (SQLite only). -type DB struct { - // our own wrapped db type - // with retry backoff support. - // kept separate to the *bun.DB - // type to be passed into query - // builders as bun.IConn iface - // (this prevents double firing - // bun query hooks). - // - // also holds per-dialect - // error hook function. - raw rawdb - - // bun DB interface we use - // for dialects, and improved - // struct marshal/unmarshaling. - bun *bun.DB -} - -// WrapDB wraps a bun database instance in our database type. -func WrapDB(db *bun.DB) *DB { - var errProc func(error) error - switch name := db.Dialect().Name(); name { - case dialect.PG: - errProc = processPostgresError - case dialect.SQLite: - errProc = processSQLiteError - default: - panic("unknown dialect name: " + name.String()) - } - return &DB{ - raw: rawdb{ - errHook: errProc, - db: db.DB, - }, - bun: db, - } -} - -// Dialect is a direct call-through to bun.DB.Dialect(). -func (db *DB) Dialect() schema.Dialect { return db.bun.Dialect() } - -// AddQueryHook is a direct call-through to bun.DB.AddQueryHook(). -func (db *DB) AddQueryHook(hook bun.QueryHook) { db.bun.AddQueryHook(hook) } - -// RegisterModels is a direct call-through to bun.DB.RegisterModels(). -func (db *DB) RegisterModel(models ...any) { db.bun.RegisterModel(models...) } - -// PingContext is a direct call-through to bun.DB.PingContext(). -func (db *DB) PingContext(ctx context.Context) error { return db.bun.PingContext(ctx) } - -// Close is a direct call-through to bun.DB.Close(). -func (db *DB) Close() error { return db.bun.Close() } - -// ExecContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing. -func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { - bundb := db.bun // use underlying *bun.DB interface for their query formatting - err = retryOnBusy(ctx, func() error { - result, err = bundb.ExecContext(ctx, query, args...) - err = db.raw.errHook(err) - return err - }) - return -} - -// QueryContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing. -func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { - bundb := db.bun // use underlying *bun.DB interface for their query formatting - err = retryOnBusy(ctx, func() error { - rows, err = bundb.QueryContext(ctx, query, args...) - err = db.raw.errHook(err) - return err - }) - return -} - -// QueryRowContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing. -func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) { - bundb := db.bun // use underlying *bun.DB interface for their query formatting - _ = retryOnBusy(ctx, func() error { - row = bundb.QueryRowContext(ctx, query, args...) - if err := db.raw.errHook(row.Err()); err != nil { - updateRowError(row, err) // set new error - } - return row.Err() - }) - return -} - -// BeginTx wraps bun.DB.BeginTx() with retry-busy timeout and our own error processing. -func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (tx Tx, err error) { - var buntx bun.Tx // captured bun.Tx - bundb := db.bun // use *bun.DB interface to return bun.Tx type - - err = retryOnBusy(ctx, func() error { - buntx, err = bundb.BeginTx(ctx, opts) - err = db.raw.errHook(err) - return err - }) - - if err == nil { - // Wrap bun.Tx in our type. - tx = wrapTx(db, &buntx) - } - - return -} - -// RunInTx is functionally the same as bun.DB.RunInTx() but with retry-busy timeouts. -func (db *DB) RunInTx(ctx context.Context, fn func(Tx) error) error { - // Attempt to start new transaction. - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return err - } - - var done bool - - defer func() { - if !done { - // Rollback tx. - _ = tx.Rollback() - } - }() - - // Perform supplied transaction - if err := fn(tx); err != nil { - return err - } - - // Commit tx. - err = tx.Commit() - done = true - return err -} - -func (db *DB) NewValues(model interface{}) *bun.ValuesQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewValuesQuery(db.bun, model).Conn(&db.raw) -} - -func (db *DB) NewMerge() *bun.MergeQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewMergeQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewSelect() *bun.SelectQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewSelectQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewInsert() *bun.InsertQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewInsertQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewUpdate() *bun.UpdateQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewUpdateQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewDelete() *bun.DeleteQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewDeleteQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewRaw(query string, args ...interface{}) *bun.RawQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewRawQuery(db.bun, query, args...).Conn(&db.raw) -} - -func (db *DB) NewCreateTable() *bun.CreateTableQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewCreateTableQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewDropTable() *bun.DropTableQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewDropTableQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewCreateIndex() *bun.CreateIndexQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewCreateIndexQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewDropIndex() *bun.DropIndexQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewDropIndexQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewTruncateTable() *bun.TruncateTableQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewTruncateTableQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewAddColumn() *bun.AddColumnQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewAddColumnQuery(db.bun).Conn(&db.raw) -} - -func (db *DB) NewDropColumn() *bun.DropColumnQuery { - // note: passing in rawdb as conn iface so no double query-hook - // firing when passed through the bun.DB.Query___() functions. - return bun.NewDropColumnQuery(db.bun).Conn(&db.raw) -} - -// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors. -func (db *DB) Exists(ctx context.Context, query *bun.SelectQuery) (bool, error) { - exists, err := query.Exists(ctx) - switch err { - case nil: - return exists, nil - case sql.ErrNoRows: - return false, nil - default: - return false, err - } -} - -// NotExists checks the results of a SelectQuery for the non-existence of the data in question, masking ErrNoEntries errors. -func (db *DB) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, error) { - exists, err := db.Exists(ctx, query) - return !exists, err -} - -type rawdb struct { - // dialect specific error - // processing function hook. - errHook func(error) error - - // embedded raw - // db interface - db *sql.DB -} - -// ExecContext wraps sql.DB.ExecContext() with retry-busy timeout and our own error processing. -func (db *rawdb) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { - err = retryOnBusy(ctx, func() error { - result, err = db.db.ExecContext(ctx, query, args...) - err = db.errHook(err) - return err - }) - return -} - -// QueryContext wraps sql.DB.QueryContext() with retry-busy timeout and our own error processing. -func (db *rawdb) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { - err = retryOnBusy(ctx, func() error { - rows, err = db.db.QueryContext(ctx, query, args...) - err = db.errHook(err) - return err - }) - return -} - -// QueryRowContext wraps sql.DB.QueryRowContext() with retry-busy timeout and our own error processing. -func (db *rawdb) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) { - _ = retryOnBusy(ctx, func() error { - row = db.db.QueryRowContext(ctx, query, args...) - err := db.errHook(row.Err()) - return err - }) - return -} - -// Tx wraps a bun transaction instance -// to provide common per-dialect SQL error -// conversions to common types, and retries -// on busy commit/rollback (SQLite only). -type Tx struct { - // our own wrapped Tx type - // kept separate to the *bun.Tx - // type to be passed into query - // builders as bun.IConn iface - // (this prevents double firing - // bun query hooks). - // - // also holds per-dialect - // error hook function. - raw rawtx - - // bun Tx interface we use - // for dialects, and improved - // struct marshal/unmarshaling. - bun *bun.Tx -} - -// wrapTx wraps a given bun.Tx in our own wrapping Tx type. -func wrapTx(db *DB, tx *bun.Tx) Tx { - return Tx{ - raw: rawtx{ - errHook: db.raw.errHook, - tx: tx.Tx, - }, - bun: tx, - } -} - -// ExecContext wraps bun.Tx.ExecContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). -func (tx Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - buntx := tx.bun // use underlying *bun.Tx interface for their query formatting - res, err := buntx.ExecContext(ctx, query, args...) - err = tx.raw.errHook(err) - return res, err -} - -// QueryContext wraps bun.Tx.QueryContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). -func (tx Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - buntx := tx.bun // use underlying *bun.Tx interface for their query formatting - rows, err := buntx.QueryContext(ctx, query, args...) - err = tx.raw.errHook(err) - return rows, err -} - -// QueryRowContext wraps bun.Tx.QueryRowContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). -func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - buntx := tx.bun // use underlying *bun.Tx interface for their query formatting - row := buntx.QueryRowContext(ctx, query, args...) - if err := tx.raw.errHook(row.Err()); err != nil { - updateRowError(row, err) // set new error - } - return row -} - -// Commit wraps bun.Tx.Commit() with retry-busy timeout and our own error processing. -func (tx Tx) Commit() (err error) { - buntx := tx.bun // use *bun.Tx interface - err = retryOnBusy(context.TODO(), func() error { - err = buntx.Commit() - err = tx.raw.errHook(err) - return err - }) - return -} - -// Rollback wraps bun.Tx.Rollback() with retry-busy timeout and our own error processing. -func (tx Tx) Rollback() (err error) { - buntx := tx.bun // use *bun.Tx interface - err = retryOnBusy(context.TODO(), func() error { - err = buntx.Rollback() - err = tx.raw.errHook(err) - return err - }) - return -} - -// Dialect is a direct call-through to bun.DB.Dialect(). -func (tx Tx) Dialect() schema.Dialect { - return tx.bun.Dialect() -} - -func (tx Tx) NewValues(model interface{}) *bun.ValuesQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewValues(model).Conn(&tx.raw) -} - -func (tx Tx) NewMerge() *bun.MergeQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewMerge().Conn(&tx.raw) -} - -func (tx Tx) NewSelect() *bun.SelectQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewSelect().Conn(&tx.raw) -} - -func (tx Tx) NewInsert() *bun.InsertQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewInsert().Conn(&tx.raw) -} - -func (tx Tx) NewUpdate() *bun.UpdateQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewUpdate().Conn(&tx.raw) -} - -func (tx Tx) NewDelete() *bun.DeleteQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewDelete().Conn(&tx.raw) -} - -func (tx Tx) NewRaw(query string, args ...interface{}) *bun.RawQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewRaw(query, args...).Conn(&tx.raw) -} - -func (tx Tx) NewCreateTable() *bun.CreateTableQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewCreateTable().Conn(&tx.raw) -} - -func (tx Tx) NewDropTable() *bun.DropTableQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewDropTable().Conn(&tx.raw) -} - -func (tx Tx) NewCreateIndex() *bun.CreateIndexQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewCreateIndex().Conn(&tx.raw) -} - -func (tx Tx) NewDropIndex() *bun.DropIndexQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewDropIndex().Conn(&tx.raw) -} - -func (tx Tx) NewTruncateTable() *bun.TruncateTableQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewTruncateTable().Conn(&tx.raw) -} - -func (tx Tx) NewAddColumn() *bun.AddColumnQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewAddColumn().Conn(&tx.raw) -} - -func (tx Tx) NewDropColumn() *bun.DropColumnQuery { - // note: passing in rawtx as conn iface so no double query-hook - // firing when passed through the bun.Tx.Query___() functions. - return tx.bun.NewDropColumn().Conn(&tx.raw) -} - -type rawtx struct { - // dialect specific error - // processing function hook. - errHook func(error) error - - // embedded raw - // tx interface - tx *sql.Tx -} - -// ExecContext wraps sql.Tx.ExecContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). -func (tx *rawtx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - res, err := tx.tx.ExecContext(ctx, query, args...) - err = tx.errHook(err) - return res, err -} - -// QueryContext wraps sql.Tx.QueryContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). -func (tx *rawtx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - rows, err := tx.tx.QueryContext(ctx, query, args...) - err = tx.errHook(err) - return rows, err -} - -// QueryRowContext wraps sql.Tx.QueryRowContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). -func (tx *rawtx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - row := tx.tx.QueryRowContext(ctx, query, args...) - if err := tx.errHook(row.Err()); err != nil { - updateRowError(row, err) // set new error - } - return row -} - -// updateRowError updates an sql.Row's internal error field using the unsafe package. -func updateRowError(sqlrow *sql.Row, err error) { - type row struct { - err error - rows *sql.Rows - } - - // compile-time check to ensure sql.Row not changed. - if unsafe.Sizeof(row{}) != unsafe.Sizeof(sql.Row{}) { - panic("sql.Row has changed definition") - } - - // this code is awful and i must be shamed for this. - (*row)(unsafe.Pointer(sqlrow)).err = err -} - -// retryOnBusy will retry given function on returned 'errBusy'. -func retryOnBusy(ctx context.Context, fn func() error) error { - var backoff time.Duration - - for i := 0; ; i++ { - // Perform func. - err := fn() - - if err != errBusy { - // May be nil, or may be - // some other error, either - // way return here. - return err - } - - // backoff according to a multiplier of 2ms * 2^2n, - // up to a maximum possible backoff time of 5 minutes. - // - // this works out as the following: - // 4ms - // 16ms - // 64ms - // 256ms - // 1.024s - // 4.096s - // 16.384s - // 1m5.536s - // 4m22.144s - backoff = 2 * time.Millisecond * (1 << (2*i + 1)) - if backoff >= 5*time.Minute { - break - } - - select { - // Context cancelled. - case <-ctx.Done(): - - // Backoff for some time. - case <-time.After(backoff): - } - } - - return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) -} diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 2398e52c2..1254d79c8 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -31,7 +31,7 @@ import ( ) type domainDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/drivers.go b/internal/db/bundb/drivers.go new file mode 100644 index 000000000..14d84e6fa --- /dev/null +++ b/internal/db/bundb/drivers.go @@ -0,0 +1,267 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "database/sql" + "database/sql/driver" + "time" + _ "unsafe" // linkname shenanigans + + pgx "github.com/jackc/pgx/v5/stdlib" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "modernc.org/sqlite" +) + +var ( + // global SQL driver instances. + postgresDriver = pgx.GetDefaultDriver() + sqliteDriver = getSQLiteDriver() +) + +func init() { + sql.Register("pgx-gts", &PostgreSQLDriver{}) + sql.Register("sqlite-gts", &SQLiteDriver{}) +} + +//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver +func getSQLiteDriver() *sqlite.Driver + +// PostgreSQLDriver is our own wrapper around the +// pgx/stdlib.Driver{} type in order to wrap further +// SQL driver types with our own err processing. +type PostgreSQLDriver struct{} + +func (d *PostgreSQLDriver) Open(name string) (driver.Conn, error) { + c, err := postgresDriver.Open(name) + if err != nil { + return nil, err + } + return &PostgreSQLConn{conn: c.(conn)}, nil +} + +type PostgreSQLConn struct{ conn } + +func (c *PostgreSQLConn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + tx, err := c.conn.BeginTx(ctx, opts) + err = processPostgresError(err) + return tx, err +} + +func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + stmt, err := c.conn.PrepareContext(ctx, query) + err = processPostgresError(err) + return stmt, err +} + +func (c *PostgreSQLConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { + return c.ExecContext(context.Background(), query, args) +} + +func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + result, err := c.conn.ExecContext(ctx, query, args) + err = processPostgresError(err) + return result, err +} + +func (c *PostgreSQLConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { + return c.QueryContext(context.Background(), query, args) +} + +func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + rows, err := c.conn.QueryContext(ctx, query, args) + err = processPostgresError(err) + return rows, err +} + +func (c *PostgreSQLConn) Close() error { + return c.conn.Close() +} + +type PostgreSQLTx struct{ driver.Tx } + +func (tx *PostgreSQLTx) Commit() error { + err := tx.Tx.Commit() + return processPostgresError(err) +} + +func (tx *PostgreSQLTx) Rollback() error { + err := tx.Tx.Rollback() + return processPostgresError(err) +} + +// SQLiteDriver is our own wrapper around the +// sqlite.Driver{} type in order to wrap further +// SQL driver types with our own functionality, +// e.g. hooks, retries and err processing. +type SQLiteDriver struct{} + +func (d *SQLiteDriver) Open(name string) (driver.Conn, error) { + c, err := sqliteDriver.Open(name) + if err != nil { + return nil, err + } + return &SQLiteConn{conn: c.(conn)}, nil +} + +type SQLiteConn struct{ conn } + +func (c *SQLiteConn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { + err = retryOnBusy(ctx, func() error { + tx, err = c.conn.BeginTx(ctx, opts) + err = processSQLiteError(err) + return err + }) + return &SQLiteTx{Context: ctx, Tx: tx}, nil +} + +func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { + err = retryOnBusy(ctx, func() error { + stmt, err = c.conn.PrepareContext(ctx, query) + err = processSQLiteError(err) + return err + }) + return +} + +func (c *SQLiteConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { + return c.ExecContext(context.Background(), query, args) +} + +func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { + err = retryOnBusy(ctx, func() error { + result, err = c.conn.ExecContext(ctx, query, args) + err = processSQLiteError(err) + return err + }) + return +} + +func (c *SQLiteConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { + return c.QueryContext(context.Background(), query, args) +} + +func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { + err = retryOnBusy(ctx, func() error { + rows, err = c.conn.QueryContext(ctx, query, args) + err = processSQLiteError(err) + return err + }) + return +} + +func (c *SQLiteConn) Close() error { + // see: https://www.sqlite.org/pragma.html#pragma_optimize + const onClose = "PRAGMA analysis_limit=1000; PRAGMA optimize;" + _, _ = c.conn.ExecContext(context.Background(), onClose, nil) + return c.conn.Close() +} + +type SQLiteTx struct { + context.Context + driver.Tx +} + +func (tx *SQLiteTx) Commit() (err error) { + err = retryOnBusy(tx.Context, func() error { + err = tx.Tx.Commit() + err = processSQLiteError(err) + return err + }) + return +} + +func (tx *SQLiteTx) Rollback() (err error) { + err = retryOnBusy(tx.Context, func() error { + err = tx.Tx.Rollback() + err = processSQLiteError(err) + return err + }) + return +} + +type conn interface { + driver.Conn + driver.ConnPrepareContext + driver.ExecerContext + driver.QueryerContext + driver.ConnBeginTx +} + +// retryOnBusy will retry given function on returned 'errBusy'. +func retryOnBusy(ctx context.Context, fn func() error) error { + var backoff time.Duration + + for i := 0; ; i++ { + // Perform func. + err := fn() + + if err != errBusy { + // May be nil, or may be + // some other error, either + // way return here. + return err + } + + // backoff according to a multiplier of 2ms * 2^2n, + // up to a maximum possible backoff time of 5 minutes. + // + // this works out as the following: + // 4ms + // 16ms + // 64ms + // 256ms + // 1.024s + // 4.096s + // 16.384s + // 1m5.536s + // 4m22.144s + backoff = 2 * time.Millisecond * (1 << (2*i + 1)) + if backoff >= 5*time.Minute { + break + } + + select { + // Context cancelled. + case <-ctx.Done(): + + // Backoff for some time. + case <-time.After(backoff): + } + } + + return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) +} diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 608cb6417..69d33eede 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -38,7 +38,7 @@ import ( ) type emojiDB struct { - db *DB + db *bun.DB state *state.State } @@ -109,7 +109,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error { return err } - return e.db.RunInTx(ctx, func(tx Tx) error { + return e.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // Delete relational links between this emoji // and any statuses using it, returning the // status IDs so we can later update them. diff --git a/internal/db/bundb/headerfilter.go b/internal/db/bundb/headerfilter.go index 087b65c82..b02d9249e 100644 --- a/internal/db/bundb/headerfilter.go +++ b/internal/db/bundb/headerfilter.go @@ -29,7 +29,7 @@ import ( ) type headerFilterDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index d506e0a31..5f96f9a26 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -34,7 +34,7 @@ import ( ) type instanceDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 5f95d3c24..fb97c8fe7 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -35,7 +35,7 @@ import ( ) type listDB struct { - db *DB + db *bun.DB state *state.State } @@ -198,7 +198,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error { } }() - return l.db.RunInTx(ctx, func(tx Tx) error { + return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // Delete all entries attached to list. if _, err := tx.NewDelete(). Table("list_entries"). @@ -515,7 +515,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt }() // Finally, insert each list entry into the database. - return l.db.RunInTx(ctx, func(tx Tx) error { + return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { for _, entry := range entries { entry := entry // rescope if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error { diff --git a/internal/db/bundb/marker.go b/internal/db/bundb/marker.go index b1dedb4f1..0ae50f269 100644 --- a/internal/db/bundb/marker.go +++ b/internal/db/bundb/marker.go @@ -30,7 +30,7 @@ import ( ) type markerDB struct { - db *DB + db *bun.DB state *state.State } @@ -85,7 +85,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er // Optimistic concurrency control: start a transaction, try to update a row with a previously retrieved version. // If the update in the transaction fails to actually change anything, another update happened concurrently, and // this update should be retried by the caller, which in this case involves sending HTTP 409 to the API client. - return m.db.RunInTx(ctx, func(tx Tx) error { + return m.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { result, err := tx.NewUpdate(). Model(marker). WherePK(). diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index ced38a588..99ef30d22 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -34,7 +34,7 @@ import ( ) type mediaDB struct { - db *DB + db *bun.DB state *state.State } @@ -151,7 +151,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { defer m.state.Caches.GTS.Media.Invalidate("ID", id) // Delete media attachment in new transaction. - err = m.db.RunInTx(ctx, func(tx Tx) error { + err = m.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if media.AccountID != "" { var account gtsmodel.Account diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index b069423bb..156469544 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -33,7 +33,7 @@ import ( ) type mentionDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index ed34222fb..3f3d5fbd6 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -34,7 +34,7 @@ import ( ) type notificationDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index 0dfb15621..37a1f26ab 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -34,7 +34,7 @@ import ( ) type pollDB struct { - db *DB + db *bun.DB state *state.State } @@ -154,7 +154,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st poll.CheckVotes() return p.state.Caches.GTS.Poll.Store(poll, func() error { - return p.db.RunInTx(ctx, func(tx Tx) error { + return p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // Update the status' "updated_at" field. if _, err := tx.NewUpdate(). Table("statuses"). @@ -362,7 +362,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { return p.state.Caches.GTS.PollVote.Store(vote, func() error { - return p.db.RunInTx(ctx, func(tx Tx) error { + return p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // Try insert vote into database. if _, err := tx.NewInsert(). Model(vote). @@ -398,7 +398,7 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error } func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { - err := p.db.RunInTx(ctx, func(tx Tx) error { + err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // Delete all votes in poll. res, err := tx.NewDelete(). Table("poll_votes"). @@ -469,7 +469,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { } func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { - err := p.db.RunInTx(ctx, func(tx Tx) error { + err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // Slice should only ever be of length // 0 or 1; it's a slice of slices only // because we can't LIMIT deletes to 1. @@ -569,7 +569,7 @@ func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID strin } // newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID. -func newSelectPollVotes(db *DB, pollID string) *bun.SelectQuery { +func newSelectPollVotes(db *bun.DB, pollID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("poll_votes")). ColumnExpr("?", bun.Ident("id")). diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 4c50862a1..71ae37545 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -31,7 +31,7 @@ import ( ) type relationshipDB struct { - db *DB + db *bun.DB state *state.State } @@ -299,7 +299,7 @@ func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID strin } // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. -func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { +func newSelectFollowRequests(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("follow_requests")). ColumnExpr("?", bun.Ident("id")). @@ -308,7 +308,7 @@ func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { } // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. -func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery { +func newSelectFollowRequesting(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("follow_requests")). ColumnExpr("?", bun.Ident("id")). @@ -317,7 +317,7 @@ func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery { } // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. -func newSelectFollows(db *DB, accountID string) *bun.SelectQuery { +func newSelectFollows(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). Table("follows"). Column("id"). @@ -327,7 +327,7 @@ func newSelectFollows(db *DB, accountID string) *bun.SelectQuery { // newSelectLocalFollows returns a new select query for all rows in the follows table with // account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). -func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery { +func newSelectLocalFollows(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). Table("follows"). Column("id"). @@ -344,7 +344,7 @@ func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery { } // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. -func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery { +func newSelectFollowers(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). Table("follows"). Column("id"). @@ -354,7 +354,7 @@ func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery { // newSelectLocalFollowers returns a new select query for all rows in the follows table with // target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). -func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery { +func newSelectLocalFollowers(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). Table("follows"). Column("id"). @@ -371,7 +371,7 @@ func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery { } // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. -func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery { +func newSelectBlocks(db *bun.DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("blocks")). ColumnExpr("?", bun.Ident("id")). diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go index 5b0ae17f3..486bf09f0 100644 --- a/internal/db/bundb/report.go +++ b/internal/db/bundb/report.go @@ -32,7 +32,7 @@ import ( ) type reportDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/rule.go b/internal/db/bundb/rule.go index ebfa89d15..e36053c38 100644 --- a/internal/db/bundb/rule.go +++ b/internal/db/bundb/rule.go @@ -32,7 +32,7 @@ import ( ) type ruleDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/search.go b/internal/db/bundb/search.go index f9c2df1f8..f8ae529f7 100644 --- a/internal/db/bundb/search.go +++ b/internal/db/bundb/search.go @@ -57,7 +57,7 @@ import ( // This isn't ideal, of course, but at least we could cover the most common use case of // a caller paging down through results. type searchDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go index 9310a6463..2177a57ae 100644 --- a/internal/db/bundb/session.go +++ b/internal/db/bundb/session.go @@ -24,10 +24,11 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" ) type sessionDB struct { - db *DB + db *bun.DB } func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 07a09050a..6d1788b5d 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -34,7 +34,7 @@ import ( ) type statusDB struct { - db *DB + db *bun.DB state *state.State } @@ -330,7 +330,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return s.db.RunInTx(ctx, func(tx Tx) error { + return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx. @@ -414,7 +414,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co // It is safe to run this database transaction within cache.Store // as the cache does not attempt a mutex lock until AFTER hook. // - return s.db.RunInTx(ctx, func(tx Tx) error { + return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // create links between this status and any emojis it uses for _, i := range status.EmojiIDs { if _, err := tx. @@ -509,7 +509,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error { // On return ensure status invalidated from cache. defer s.state.Caches.GTS.Status.Invalidate("ID", id) - return s.db.RunInTx(ctx, func(tx Tx) error { + return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // delete links between this status and any emojis it uses if _, err := tx. NewDelete(). @@ -697,6 +697,5 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID). Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID) - - return s.db.Exists(ctx, q) + return exists(ctx, q) } diff --git a/internal/db/bundb/statusbookmark.go b/internal/db/bundb/statusbookmark.go index 742c13966..73fced9c3 100644 --- a/internal/db/bundb/statusbookmark.go +++ b/internal/db/bundb/statusbookmark.go @@ -29,7 +29,7 @@ import ( ) type statusBookmarkDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index e0f018b68..d04578076 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -35,7 +35,7 @@ import ( ) type statusFaveDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go index 66ee8cb3a..e6297d2ab 100644 --- a/internal/db/bundb/tag.go +++ b/internal/db/bundb/tag.go @@ -28,7 +28,7 @@ import ( ) type tagDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/thread.go b/internal/db/bundb/thread.go index 34c5f783a..a75515062 100644 --- a/internal/db/bundb/thread.go +++ b/internal/db/bundb/thread.go @@ -28,7 +28,7 @@ import ( ) type threadDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index f2ba2a9d1..e6c7e482d 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -34,7 +34,7 @@ import ( ) type timelineDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go index c0e439720..64169213e 100644 --- a/internal/db/bundb/tombstone.go +++ b/internal/db/bundb/tombstone.go @@ -27,7 +27,7 @@ import ( ) type tombstoneDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index a6fa142f2..2854c0caa 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -31,7 +31,7 @@ import ( ) type userDB struct { - db *DB + db *bun.DB state *state.State } diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index cee20bbe1..e2dd392dc 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -18,6 +18,8 @@ package bundb import ( + "context" + "database/sql" "slices" "strings" @@ -113,6 +115,25 @@ func whereStartsLike( ) } +// exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors. +func exists(ctx context.Context, query *bun.SelectQuery) (bool, error) { + exists, err := query.Exists(ctx) + switch err { + case nil: + return exists, nil + case sql.ErrNoRows: + return false, nil + default: + return false, err + } +} + +// notExists checks the results of a SelectQuery for the non-existence of the data in question, masking ErrNoEntries errors. +func notExists(ctx context.Context, query *bun.SelectQuery) (bool, error) { + exists, err := exists(ctx, query) + return !exists, err +} + // loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs. // NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order. func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) {