forked from mirrors/gotosocial
[chore] Standardize database queries, use bun.Ident()
properly (#886)
* use bun.Ident for user queries * use bun.Ident for account queries * use bun.Ident for media queries * add DeleteAccount func * remove CaseInsensitive in Where+use Ident ipv Safe * update admin db * update domain, use ident * update emoji, use ident * update instance queries, use bun.Ident * fix media * update mentions, use bun ident * update relationship + tests * use tableexpr * add test follows to bun db test suite * update notifications * updatebyprimarykey => updatebyid * fix session * prefer explicit ID to pk * fix little fucky wucky * remove workaround * use proper db func for attachment selection * update status db * add m2m entries in test rig * fix up timeline * go fmt * fix status put issue * update GetAccountStatuses
This commit is contained in:
parent
e58a6a2da3
commit
aa07750bdb
45 changed files with 1074 additions and 570 deletions
|
@ -101,7 +101,7 @@ var Confirm action.GTSAction = func(ctx context.Context) error {
|
|||
u.Email = u.UnconfirmedEmail
|
||||
u.ConfirmedAt = time.Now()
|
||||
u.UpdatedAt = time.Now()
|
||||
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
|
||||
if err := dbConn.UpdateByID(ctx, u, u.ID, updatingColumns...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
5
internal/cache/account.go
vendored
5
internal/cache/account.go
vendored
|
@ -101,6 +101,11 @@ func (c *AccountCache) Put(account *gtsmodel.Account) {
|
|||
c.cache.Set(account.ID, copyAccount(account))
|
||||
}
|
||||
|
||||
// Invalidate removes (invalidates) one account from the cache by its ID.
|
||||
func (c *AccountCache) Invalidate(id string) {
|
||||
c.cache.Invalidate(id)
|
||||
}
|
||||
|
||||
// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
|
||||
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
|
||||
// this should be a relatively cheap process
|
||||
|
|
|
@ -48,6 +48,11 @@ type Account interface {
|
|||
// UpdateAccount updates one account by ID.
|
||||
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
|
||||
|
||||
// DeleteAccount deletes one account from the database by its ID.
|
||||
// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the
|
||||
// account as suspended instead, rather than deleting from the db entirely.
|
||||
DeleteAccount(ctx context.Context, id string) Error
|
||||
|
||||
// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username.
|
||||
GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error)
|
||||
|
||||
|
|
|
@ -62,11 +62,11 @@ type Basic interface {
|
|||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
Put(ctx context.Context, i interface{}) Error
|
||||
|
||||
// UpdateByPrimaryKey updates values of i based on its primary key.
|
||||
// UpdateByID updates values of i based on its id.
|
||||
// If any columns are specified, these will be updated exclusively.
|
||||
// Otherwise, the whole model will be updated.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) Error
|
||||
UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) Error
|
||||
|
||||
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
|
||||
UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error
|
||||
|
|
|
@ -21,7 +21,6 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
|
|||
return a.cache.GetByID(id)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
|
|||
return a.cache.GetByURI(uri)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
|
|||
return a.cache.GetByURL(url)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
|
|||
q := a.newAccountQ(account)
|
||||
|
||||
if domain != "" {
|
||||
q = q.Where("account.username = ?", username)
|
||||
q = q.Where("account.domain = ?", domain)
|
||||
q = q.Where("? = ?", bun.Ident("account.username"), username)
|
||||
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
} else {
|
||||
q = q.Where("account.username = ?", strings.ToLower(username))
|
||||
q = q.Where("account.domain IS NULL")
|
||||
q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username))
|
||||
q = q.Where("? IS NULL", bun.Ident("account.domain"))
|
||||
}
|
||||
|
||||
return q.Scan(ctx)
|
||||
|
@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
|
|||
return a.cache.GetByPubkeyID(id)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -169,26 +168,36 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// create links between this account and any emojis it uses
|
||||
// first clear out any old emoji links
|
||||
if _, err := tx.NewDelete().
|
||||
Model(&[]*gtsmodel.AccountToEmoji{}).
|
||||
Where("account_id = ?", account.ID).
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
|
||||
Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// now populate new emoji links
|
||||
for _, i := range account.EmojiIDs {
|
||||
if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{
|
||||
AccountID: account.ID,
|
||||
EmojiID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(>smodel.AccountToEmoji{
|
||||
AccountID: account.ID,
|
||||
EmojiID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// update the account
|
||||
_, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx)
|
||||
return err
|
||||
if _, err := tx.
|
||||
NewUpdate().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.id"), account.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
@ -197,6 +206,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
|
||||
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// clear out any emoji links
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
|
||||
Where("? = ?", bun.Ident("account_to_emoji.account_id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete the account
|
||||
_, err := tx.
|
||||
NewUpdate().
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Where("? = ?", bun.Ident("account.id"), id).
|
||||
Exec(ctx)
|
||||
return err
|
||||
}); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
a.cache.Invalidate(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
|
||||
account := new(gtsmodel.Account)
|
||||
|
||||
|
@ -204,11 +239,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
|
|||
|
||||
if domain != "" {
|
||||
q = q.
|
||||
Where("account.username = ?", domain).
|
||||
Where("account.domain = ?", domain)
|
||||
Where("? = ?", bun.Ident("account.username"), domain).
|
||||
Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
} else {
|
||||
q = q.
|
||||
Where("account.username = ?", config.GetHost()).
|
||||
Where("? = ?", bun.Ident("account.username"), config.GetHost()).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
}
|
||||
|
||||
|
@ -224,10 +259,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
|
|||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(status).
|
||||
Order("id DESC").
|
||||
Limit(1).
|
||||
Where("account_id = ?", accountID).
|
||||
Column("created_at")
|
||||
Column("status.created_at").
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
Order("status.id DESC").
|
||||
Limit(1)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return time.Time{}, a.conn.ProcessError(err)
|
||||
|
@ -240,12 +275,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
|||
return errors.New("one media attachment cannot be both header and avatar")
|
||||
}
|
||||
|
||||
var headerOrAVI string
|
||||
var column bun.Ident
|
||||
switch {
|
||||
case *mediaAttachment.Avatar:
|
||||
headerOrAVI = "avatar"
|
||||
column = bun.Ident("account.avatar_media_attachment_id")
|
||||
case *mediaAttachment.Header:
|
||||
headerOrAVI = "header"
|
||||
column = bun.Ident("account.header_media_attachment_id")
|
||||
default:
|
||||
return errors.New("given media attachment was neither a header nor an avatar")
|
||||
}
|
||||
|
@ -257,11 +292,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
|||
Exec(ctx); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if _, err := a.conn.
|
||||
NewUpdate().
|
||||
Model(>smodel.Account{}).
|
||||
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
|
||||
Where("id = ?", accountID).
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Set("? = ?", column, mediaAttachment.ID).
|
||||
Where("? = ?", bun.Ident("account.id"), accountID).
|
||||
Exec(ctx); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
@ -284,7 +320,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
|||
if err := a.conn.
|
||||
NewSelect().
|
||||
Model(faves).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
@ -295,8 +331,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
|||
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
|
||||
return a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Status{}).
|
||||
Where("account_id = ?", accountID).
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
Count(ctx)
|
||||
}
|
||||
|
||||
|
@ -305,12 +341,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
Column("id").
|
||||
Order("id DESC")
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Column("status.id").
|
||||
Order("status.id DESC")
|
||||
|
||||
if accountID != "" {
|
||||
q = q.Where("account_id = ?", accountID)
|
||||
q = q.Where("? = ?", bun.Ident("status.account_id"), accountID)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
|
@ -321,27 +357,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
// include self-replies (threads)
|
||||
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
|
||||
return q.
|
||||
WhereOr("in_reply_to_account_id = ?", accountID).
|
||||
WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri"))
|
||||
WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID).
|
||||
WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri"))
|
||||
}
|
||||
|
||||
q = q.WhereGroup(" AND ", whereGroup)
|
||||
}
|
||||
|
||||
if excludeReblogs {
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id"))
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id"))
|
||||
}
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||
}
|
||||
|
||||
if minID != "" {
|
||||
q = q.Where("id > ?", minID)
|
||||
q = q.Where("? > ?", bun.Ident("status.id"), minID)
|
||||
}
|
||||
|
||||
if pinnedOnly {
|
||||
q = q.Where("pinned = ?", true)
|
||||
q = q.Where("? = ?", bun.Ident("status.pinned"), true)
|
||||
}
|
||||
|
||||
if mediaOnly {
|
||||
|
@ -352,15 +388,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
switch a.conn.Dialect().Name() {
|
||||
case dialect.PG:
|
||||
return q.
|
||||
Where("? IS NOT NULL", bun.Ident("attachments")).
|
||||
Where("? != '{}'", bun.Ident("attachments"))
|
||||
Where("? IS NOT NULL", bun.Ident("status.attachments")).
|
||||
Where("? != '{}'", bun.Ident("status.attachments"))
|
||||
case dialect.SQLite:
|
||||
return q.
|
||||
Where("? IS NOT NULL", bun.Ident("attachments")).
|
||||
Where("? != ''", bun.Ident("attachments")).
|
||||
Where("? != 'null'", bun.Ident("attachments")).
|
||||
Where("? != '{}'", bun.Ident("attachments")).
|
||||
Where("? != '[]'", bun.Ident("attachments"))
|
||||
Where("? IS NOT NULL", bun.Ident("status.attachments")).
|
||||
Where("? != ''", bun.Ident("status.attachments")).
|
||||
Where("? != 'null'", bun.Ident("status.attachments")).
|
||||
Where("? != '{}'", bun.Ident("status.attachments")).
|
||||
Where("? != '[]'", bun.Ident("status.attachments"))
|
||||
default:
|
||||
log.Panic("db dialect was neither pg nor sqlite")
|
||||
return q
|
||||
|
@ -369,7 +405,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
}
|
||||
|
||||
if publicOnly {
|
||||
q = q.Where("visibility = ?", gtsmodel.VisibilityPublic)
|
||||
q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||
|
@ -384,19 +420,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
|
|||
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
Column("id").
|
||||
Where("account_id = ?", accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")).
|
||||
Where("visibility = ?", gtsmodel.VisibilityPublic).
|
||||
Where("federated = ?", true)
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Column("status.id").
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
|
||||
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
|
||||
Where("? = ?", bun.Ident("status.federated"), true)
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||
}
|
||||
|
||||
q = q.Limit(limit).Order("id DESC")
|
||||
q = q.Limit(limit).Order("status.id DESC")
|
||||
|
||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
|
@ -411,16 +447,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
|
|||
fq := a.conn.
|
||||
NewSelect().
|
||||
Model(&blocks).
|
||||
Where("block.account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("block.account_id"), accountID).
|
||||
Relation("TargetAccount").
|
||||
Order("block.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
fq = fq.Where("block.id < ?", maxID)
|
||||
fq = fq.Where("? < ?", bun.Ident("block.id"), maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
fq = fq.Where("block.id > ?", sinceID)
|
||||
fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
|
|
|
@ -42,6 +42,18 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() {
|
|||
suite.Len(statuses, 5)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() {
|
||||
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, false)
|
||||
suite.NoError(err)
|
||||
suite.Len(statuses, 5)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() {
|
||||
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, true)
|
||||
suite.NoError(err)
|
||||
suite.Len(statuses, 1)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() {
|
||||
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, true, false)
|
||||
suite.NoError(err)
|
||||
|
@ -99,7 +111,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
|
|||
err = dbService.GetConn().
|
||||
NewSelect().
|
||||
Model(noCache).
|
||||
Where("account.id = ?", bun.Ident(testAccount.ID)).
|
||||
Where("? = ?", bun.Ident("account.id"), testAccount.ID).
|
||||
Relation("AvatarMediaAttachment").
|
||||
Relation("HeaderMediaAttachment").
|
||||
Relation("Emojis").
|
||||
|
@ -127,7 +139,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
|
|||
err = dbService.GetConn().
|
||||
NewSelect().
|
||||
Model(noCache).
|
||||
Where("account.id = ?", bun.Ident(testAccount.ID)).
|
||||
Where("? = ?", bun.Ident("account.id"), testAccount.ID).
|
||||
Relation("AvatarMediaAttachment").
|
||||
Relation("HeaderMediaAttachment").
|
||||
Relation("Emojis").
|
||||
|
|
|
@ -22,7 +22,6 @@ import (
|
|||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/mail"
|
||||
|
@ -37,21 +36,26 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/uris"
|
||||
"github.com/uptrace/bun"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// generate RSA keys of this length
|
||||
const rsaKeyBits = 2048
|
||||
|
||||
type adminDB struct {
|
||||
conn *DBConn
|
||||
userCache *cache.UserCache
|
||||
conn *DBConn
|
||||
userCache *cache.UserCache
|
||||
accountCache *cache.AccountCache
|
||||
}
|
||||
|
||||
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Account{}).
|
||||
Where("username = ?", username).
|
||||
Where("domain = ?", nil)
|
||||
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Column("account.id").
|
||||
Where("? = ?", bun.Ident("account.username"), username).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
return a.conn.NotExists(ctx, q)
|
||||
}
|
||||
|
||||
|
@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
|||
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
|
||||
|
||||
// check if the email domain is blocked
|
||||
if err := a.conn.
|
||||
emailDomainBlockedQ := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.EmailDomainBlock{}).
|
||||
Where("domain = ?", domain).
|
||||
Scan(ctx); err == nil {
|
||||
// fail because we found something
|
||||
TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
|
||||
Column("email_domain_block.id").
|
||||
Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
|
||||
emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if emailDomainBlocked {
|
||||
return false, fmt.Errorf("email domain %s is blocked", domain)
|
||||
} else if err != sql.ErrNoRows {
|
||||
return false, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// check if this email is associated with a user already
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.User{}).
|
||||
Where("email = ?", email).
|
||||
WhereOr("unconfirmed_email = ?", email)
|
||||
|
||||
TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
|
||||
Column("user.id").
|
||||
Where("? = ?", bun.Ident("user.email"), email).
|
||||
WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
|
||||
return a.conn.NotExists(ctx, q)
|
||||
}
|
||||
|
||||
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
|
||||
if err != nil {
|
||||
log.Errorf("error creating new rsa key: %s", err)
|
||||
return nil, err
|
||||
|
@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
|
||||
// if something went wrong while creating a user, we might already have an account, so check here first...
|
||||
acct := >smodel.Account{}
|
||||
q := a.conn.NewSelect().
|
||||
if err := a.conn.
|
||||
NewSelect().
|
||||
Model(acct).
|
||||
Where("username = ?", username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
Where("? = ?", bun.Ident("account.username"), username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("account.domain")).
|
||||
Scan(ctx); err != nil {
|
||||
err = a.conn.ProcessError(err)
|
||||
if err != db.ErrNoEntries {
|
||||
log.Errorf("error checking for existing account: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
// we just don't have an account yet so create one before we proceed
|
||||
// if we have db.ErrNoEntries, we just don't have an
|
||||
// account yet so create one before we proceed
|
||||
accountURIs := uris.GenerateURIsForAccount(username)
|
||||
accountID, err := id.NewRandomULID()
|
||||
if err != nil {
|
||||
|
@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
FeaturedCollectionURI: accountURIs.CollectionURI,
|
||||
}
|
||||
|
||||
// insert the new account!
|
||||
if _, err = a.conn.
|
||||
NewInsert().
|
||||
Model(acct).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
a.accountCache.Put(acct)
|
||||
}
|
||||
|
||||
// we either created or already had an account by now,
|
||||
// so proceed with creating a user for that account
|
||||
|
||||
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error hashing password: %s", err)
|
||||
|
@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
u.Moderator = &moderator
|
||||
}
|
||||
|
||||
// insert the user!
|
||||
if _, err = a.conn.
|
||||
NewInsert().
|
||||
Model(u).
|
||||
|
@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Account{}).
|
||||
Where("username = ?", username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Column("account.id").
|
||||
Where("? = ?", bun.Ident("account.username"), username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
|
||||
|
||||
exists, err := a.conn.Exists(ctx, q)
|
||||
if err != nil {
|
||||
|
@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
|
||||
if err != nil {
|
||||
log.Errorf("error creating new rsa key: %s", err)
|
||||
return err
|
||||
|
@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
a.accountCache.Put(acct)
|
||||
log.Infof("instance account %s CREATED with id %s", username, acct.ID)
|
||||
return nil
|
||||
}
|
||||
|
@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
|
|||
// check if instance entry already exists
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Instance{}).
|
||||
Where("domain = ?", host)
|
||||
Column("instance.id").
|
||||
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
|
||||
Where("? = ?", bun.Ident("instance.domain"), host)
|
||||
|
||||
exists, err := a.conn.Exists(ctx, q)
|
||||
if err != nil {
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
|
@ -30,6 +31,44 @@ type AdminTestSuite struct {
|
|||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsUsernameAvailableNo() {
|
||||
available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork")
|
||||
suite.NoError(err)
|
||||
suite.False(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsUsernameAvailableYes() {
|
||||
available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different")
|
||||
suite.NoError(err)
|
||||
suite.True(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsEmailAvailableNo() {
|
||||
available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org")
|
||||
suite.NoError(err)
|
||||
suite.False(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsEmailAvailableYes() {
|
||||
available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
|
||||
suite.NoError(err)
|
||||
suite.True(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
|
||||
if err := suite.db.Put(context.Background(), >smodel.EmailDomainBlock{
|
||||
ID: "01GEEV2R2YC5GRSN96761YJE47",
|
||||
Domain: "somewhere.com",
|
||||
CreatedByAccountID: suite.testAccounts["admin_account"].ID,
|
||||
}); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com")
|
||||
suite.EqualError(err, "email domain somewhere.com is blocked")
|
||||
suite.False(available)
|
||||
}
|
||||
|
||||
func (suite *AdminTestSuite) TestCreateInstanceAccount() {
|
||||
// we need to take an empty db for this...
|
||||
testrig.StandardDBTeardown(suite.db)
|
||||
|
|
|
@ -94,12 +94,12 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
|
|||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) db.Error {
|
||||
func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) db.Error {
|
||||
q := b.conn.
|
||||
NewUpdate().
|
||||
Model(i).
|
||||
Column(columns...).
|
||||
WherePK()
|
||||
Where("? = ?", bun.Ident("id"), id)
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
return b.conn.ProcessError(err)
|
||||
|
@ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
|
|||
|
||||
updateWhere(q, where)
|
||||
|
||||
q = q.Set("? = ?", bun.Safe(key), value)
|
||||
q = q.Set("? = ?", bun.Ident(key), value)
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
return b.conn.ProcessError(err)
|
||||
|
|
|
@ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
|
|||
return nil, fmt.Errorf("db migration error: %s", err)
|
||||
}
|
||||
|
||||
// Create DB structs that require ptrs to each other
|
||||
accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()}
|
||||
status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
|
||||
emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
|
||||
timeline := &timelineDB{conn: conn}
|
||||
|
||||
// Setup DB cross-referencing
|
||||
accounts.status = status
|
||||
status.accounts = accounts
|
||||
timeline.status = status
|
||||
// Prepare caches required by more than one struct
|
||||
userCache := cache.NewUserCache()
|
||||
accountCache := cache.NewAccountCache()
|
||||
|
||||
// Prepare other caches
|
||||
// Prepare mentions cache
|
||||
// TODO: move into internal/cache
|
||||
mentionCache := grufcache.New[string, *gtsmodel.Mention]()
|
||||
|
@ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
|
|||
notifCache.SetTTL(time.Minute*5, false)
|
||||
notifCache.Start(time.Second * 10)
|
||||
|
||||
// Prepare other caches
|
||||
blockCache := cache.NewDomainBlockCache()
|
||||
userCache := cache.NewUserCache()
|
||||
// Create DB structs that require ptrs to each other
|
||||
accounts := &accountDB{conn: conn, cache: accountCache}
|
||||
status := &statusDB{conn: conn, cache: cache.NewStatusCache()}
|
||||
emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()}
|
||||
timeline := &timelineDB{conn: conn}
|
||||
|
||||
// Setup DB cross-referencing
|
||||
accounts.status = status
|
||||
status.accounts = accounts
|
||||
timeline.status = status
|
||||
|
||||
ps := &DBService{
|
||||
Account: accounts,
|
||||
Admin: &adminDB{
|
||||
conn: conn,
|
||||
userCache: userCache,
|
||||
conn: conn,
|
||||
userCache: userCache,
|
||||
accountCache: accountCache,
|
||||
},
|
||||
Basic: &basicDB{
|
||||
conn: conn,
|
||||
},
|
||||
Domain: &domainDB{
|
||||
conn: conn,
|
||||
cache: blockCache,
|
||||
cache: cache.NewDomainBlockCache(),
|
||||
},
|
||||
Emoji: emoji,
|
||||
Instance: &instanceDB{
|
||||
|
|
|
@ -40,6 +40,7 @@ type BunDBStandardTestSuite struct {
|
|||
testStatuses map[string]*gtsmodel.Status
|
||||
testTags map[string]*gtsmodel.Tag
|
||||
testMentions map[string]*gtsmodel.Mention
|
||||
testFollows map[string]*gtsmodel.Follow
|
||||
}
|
||||
|
||||
func (suite *BunDBStandardTestSuite) SetupSuite() {
|
||||
|
@ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
|
|||
suite.testStatuses = testrig.NewTestStatuses()
|
||||
suite.testTags = testrig.NewTestTags()
|
||||
suite.testMentions = testrig.NewTestMentions()
|
||||
suite.testFollows = testrig.NewTestFollows()
|
||||
}
|
||||
|
||||
func (suite *BunDBStandardTestSuite) SetupTest() {
|
||||
|
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/uptrace/bun"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
|
@ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
|
|||
q := d.conn.
|
||||
NewSelect().
|
||||
Model(block).
|
||||
Where("domain = ?", domain).
|
||||
Where("? = ?", bun.Ident("domain_block.domain"), domain).
|
||||
Limit(1)
|
||||
|
||||
// Query database for domain block
|
||||
|
@ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
|
|||
// Attempt to delete domain block
|
||||
if _, err := d.conn.NewDelete().
|
||||
Model((*gtsmodel.DomainBlock)(nil)).
|
||||
Where("domain = ?", domain).
|
||||
Where("? = ?", bun.Ident("domain_block.domain"), domain).
|
||||
Exec(ctx); err != nil {
|
||||
return d.conn.ProcessError(err)
|
||||
}
|
||||
|
|
|
@ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er
|
|||
|
||||
q := e.conn.
|
||||
NewSelect().
|
||||
Table("emojis").
|
||||
Column("id").
|
||||
Where("visible_in_picker = true").
|
||||
Where("disabled = false").
|
||||
Where("domain IS NULL").
|
||||
Order("shortcode ASC")
|
||||
TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")).
|
||||
Column("emoji.id").
|
||||
Where("? = ?", bun.Ident("emoji.visible_in_picker"), true).
|
||||
Where("? = ?", bun.Ident("emoji.disabled"), false).
|
||||
Where("? IS NULL", bun.Ident("emoji.domain")).
|
||||
Order("emoji.shortcode ASC")
|
||||
|
||||
if err := q.Scan(ctx, &emojiIDs); err != nil {
|
||||
return nil, e.conn.ProcessError(err)
|
||||
|
@ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji,
|
|||
return e.cache.GetByID(id)
|
||||
},
|
||||
func(emoji *gtsmodel.Emoji) error {
|
||||
return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx)
|
||||
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
|
|||
return e.cache.GetByURI(uri)
|
||||
},
|
||||
func(emoji *gtsmodel.Emoji) error {
|
||||
return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx)
|
||||
return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin
|
|||
q := e.newEmojiQ(emoji)
|
||||
|
||||
if domain != "" {
|
||||
q = q.Where("emoji.shortcode = ?", shortcode)
|
||||
q = q.Where("emoji.domain = ?", domain)
|
||||
q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode)
|
||||
q = q.Where("? = ?", bun.Ident("emoji.domain"), domain)
|
||||
} else {
|
||||
q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode))
|
||||
q = q.Where("emoji.domain IS NULL")
|
||||
q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode))
|
||||
q = q.Where("? IS NULL", bun.Ident("emoji.domain"))
|
||||
}
|
||||
|
||||
return q.Scan(ctx)
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
|
@ -35,15 +34,16 @@ type instanceDB struct {
|
|||
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
|
||||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Account{}).
|
||||
Where("username != ?", domain).
|
||||
Where("? IS NULL", bun.Ident("suspended_at"))
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Column("account.id").
|
||||
Where("? != ?", bun.Ident("account.username"), domain).
|
||||
Where("? IS NULL", bun.Ident("account.suspended_at"))
|
||||
|
||||
if domain == config.GetHost() {
|
||||
if domain == config.GetHost() || domain == config.GetAccountDomain() {
|
||||
// if the domain is *this* domain, just count where the domain field is null
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain"))
|
||||
} else {
|
||||
q = q.Where("domain = ?", domain)
|
||||
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
@ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
|
|||
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
|
||||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Status{})
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status"))
|
||||
|
||||
if domain == config.GetHost() {
|
||||
if domain == config.GetHost() || domain == config.GetAccountDomain() {
|
||||
// if the domain is *this* domain, just count where local is true
|
||||
q = q.Where("local = ?", true)
|
||||
q = q.Where("? = ?", bun.Ident("status.local"), true)
|
||||
} else {
|
||||
// join on the domain of the account
|
||||
q = q.Join("JOIN accounts AS account ON account.id = status.account_id").
|
||||
Where("account.domain = ?", domain)
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")).
|
||||
Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
@ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
|
|||
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
|
||||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Instance{})
|
||||
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance"))
|
||||
|
||||
if domain == config.GetHost() {
|
||||
// if the domain is *this* domain, just count other instances it knows about
|
||||
// exclude domains that are blocked
|
||||
q = q.
|
||||
Where("domain != ?", domain).
|
||||
Where("? IS NULL", bun.Ident("suspended_at"))
|
||||
Where("? != ?", bun.Ident("instance.domain"), domain).
|
||||
Where("? IS NULL", bun.Ident("instance.suspended_at"))
|
||||
} else {
|
||||
// TODO: implement federated domain counting properly for remote domains
|
||||
return 0, nil
|
||||
|
@ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
|
|||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&instances).
|
||||
Where("domain != ?", config.GetHost())
|
||||
Where("? != ?", bun.Ident("instance.domain"), config.GetHost())
|
||||
|
||||
if !includeSuspended {
|
||||
q = q.Where("? IS NULL", bun.Ident("suspended_at"))
|
||||
q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
|
@ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool
|
|||
}
|
||||
|
||||
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
||||
log.Debug("GetAccountsForInstance")
|
||||
|
||||
accounts := []*gtsmodel.Account{}
|
||||
|
||||
q := i.conn.NewSelect().
|
||||
Model(&accounts).
|
||||
Where("domain = ?", domain).
|
||||
Order("id DESC")
|
||||
Where("? = ?", bun.Ident("account.domain"), domain).
|
||||
Order("account.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("account.id"), maxID)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
|
|
83
internal/db/bundb/instance_test.go
Normal file
83
internal/db/bundb/instance_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package bundb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
type InstanceTestSuite struct {
|
||||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceUsers() {
|
||||
count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost())
|
||||
suite.NoError(err)
|
||||
suite.Equal(4, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {
|
||||
count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io")
|
||||
suite.NoError(err)
|
||||
suite.Equal(1, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceStatuses() {
|
||||
count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())
|
||||
suite.NoError(err)
|
||||
suite.Equal(16, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {
|
||||
count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")
|
||||
suite.NoError(err)
|
||||
suite.Equal(1, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceDomains() {
|
||||
count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost())
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestGetInstancePeers() {
|
||||
peers, err := suite.db.GetInstancePeers(context.Background(), false)
|
||||
suite.NoError(err)
|
||||
suite.Len(peers, 2)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() {
|
||||
peers, err := suite.db.GetInstancePeers(context.Background(), true)
|
||||
suite.NoError(err)
|
||||
suite.Len(peers, 2)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestGetInstanceAccounts() {
|
||||
accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10)
|
||||
suite.NoError(err)
|
||||
suite.Len(accounts, 1)
|
||||
}
|
||||
|
||||
func TestInstanceTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(InstanceTestSuite))
|
||||
}
|
|
@ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
|
|||
attachment := >smodel.MediaAttachment{}
|
||||
|
||||
q := m.newMediaQ(attachment).
|
||||
Where("media_attachment.id = ?", id)
|
||||
Where("? = ?", bun.Ident("media_attachment.id"), id)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, m.conn.ProcessError(err)
|
||||
|
@ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
|
|||
q := m.conn.
|
||||
NewSelect().
|
||||
Model(&attachments).
|
||||
Where("media_attachment.cached = true").
|
||||
Where("media_attachment.avatar = false").
|
||||
Where("media_attachment.header = false").
|
||||
Where("media_attachment.created_at < ?", olderThan).
|
||||
Where("? = ?", bun.Ident("media_attachment.cached"), true).
|
||||
Where("? = ?", bun.Ident("media_attachment.avatar"), false).
|
||||
Where("? = ?", bun.Ident("media_attachment.header"), false).
|
||||
Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
|
||||
WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")).
|
||||
Order("media_attachment.created_at DESC")
|
||||
|
||||
|
@ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
|
|||
q := m.newMediaQ(&attachments).
|
||||
WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery {
|
||||
return innerQ.
|
||||
WhereOr("media_attachment.avatar = true").
|
||||
WhereOr("media_attachment.header = true")
|
||||
WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true).
|
||||
WhereOr("? = ?", bun.Ident("media_attachment.header"), true)
|
||||
}).
|
||||
Order("media_attachment.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("media_attachment.id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
|
@ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
|
|||
attachments := []*gtsmodel.MediaAttachment{}
|
||||
|
||||
q := m.newMediaQ(&attachments).
|
||||
Where("media_attachment.cached = true").
|
||||
Where("media_attachment.avatar = false").
|
||||
Where("media_attachment.header = false").
|
||||
Where("media_attachment.created_at < ?", olderThan).
|
||||
Where("media_attachment.remote_url IS NULL").
|
||||
Where("media_attachment.status_id IS NULL")
|
||||
Where("? = ?", bun.Ident("media_attachment.cached"), true).
|
||||
Where("? = ?", bun.Ident("media_attachment.avatar"), false).
|
||||
Where("? = ?", bun.Ident("media_attachment.header"), false).
|
||||
Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan).
|
||||
Where("? IS NULL", bun.Ident("media_attachment.remote_url")).
|
||||
Where("? IS NULL", bun.Ident("media_attachment.status_id"))
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("media_attachment.id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
|
|
|
@ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment
|
|||
mention := gtsmodel.Mention{}
|
||||
|
||||
q := m.newMentionQ(&mention).
|
||||
Where("mention.id = ?", id)
|
||||
Where("? = ?", bun.Ident("mention.id"), id)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, m.conn.ProcessError(err)
|
||||
|
|
|
@ -47,8 +47,8 @@ func init() {
|
|||
}
|
||||
|
||||
if _, err := tx.NewDelete().
|
||||
Model(a).
|
||||
WherePK().
|
||||
TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
|
||||
Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
|
||||
Exec(ctx); err != nil {
|
||||
l.Errorf("error deleting attachment with id %s: %s", a.ID, err)
|
||||
} else {
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type notificationDB struct {
|
||||
|
@ -44,7 +45,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
|
|||
Relation("OriginAccount").
|
||||
Relation("TargetAccount").
|
||||
Relation("Status").
|
||||
WherePK()
|
||||
Where("? = ?", bun.Ident("notification.id"), id)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, n.conn.ProcessError(err)
|
||||
|
@ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
|
||||
q := n.conn.
|
||||
NewSelect().
|
||||
Table("notifications").
|
||||
Column("id")
|
||||
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||
Column("notification.id")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("notification.id"), maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
q = q.Where("id > ?", sinceID)
|
||||
q = q.Where("? > ?", bun.Ident("notification.id"), sinceID)
|
||||
}
|
||||
|
||||
for _, excludeType := range excludeTypes {
|
||||
q = q.Where("notification_type != ?", excludeType)
|
||||
q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType)
|
||||
}
|
||||
|
||||
q = q.
|
||||
Where("target_account_id = ?", accountID).
|
||||
Order("id DESC")
|
||||
Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
|
||||
Order("notification.id DESC")
|
||||
|
||||
if limit != 0 {
|
||||
q = q.Limit(limit)
|
||||
|
@ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error {
|
||||
if _, err := n.conn.
|
||||
NewDelete().
|
||||
Table("notifications").
|
||||
Where("target_account_id = ?", accountID).
|
||||
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||
Where("? = ?", bun.Ident("notification.target_account_id"), accountID).
|
||||
Exec(ctx); err != nil {
|
||||
return n.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
n.cache.Clear()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
|
|||
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Block{}).
|
||||
ExcludeColumn("id", "created_at", "updated_at", "uri").
|
||||
Limit(1)
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Column("block.id")
|
||||
|
||||
if eitherDirection {
|
||||
q = q.
|
||||
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
|
||||
return inner.
|
||||
Where("account_id = ?", account1).
|
||||
Where("target_account_id = ?", account2)
|
||||
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||
}).
|
||||
WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery {
|
||||
return inner.
|
||||
Where("account_id = ?", account2).
|
||||
Where("target_account_id = ?", account1)
|
||||
Where("? = ?", bun.Ident("block.account_id"), account2).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account1)
|
||||
})
|
||||
} else {
|
||||
q = q.
|
||||
Where("account_id = ?", account1).
|
||||
Where("target_account_id = ?", account2)
|
||||
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||
}
|
||||
|
||||
return r.conn.Exists(ctx, q)
|
||||
|
@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
|
|||
block := >smodel.Block{}
|
||||
|
||||
q := r.newBlockQ(block).
|
||||
Where("block.account_id = ?", account1).
|
||||
Where("block.target_account_id = ?", account2)
|
||||
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
|
@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
|
|||
if err := r.conn.
|
||||
NewSelect().
|
||||
Model(follow).
|
||||
Where("account_id = ?", requestingAccount).
|
||||
Where("target_account_id = ?", targetAccount).
|
||||
Column("follow.show_reblogs", "follow.notify").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
|
||||
Limit(1).
|
||||
Scan(ctx); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
// a proper error
|
||||
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
|
||||
if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
|
||||
return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
|
||||
}
|
||||
// no follow exists so these are all false
|
||||
rel.Following = false
|
||||
|
@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
|
|||
}
|
||||
|
||||
// check if the target account follows the requesting account
|
||||
count, err := r.conn.
|
||||
followedByQ := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Follow{}).
|
||||
Where("account_id = ?", targetAccount).
|
||||
Where("target_account_id = ?", requestingAccount).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Column("follow.id").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
|
||||
followedBy, err := r.conn.Exists(ctx, followedByQ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
|
||||
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
|
||||
}
|
||||
rel.FollowedBy = count > 0
|
||||
|
||||
// check if the requesting account blocks the target account
|
||||
count, err = r.conn.NewSelect().
|
||||
Model(>smodel.Block{}).
|
||||
Where("account_id = ?", requestingAccount).
|
||||
Where("target_account_id = ?", targetAccount).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
|
||||
}
|
||||
rel.Blocking = count > 0
|
||||
|
||||
// check if the target account blocks the requesting account
|
||||
count, err = r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Block{}).
|
||||
Where("account_id = ?", targetAccount).
|
||||
Where("target_account_id = ?", requestingAccount).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
|
||||
}
|
||||
rel.BlockedBy = count > 0
|
||||
rel.FollowedBy = followedBy
|
||||
|
||||
// check if there's a pending following request from requesting account to target account
|
||||
count, err = r.conn.
|
||||
requestedQ := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.FollowRequest{}).
|
||||
Where("account_id = ?", requestingAccount).
|
||||
Where("target_account_id = ?", targetAccount).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
|
||||
requested, err := r.conn.Exists(ctx, requestedQ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
|
||||
return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
|
||||
}
|
||||
rel.Requested = count > 0
|
||||
rel.Requested = requested
|
||||
|
||||
// check if the requesting account is blocking the target account
|
||||
blockingQ := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Column("block.id").
|
||||
Where("? = ?", bun.Ident("block.account_id"), requestingAccount).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), targetAccount)
|
||||
blocking, err := r.conn.Exists(ctx, blockingQ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
|
||||
}
|
||||
rel.Blocking = blocking
|
||||
|
||||
// check if the requesting account is blocked by the target account
|
||||
blockedByQ := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Column("block.id").
|
||||
Where("? = ?", bun.Ident("block.account_id"), targetAccount).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount)
|
||||
blockedBy, err := r.conn.Exists(ctx, blockedByQ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
|
||||
}
|
||||
rel.BlockedBy = blockedBy
|
||||
|
||||
return rel, nil
|
||||
}
|
||||
|
@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
|
|||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Follow{}).
|
||||
Where("account_id = ?", sourceAccount.ID).
|
||||
Where("target_account_id = ?", targetAccount.ID).
|
||||
Limit(1)
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Column("follow.id").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
|
||||
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
|
|||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.FollowRequest{}).
|
||||
Where("account_id = ?", sourceAccount.ID).
|
||||
Where("target_account_id = ?", targetAccount.ID)
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
|
||||
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
|
|||
}
|
||||
|
||||
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
|
||||
// make sure the original follow request exists
|
||||
fr := >smodel.FollowRequest{}
|
||||
if err := r.conn.
|
||||
NewSelect().
|
||||
Model(fr).
|
||||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// create a new follow to 'replace' the request with
|
||||
follow := >smodel.Follow{
|
||||
ID: fr.ID,
|
||||
AccountID: originAccountID,
|
||||
TargetAccountID: targetAccountID,
|
||||
URI: fr.URI,
|
||||
}
|
||||
|
||||
// if the follow already exists, just update the URI -- we don't need to do anything else
|
||||
if _, err := r.conn.
|
||||
NewInsert().
|
||||
Model(follow).
|
||||
On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// now remove the follow request
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
Model(>smodel.FollowRequest{}).
|
||||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Exec(ctx); err != nil {
|
||||
var follow *gtsmodel.Follow
|
||||
|
||||
if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// get original follow request
|
||||
followRequest := >smodel.FollowRequest{}
|
||||
if err := tx.
|
||||
NewSelect().
|
||||
Model(followRequest).
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||
Scan(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create a new follow to 'replace' the request with
|
||||
follow = >smodel.Follow{
|
||||
ID: followRequest.ID,
|
||||
AccountID: originAccountID,
|
||||
TargetAccountID: targetAccountID,
|
||||
URI: followRequest.URI,
|
||||
}
|
||||
|
||||
// if the follow already exists, just update the URI -- we don't need to do anything else
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(follow).
|
||||
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// now remove the follow request
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// return the new follow
|
||||
return follow, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
|
||||
// first get the follow request out of the database
|
||||
fr := >smodel.FollowRequest{}
|
||||
if err := r.conn.
|
||||
NewSelect().
|
||||
Model(fr).
|
||||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
followRequest := >smodel.FollowRequest{}
|
||||
|
||||
// now delete it from the database by ID
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
Model(>smodel.FollowRequest{ID: fr.ID}).
|
||||
WherePK().
|
||||
Exec(ctx); err != nil {
|
||||
if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// get original follow request
|
||||
if err := tx.
|
||||
NewSelect().
|
||||
Model(followRequest).
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||
Scan(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// now delete it from the database by ID
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// return the deleted follow request
|
||||
return fr, nil
|
||||
return followRequest, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
|
||||
followRequests := []*gtsmodel.FollowRequest{}
|
||||
|
||||
q := r.newFollowQ(&followRequests).
|
||||
Where("target_account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID).
|
||||
Order("follow_request.updated_at DESC")
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return followRequests, nil
|
||||
}
|
||||
|
||||
|
@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
|
|||
follows := []*gtsmodel.Follow{}
|
||||
|
||||
q := r.newFollowQ(&follows).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("follow.account_id"), accountID).
|
||||
Order("follow.updated_at DESC")
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return follows, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
||||
return r.conn.
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Follow{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Count(ctx)
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
|
||||
|
||||
if localOnly {
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")).
|
||||
Where("? = ?", bun.Ident("follow.account_id"), accountID).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
} else {
|
||||
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
|
||||
}
|
||||
|
||||
return q.Count(ctx)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
|
||||
|
@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
|||
Order("follow.updated_at DESC")
|
||||
|
||||
if localOnly {
|
||||
q = q.ColumnExpr("follow.*").
|
||||
Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").
|
||||
Where("follow.target_account_id = ?", accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("a.domain"))
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
} else {
|
||||
q = q.Where("target_account_id = ?", accountID)
|
||||
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
|
||||
}
|
||||
|
||||
err := q.Scan(ctx)
|
||||
|
@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
|||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
||||
return r.conn.
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
Model(&[]*gtsmodel.Follow{}).
|
||||
Where("target_account_id = ?", accountID).
|
||||
Count(ctx)
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow"))
|
||||
|
||||
if localOnly {
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), accountID).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
} else {
|
||||
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID)
|
||||
}
|
||||
|
||||
return q.Count(ctx)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ package bundb_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
@ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
|
|||
suite.False(blocked)
|
||||
|
||||
// have account1 block account2
|
||||
suite.db.Put(ctx, >smodel.Block{
|
||||
if err := suite.db.Put(ctx, >smodel.Block{
|
||||
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
||||
URI: "http://localhost:8080/some_block_uri_1",
|
||||
AccountID: account1,
|
||||
TargetAccountID: account2,
|
||||
})
|
||||
}); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
// account 1 now blocks account 2
|
||||
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
|
||||
|
@ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
|
|||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetBlock() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
ctx := context.Background()
|
||||
|
||||
account1 := suite.testAccounts["local_account_1"].ID
|
||||
account2 := suite.testAccounts["local_account_2"].ID
|
||||
|
||||
if err := suite.db.Put(ctx, >smodel.Block{
|
||||
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
||||
URI: "http://localhost:8080/some_block_uri_1",
|
||||
AccountID: account1,
|
||||
TargetAccountID: account2,
|
||||
}); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
block, err := suite.db.GetBlock(ctx, account1, account2)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(block)
|
||||
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetRelationship() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
|
||||
relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(relationship)
|
||||
|
||||
suite.True(relationship.Following)
|
||||
suite.True(relationship.ShowingReblogs)
|
||||
suite.False(relationship.Notifying)
|
||||
suite.True(relationship.FollowedBy)
|
||||
suite.False(relationship.Blocking)
|
||||
suite.False(relationship.BlockedBy)
|
||||
suite.False(relationship.Muting)
|
||||
suite.False(relationship.MutingNotifications)
|
||||
suite.False(relationship.Requested)
|
||||
suite.False(relationship.DomainBlocking)
|
||||
suite.False(relationship.Endorsed)
|
||||
suite.Empty(relationship.Note)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsFollowing() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
func (suite *RelationshipTestSuite) TestIsFollowingYes() {
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
suite.NoError(err)
|
||||
suite.True(isFollowing)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
|
||||
requestingAccount := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
suite.NoError(err)
|
||||
suite.False(isFollowing)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
suite.NoError(err)
|
||||
suite.True(isMutualFollowing)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) AcceptFollowRequest() {
|
||||
for _, account := range suite.testAccounts {
|
||||
_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID")
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
suite.Suite.Fail("error accepting follow request: %v", err)
|
||||
}
|
||||
func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
suite.NoError(err)
|
||||
suite.True(isMutualFollowing)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
followRequest := >smodel.FollowRequest{
|
||||
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||
AccountID: account.ID,
|
||||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollowRequests() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollows() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) CountAccountFollows() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollowedBy() {
|
||||
// TODO: more comprehensive tests here
|
||||
|
||||
for _, account := range suite.testAccounts {
|
||||
var err error
|
||||
|
||||
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
|
||||
if err != nil {
|
||||
suite.Suite.Fail("error checking accounts followed by: %v", err)
|
||||
}
|
||||
|
||||
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
|
||||
if err != nil {
|
||||
suite.Suite.Fail("error checking localOnly accounts followed by: %v", err)
|
||||
}
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(follow)
|
||||
suite.Equal(followRequest.URI, follow.URI)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) CountAccountFollowedBy() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(follow)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
|
||||
// follow already exists in the db from local_account_1 -> admin_account
|
||||
existingFollow := >smodel.Follow{}
|
||||
if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
followRequest := >smodel.FollowRequest{
|
||||
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||
AccountID: account.ID,
|
||||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(follow)
|
||||
|
||||
// uri should be equal to value of new/overlapping follow request
|
||||
suite.NotEqual(followRequest.URI, existingFollow.URI)
|
||||
suite.Equal(followRequest.URI, follow.URI)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
followRequest := >smodel.FollowRequest{
|
||||
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||
AccountID: account.ID,
|
||||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(rejectedFollowRequest)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(rejectedFollowRequest)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
|
||||
ctx := context.Background()
|
||||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
followRequest := >smodel.FollowRequest{
|
||||
ID: "01GEF753FWHCHRDWR0QEHBXM8W",
|
||||
URI: "http://localhost:8080/weeeeeeeeeeeeeeeee",
|
||||
AccountID: account.ID,
|
||||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.Len(followRequests, 1)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
|
||||
suite.NoError(err)
|
||||
suite.Len(follows, 2)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
|
||||
suite.NoError(err)
|
||||
suite.Len(follows, 2)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
|
||||
suite.NoError(err)
|
||||
suite.Len(follows, 2)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func TestRelationshipTestSuite(t *testing.T) {
|
||||
|
|
|
@ -21,7 +21,6 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
@ -35,29 +34,22 @@ type sessionDB struct {
|
|||
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
|
||||
rss := make([]*gtsmodel.RouterSession, 0, 1)
|
||||
|
||||
_, err := s.conn.
|
||||
// get the first router session in the db or...
|
||||
if err := s.conn.
|
||||
NewSelect().
|
||||
Model(&rss).
|
||||
Limit(1).
|
||||
Order("id DESC").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
Order("router_session.id DESC").
|
||||
Scan(ctx); err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// ... create a new one
|
||||
if len(rss) == 0 {
|
||||
// no session created yet, so make one
|
||||
return s.createSession(ctx)
|
||||
}
|
||||
|
||||
if len(rss) != 1 {
|
||||
// we asked for 1 so we should get 1
|
||||
return nil, errors.New("more than 1 router session was returned")
|
||||
}
|
||||
|
||||
// return the one session found
|
||||
rs := rss[0]
|
||||
return rs, nil
|
||||
return rss[0], nil
|
||||
}
|
||||
|
||||
func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
|
||||
|
@ -71,24 +63,23 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
rid, err := id.NewULID()
|
||||
id, err := id.NewULID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rs := >smodel.RouterSession{
|
||||
ID: rid,
|
||||
ID: id,
|
||||
Auth: auth,
|
||||
Crypt: crypt,
|
||||
}
|
||||
|
||||
q := s.conn.
|
||||
if _, err := s.conn.
|
||||
NewInsert().
|
||||
Model(rs)
|
||||
|
||||
_, err = q.Exec(ctx)
|
||||
if err != nil {
|
||||
Model(rs).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
|
|
@ -37,14 +37,13 @@ func (suite *SessionTestSuite) TestGetSession() {
|
|||
suite.NotEmpty(session.Crypt)
|
||||
suite.NotEmpty(session.ID)
|
||||
|
||||
// TODO -- the same session should be returned with consecutive selects
|
||||
// right now there's an issue with bytea in bun, so uncomment this when that issue is fixed: https://github.com/uptrace/bun/issues/122
|
||||
// session2, err := suite.db.GetSession(context.Background())
|
||||
// suite.NoError(err)
|
||||
// suite.NotNil(session2)
|
||||
// suite.Equal(session.Auth, session2.Auth)
|
||||
// suite.Equal(session.Crypt, session2.Crypt)
|
||||
// suite.Equal(session.ID, session2.ID)
|
||||
// the same session should be returned with consecutive selects
|
||||
session2, err := suite.db.GetSession(context.Background())
|
||||
suite.NoError(err)
|
||||
suite.NotNil(session2)
|
||||
suite.Equal(session.Auth, session2.Auth)
|
||||
suite.Equal(session.Crypt, session2.Crypt)
|
||||
suite.Equal(session.ID, session2.ID)
|
||||
}
|
||||
|
||||
func TestSessionTestSuite(t *testing.T) {
|
||||
|
|
|
@ -72,7 +72,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
|
|||
return s.cache.GetByID(id)
|
||||
},
|
||||
func(status *gtsmodel.Status) error {
|
||||
return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx)
|
||||
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
|
|||
return s.cache.GetByURI(uri)
|
||||
},
|
||||
func(status *gtsmodel.Status) error {
|
||||
return s.newStatusQ(status).Where("status.uri = ?", uri).Scan(ctx)
|
||||
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St
|
|||
return s.cache.GetByURL(url)
|
||||
},
|
||||
func(status *gtsmodel.Status) error {
|
||||
return s.newStatusQ(status).Where("status.url = ?", url).Scan(ctx)
|
||||
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -109,8 +109,7 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta
|
|||
status = >smodel.Status{}
|
||||
|
||||
// Not cached! Perform database query
|
||||
err := dbQuery(status)
|
||||
if err != nil {
|
||||
if err := dbQuery(status); err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
|
@ -138,52 +137,15 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta
|
|||
}
|
||||
|
||||
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
|
||||
return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// create links between this status and any emojis it uses
|
||||
for _, i := range status.EmojiIDs {
|
||||
if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{
|
||||
StatusID: status.ID,
|
||||
EmojiID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// create links between this status and any tags it uses
|
||||
for _, i := range status.TagIDs {
|
||||
if _, err := tx.NewInsert().Model(>smodel.StatusToTag{
|
||||
StatusID: status.ID,
|
||||
TagID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// change the status ID of the media attachments to the new status
|
||||
for _, a := range status.Attachments {
|
||||
a.StatusID = status.ID
|
||||
a.UpdatedAt = time.Now()
|
||||
if _, err := tx.NewUpdate().Model(a).
|
||||
Where("id = ?", a.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, insert the status
|
||||
_, err := tx.NewInsert().Model(status).Exec(ctx)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) {
|
||||
err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// create links between this status and any emojis it uses
|
||||
for _, i := range status.EmojiIDs {
|
||||
if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{
|
||||
StatusID: status.ID,
|
||||
EmojiID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(>smodel.StatusToEmoji{
|
||||
StatusID: status.ID,
|
||||
EmojiID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
err = s.conn.errProc(err)
|
||||
if !errors.Is(err, db.ErrAlreadyExists) {
|
||||
return err
|
||||
|
@ -193,10 +155,78 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
|
|||
|
||||
// create links between this status and any tags it uses
|
||||
for _, i := range status.TagIDs {
|
||||
if _, err := tx.NewInsert().Model(>smodel.StatusToTag{
|
||||
StatusID: status.ID,
|
||||
TagID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(>smodel.StatusToTag{
|
||||
StatusID: status.ID,
|
||||
TagID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
err = s.conn.errProc(err)
|
||||
if !errors.Is(err, db.ErrAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// change the status ID of the media attachments to the new status
|
||||
for _, a := range status.Attachments {
|
||||
a.StatusID = status.ID
|
||||
a.UpdatedAt = time.Now()
|
||||
if _, err := tx.
|
||||
NewUpdate().
|
||||
Model(a).
|
||||
Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
|
||||
Exec(ctx); err != nil {
|
||||
err = s.conn.errProc(err)
|
||||
if !errors.Is(err, db.ErrAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, insert the status
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(status).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
s.cache.Put(status)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) {
|
||||
err := s.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// create links between this status and any emojis it uses
|
||||
for _, i := range status.EmojiIDs {
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(>smodel.StatusToEmoji{
|
||||
StatusID: status.ID,
|
||||
EmojiID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
err = s.conn.errProc(err)
|
||||
if !errors.Is(err, db.ErrAlreadyExists) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create links between this status and any tags it uses
|
||||
for _, i := range status.TagIDs {
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(>smodel.StatusToTag{
|
||||
StatusID: status.ID,
|
||||
TagID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
err = s.conn.errProc(err)
|
||||
if !errors.Is(err, db.ErrAlreadyExists) {
|
||||
return err
|
||||
|
@ -208,23 +238,32 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*
|
|||
for _, a := range status.Attachments {
|
||||
a.StatusID = status.ID
|
||||
a.UpdatedAt = time.Now()
|
||||
if _, err := tx.NewUpdate().Model(a).
|
||||
Where("id = ?", a.ID).
|
||||
if _, err := tx.
|
||||
NewUpdate().
|
||||
Model(a).
|
||||
Where("? = ?", bun.Ident("media_attachment.id"), a.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, update the status itself
|
||||
if _, err := tx.NewUpdate().Model(status).WherePK().Exec(ctx); err != nil {
|
||||
if _, err := tx.
|
||||
NewUpdate().
|
||||
Model(status).
|
||||
Where("? = ?", bun.Ident("status.id"), status.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cache.Put(status)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return status, err
|
||||
s.cache.Put(status)
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
|
||||
|
@ -232,8 +271,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
|
|||
// delete links between this status and any emojis it uses
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
Model(>smodel.StatusToEmoji{}).
|
||||
Where("status_id = ?", id).
|
||||
TableExpr("? AS ?", bun.Ident("status_to_emojis"), bun.Ident("status_to_emoji")).
|
||||
Where("? = ?", bun.Ident("status_to_emoji.status_id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -241,8 +280,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
|
|||
// delete links between this status and any tags it uses
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
Model(>smodel.StatusToTag{}).
|
||||
Where("status_id = ?", id).
|
||||
TableExpr("? AS ?", bun.Ident("status_to_tags"), bun.Ident("status_to_tag")).
|
||||
Where("? = ?", bun.Ident("status_to_tag.status_id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -250,17 +289,20 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
|
|||
// delete the status itself
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
Model(>smodel.Status{ID: id}).
|
||||
WherePK().
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Where("? = ?", bun.Ident("status.id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cache.Invalidate(id)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return s.conn.ProcessError(err)
|
||||
s.cache.Invalidate(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
|
||||
|
@ -312,11 +354,11 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
|
|||
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
Column("id").
|
||||
Where("in_reply_to_id = ?", status.ID)
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Column("status.id").
|
||||
Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID)
|
||||
if minID != "" {
|
||||
q = q.Where("id > ?", minID)
|
||||
q = q.Where("? > ?", bun.Ident("status.id"), minID)
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx, &childIDs); err != nil {
|
||||
|
@ -356,23 +398,35 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
|
|||
}
|
||||
|
||||
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
|
||||
return s.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID).
|
||||
Count(ctx)
|
||||
}
|
||||
|
||||
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
|
||||
return s.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID).
|
||||
Count(ctx)
|
||||
}
|
||||
|
||||
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
|
||||
return s.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
|
||||
Where("? = ?", bun.Ident("status_fave.status_id"), status.ID).
|
||||
Count(ctx)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.StatusFave{}).
|
||||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
|
||||
Where("? = ?", bun.Ident("status_fave.status_id"), status.ID).
|
||||
Where("? = ?", bun.Ident("status_fave.account_id"), accountID)
|
||||
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
@ -380,9 +434,9 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
|
|||
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Status{}).
|
||||
Where("boost_of_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID).
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID)
|
||||
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
@ -390,9 +444,9 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
|
|||
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.StatusMute{}).
|
||||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")).
|
||||
Where("? = ?", bun.Ident("status_mute.status_id"), status.ID).
|
||||
Where("? = ?", bun.Ident("status_mute.account_id"), accountID)
|
||||
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
@ -400,9 +454,9 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
|
|||
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.StatusBookmark{}).
|
||||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
|
||||
Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID).
|
||||
Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID)
|
||||
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
@ -410,8 +464,9 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
|
|||
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
|
||||
faves := []*gtsmodel.StatusFave{}
|
||||
|
||||
q := s.newFaveQ(&faves).
|
||||
Where("status_id = ?", status.ID)
|
||||
q := s.
|
||||
newFaveQ(&faves).
|
||||
Where("? = ?", bun.Ident("status_fave.status_id"), status.ID)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
|
@ -422,8 +477,9 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
|
|||
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
|
||||
reblogs := []*gtsmodel.Status{}
|
||||
|
||||
q := s.newStatusQ(&reblogs).
|
||||
Where("boost_of_id = ?", status.ID)
|
||||
q := s.
|
||||
newStatusQ(&reblogs).
|
||||
Where("? = ?", bun.Ident("status.boost_of_id"), status.ID)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
|
|
|
@ -108,14 +108,14 @@ func (suite *StatusTestSuite) TestGetStatusTwice() {
|
|||
suite.NoError(err)
|
||||
after1 := time.Now()
|
||||
duration1 := after1.Sub(before1)
|
||||
fmt.Println(duration1.Milliseconds())
|
||||
fmt.Println(duration1.Microseconds())
|
||||
|
||||
before2 := time.Now()
|
||||
_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
|
||||
suite.NoError(err)
|
||||
after2 := time.Now()
|
||||
duration2 := after2.Sub(before2)
|
||||
fmt.Println(duration2.Milliseconds())
|
||||
fmt.Println(duration2.Microseconds())
|
||||
|
||||
// second retrieval should be several orders faster since it will be cached now
|
||||
suite.Less(duration2, duration1)
|
||||
|
|
|
@ -34,38 +34,48 @@ type timelineDB struct {
|
|||
}
|
||||
|
||||
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||
// Ensure reasonable
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
|
||||
// Make educated guess for slice size
|
||||
statusIDs := make([]string, 0, limit)
|
||||
|
||||
q := t.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
// Select only IDs from table
|
||||
Column("statuses.id").
|
||||
Column("status.id").
|
||||
// Find out who accountID follows.
|
||||
Join("LEFT JOIN follows ON follows.target_account_id = statuses.account_id AND follows.account_id = ?", accountID).
|
||||
Join("LEFT JOIN ? AS ? ON ? = ? AND ? = ?",
|
||||
bun.Ident("follows"),
|
||||
bun.Ident("follow"),
|
||||
bun.Ident("follow.target_account_id"),
|
||||
bun.Ident("status.account_id"),
|
||||
bun.Ident("follow.account_id"),
|
||||
accountID).
|
||||
// Sort by highest ID (newest) to lowest ID (oldest)
|
||||
Order("statuses.id DESC")
|
||||
Order("status.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
// return only statuses LOWER (ie., older) than maxID
|
||||
q = q.Where("statuses.id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
// return only statuses HIGHER (ie., newer) than sinceID
|
||||
q = q.Where("statuses.id > ?", sinceID)
|
||||
q = q.Where("? > ?", bun.Ident("status.id"), sinceID)
|
||||
}
|
||||
|
||||
if minID != "" {
|
||||
// return only statuses HIGHER (ie., newer) than minID
|
||||
q = q.Where("statuses.id > ?", minID)
|
||||
q = q.Where("? > ?", bun.Ident("status.id"), minID)
|
||||
}
|
||||
|
||||
if local {
|
||||
// return only statuses posted by local account havers
|
||||
q = q.Where("statuses.local = ?", local)
|
||||
q = q.Where("? = ?", bun.Ident("status.local"), local)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
|
@ -78,13 +88,11 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
|
|||
//
|
||||
// This is equivalent to something like WHERE ... AND (... OR ...)
|
||||
// See: https://bun.uptrace.dev/guide/queries.html#select
|
||||
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
|
||||
q = q.WhereGroup(" AND ", func(*bun.SelectQuery) *bun.SelectQuery {
|
||||
return q.
|
||||
WhereOr("follows.account_id = ?", accountID).
|
||||
WhereOr("statuses.account_id = ?", accountID)
|
||||
}
|
||||
|
||||
q = q.WhereGroup(" AND ", whereGroup)
|
||||
WhereOr("? = ?", bun.Ident("follow.account_id"), accountID).
|
||||
WhereOr("? = ?", bun.Ident("status.account_id"), accountID)
|
||||
})
|
||||
|
||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||
return nil, t.conn.ProcessError(err)
|
||||
|
@ -118,28 +126,28 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma
|
|||
|
||||
q := t.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
Column("statuses.id").
|
||||
Where("statuses.visibility = ?", gtsmodel.VisibilityPublic).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_id")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("statuses.boost_of_id")).
|
||||
Order("statuses.id DESC")
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Column("status.id").
|
||||
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
|
||||
Order("status.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("statuses.id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
q = q.Where("statuses.id > ?", sinceID)
|
||||
q = q.Where("? > ?", bun.Ident("status.id"), sinceID)
|
||||
}
|
||||
|
||||
if minID != "" {
|
||||
q = q.Where("statuses.id > ?", minID)
|
||||
q = q.Where("? > ?", bun.Ident("status.id"), minID)
|
||||
}
|
||||
|
||||
if local {
|
||||
q = q.Where("statuses.local = ?", local)
|
||||
q = q.Where("? = ?", bun.Ident("status.local"), local)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
|
@ -181,15 +189,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
|
|||
fq := t.conn.
|
||||
NewSelect().
|
||||
Model(&faves).
|
||||
Where("account_id = ?", accountID).
|
||||
Order("id DESC")
|
||||
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
|
||||
Order("status_fave.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
fq = fq.Where("id < ?", maxID)
|
||||
fq = fq.Where("? < ?", bun.Ident("status_fave.id"), maxID)
|
||||
}
|
||||
|
||||
if minID != "" {
|
||||
fq = fq.Where("id > ?", minID)
|
||||
fq = fq.Where("? > ?", bun.Ident("status_fave.id"), minID)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
|
|
|
@ -38,6 +38,15 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() {
|
|||
suite.Len(s, 6)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
|
||||
viewingAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.Len(s, 16)
|
||||
}
|
||||
|
||||
func TestTimelineTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(TimelineTestSuite))
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db
|
|||
return u.cache.GetByID(id)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx)
|
||||
return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts
|
|||
return u.cache.GetByAccountID(accountID)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx)
|
||||
return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string)
|
|||
return u.cache.GetByEmail(emailAddress)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx)
|
||||
return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok
|
|||
return u.cache.GetByConfirmationToken(confirmationToken)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx)
|
||||
return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
|
|||
if _, err := u.conn.
|
||||
NewUpdate().
|
||||
Model(user).
|
||||
WherePK().
|
||||
Where("? = ?", bun.Ident("user.id"), user.ID).
|
||||
Column(columns...).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, u.conn.ProcessError(err)
|
||||
|
@ -140,8 +140,8 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
|
|||
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
|
||||
if _, err := u.conn.
|
||||
NewDelete().
|
||||
Model(>smodel.User{ID: userID}).
|
||||
WherePK().
|
||||
TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")).
|
||||
Where("? = ?", bun.Ident("user.id"), userID).
|
||||
Exec(ctx); err != nil {
|
||||
return u.conn.ProcessError(err)
|
||||
}
|
||||
|
|
|
@ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
if w.CaseInsensitive {
|
||||
query = "LOWER(?) != LOWER(?)"
|
||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
||||
return
|
||||
}
|
||||
|
||||
query = "? != ?"
|
||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
||||
args = []interface{}{bun.Ident(w.Key), w.Value}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) {
|
|||
return
|
||||
}
|
||||
|
||||
if w.CaseInsensitive {
|
||||
query = "LOWER(?) = LOWER(?)"
|
||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
||||
return
|
||||
}
|
||||
|
||||
query = "? = ?"
|
||||
args = []interface{}{bun.Safe(w.Key), w.Value}
|
||||
args = []interface{}{bun.Ident(w.Key), w.Value}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -24,9 +24,6 @@ type Where struct {
|
|||
Key string
|
||||
// The value to match.
|
||||
Value interface{}
|
||||
// Whether the value (if a string) should be case sensitive or not.
|
||||
// Defaults to false.
|
||||
CaseInsensitive bool
|
||||
// If set, reverse the where.
|
||||
// `WHERE k = v` becomes `WHERE k != v`.
|
||||
// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL`
|
||||
|
|
|
@ -101,7 +101,7 @@ func (p *ProcessingMedia) LoadAttachment(ctx context.Context) (*gtsmodel.MediaAt
|
|||
if !p.insertedInDB {
|
||||
if p.recache {
|
||||
// if it's a recache we should only need to update
|
||||
if err := p.database.UpdateByPrimaryKey(ctx, p.attachment); err != nil {
|
||||
if err := p.database.UpdateByID(ctx, p.attachment, p.attachment.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -40,7 +40,7 @@ func (suite *PruneMetaTestSuite) TestPruneMeta() {
|
|||
zork := suite.testAccounts["local_account_1"]
|
||||
zork.AvatarMediaAttachmentID = ""
|
||||
zork.HeaderMediaAttachmentID = ""
|
||||
if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
|
||||
if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ func (suite *PruneMetaTestSuite) TestPruneMetaTwice() {
|
|||
zork := suite.testAccounts["local_account_1"]
|
||||
zork.AvatarMediaAttachmentID = ""
|
||||
zork.HeaderMediaAttachmentID = ""
|
||||
if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
|
||||
if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
@ -95,14 +95,14 @@ func (suite *PruneMetaTestSuite) TestPruneMetaMultipleAccounts() {
|
|||
zork := suite.testAccounts["local_account_1"]
|
||||
zork.AvatarMediaAttachmentID = ""
|
||||
zork.HeaderMediaAttachmentID = ""
|
||||
if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
|
||||
if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// set zork's unused header as belonging to turtle
|
||||
turtle := suite.testAccounts["local_account_1"]
|
||||
zorkOldHeader.AccountID = turtle.ID
|
||||
if err := suite.db.UpdateByPrimaryKey(ctx, zorkOldHeader, "account_id"); err != nil {
|
||||
if err := suite.db.UpdateByID(ctx, zorkOldHeader, zorkOldHeader.ID, "account_id"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ func (m *manager) pruneOneRemote(ctx context.Context, attachment *gtsmodel.Media
|
|||
|
||||
// update the attachment to reflect that we no longer have it cached
|
||||
if changed {
|
||||
return m.db.UpdateByPrimaryKey(ctx, attachment, "updated_at", "cached")
|
||||
return m.db.UpdateByID(ctx, attachment, attachment.ID, "updated_at", "cached")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -128,15 +128,17 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account
|
|||
instance.ContactAccountUsername = ""
|
||||
instance.ContactAccountID = ""
|
||||
instance.Version = ""
|
||||
if err := p.db.UpdateByPrimaryKey(ctx, instance, updatingColumns...); err != nil {
|
||||
if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
|
||||
}
|
||||
l.Debug("domainBlockProcessSideEffects: instance entry updated")
|
||||
}
|
||||
|
||||
// if we have an instance account for this instance, delete it
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err)
|
||||
if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
|
||||
if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines)
|
||||
|
|
|
@ -55,14 +55,14 @@ func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
|
|||
// remove the domain block reference from the instance, if we have an entry for it
|
||||
i := >smodel.Instance{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{
|
||||
{Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true},
|
||||
{Key: "domain", Value: domainBlock.Domain},
|
||||
{Key: "domain_block_id", Value: id},
|
||||
}, i); err == nil {
|
||||
updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}
|
||||
i.SuspendedAt = time.Time{}
|
||||
i.DomainBlockID = ""
|
||||
i.UpdatedAt = time.Now()
|
||||
if err := p.db.UpdateByPrimaryKey(ctx, i, updatingColumns...); err != nil {
|
||||
if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -224,7 +224,7 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
}
|
||||
}
|
||||
|
||||
if err := p.db.UpdateByPrimaryKey(ctx, i, updatingColumns...); err != nil {
|
||||
if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err))
|
||||
}
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncached() {
|
|||
// uncache the file from local
|
||||
testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"]
|
||||
testAttachment.Cached = testrig.FalseBool()
|
||||
err := suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached")
|
||||
err := suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached")
|
||||
suite.NoError(err)
|
||||
err = suite.storage.Delete(ctx, testAttachment.File.Path)
|
||||
suite.NoError(err)
|
||||
|
@ -124,7 +124,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncachedInterrupted() {
|
|||
// uncache the file from local
|
||||
testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"]
|
||||
testAttachment.Cached = testrig.FalseBool()
|
||||
err := suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached")
|
||||
err := suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached")
|
||||
suite.NoError(err)
|
||||
err = suite.storage.Delete(ctx, testAttachment.File.Path)
|
||||
suite.NoError(err)
|
||||
|
@ -179,7 +179,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileThumbnailUncached() {
|
|||
|
||||
// uncache the file from local
|
||||
testAttachment.Cached = testrig.FalseBool()
|
||||
err = suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached")
|
||||
err = suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached")
|
||||
suite.NoError(err)
|
||||
err = suite.storage.Delete(ctx, testAttachment.File.Path)
|
||||
suite.NoError(err)
|
||||
|
|
|
@ -47,7 +47,7 @@ func (p *processor) Unattach(ctx context.Context, account *gtsmodel.Account, med
|
|||
attachment.UpdatedAt = time.Now()
|
||||
attachment.StatusID = ""
|
||||
|
||||
if err := p.db.UpdateByPrimaryKey(ctx, attachment, updatingColumns...); err != nil {
|
||||
if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err))
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, media
|
|||
updatingColumns = append(updatingColumns, "focus_x", "focus_y")
|
||||
}
|
||||
|
||||
if err := p.db.UpdateByPrimaryKey(ctx, attachment, updatingColumns...); err != nil {
|
||||
if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err))
|
||||
}
|
||||
|
||||
|
|
|
@ -162,27 +162,28 @@ func (p *processor) ProcessMediaIDs(ctx context.Context, form *apimodel.Advanced
|
|||
return nil
|
||||
}
|
||||
|
||||
gtsMediaAttachments := []*gtsmodel.MediaAttachment{}
|
||||
attachments := []string{}
|
||||
attachments := []*gtsmodel.MediaAttachment{}
|
||||
attachmentIDs := []string{}
|
||||
for _, mediaID := range form.MediaIDs {
|
||||
// check these attachments exist
|
||||
a := >smodel.MediaAttachment{}
|
||||
if err := p.db.GetByID(ctx, mediaID, a); err != nil {
|
||||
return fmt.Errorf("invalid media type or media not found for media id %s", mediaID)
|
||||
attachment, err := p.db.GetAttachmentByID(ctx, mediaID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ProcessMediaIDs: invalid media type or media not found for media id %s", mediaID)
|
||||
}
|
||||
// check they belong to the requesting account id
|
||||
if a.AccountID != thisAccountID {
|
||||
return fmt.Errorf("media with id %s does not belong to account %s", mediaID, thisAccountID)
|
||||
|
||||
if attachment.AccountID != thisAccountID {
|
||||
return fmt.Errorf("ProcessMediaIDs: media with id %s does not belong to account %s", mediaID, thisAccountID)
|
||||
}
|
||||
// check they're not already used in a status
|
||||
if a.StatusID != "" || a.ScheduledStatusID != "" {
|
||||
return fmt.Errorf("media with id %s is already attached to a status", mediaID)
|
||||
|
||||
if attachment.StatusID != "" || attachment.ScheduledStatusID != "" {
|
||||
return fmt.Errorf("ProcessMediaIDs: media with id %s is already attached to a status", mediaID)
|
||||
}
|
||||
gtsMediaAttachments = append(gtsMediaAttachments, a)
|
||||
attachments = append(attachments, a.ID)
|
||||
|
||||
attachments = append(attachments, attachment)
|
||||
attachmentIDs = append(attachmentIDs, attachment.ID)
|
||||
}
|
||||
status.Attachments = gtsMediaAttachments
|
||||
status.AttachmentIDs = attachments
|
||||
|
||||
status.Attachments = attachments
|
||||
status.AttachmentIDs = attachmentIDs
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ func (p *processor) ChangePassword(ctx context.Context, user *gtsmodel.User, old
|
|||
user.EncryptedPassword = string(newPasswordHash)
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
if err := p.db.UpdateByPrimaryKey(ctx, user, "encrypted_password", "updated_at"); err != nil {
|
||||
if err := p.db.UpdateByID(ctx, user, user.ID, "encrypted_password", "updated_at"); err != nil {
|
||||
return gtserror.NewErrorInternalError(err, "database error")
|
||||
}
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ func (p *processor) SendConfirmEmail(ctx context.Context, user *gtsmodel.User, u
|
|||
user.LastEmailedAt = time.Now()
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
if err := p.db.UpdateByPrimaryKey(ctx, user, updatingColumns...); err != nil {
|
||||
if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
|
||||
return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err)
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U
|
|||
user.ConfirmationToken = ""
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
if err := p.db.UpdateByPrimaryKey(ctx, user, updatingColumns...); err != nil {
|
||||
if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ func (suite *EmailConfirmTestSuite) TestConfirmEmail() {
|
|||
user.ConfirmationSentAt = time.Now().Add(-5 * time.Minute)
|
||||
user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6"
|
||||
|
||||
err := suite.db.UpdateByPrimaryKey(ctx, user, updatingColumns...)
|
||||
err := suite.db.UpdateByID(ctx, user, user.ID, updatingColumns...)
|
||||
suite.NoError(err)
|
||||
|
||||
// confirm with the token set above
|
||||
|
@ -102,7 +102,7 @@ func (suite *EmailConfirmTestSuite) TestConfirmEmailOldToken() {
|
|||
user.ConfirmationSentAt = time.Now().Add(-192 * time.Hour)
|
||||
user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6"
|
||||
|
||||
err := suite.db.UpdateByPrimaryKey(ctx, user, updatingColumns...)
|
||||
err := suite.db.UpdateByID(ctx, user, user.ID, updatingColumns...)
|
||||
suite.NoError(err)
|
||||
|
||||
// confirm with the token set above
|
||||
|
|
|
@ -187,7 +187,7 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
|
|||
}
|
||||
|
||||
for _, v := range NewTestStatuses() {
|
||||
if err := db.PutStatus(ctx, v); err != nil {
|
||||
if err := db.Put(ctx, v); err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
}
|
||||
|
@ -198,12 +198,24 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
|
|||
}
|
||||
}
|
||||
|
||||
for _, v := range NewTestStatusToEmojis() {
|
||||
if err := db.Put(ctx, v); err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range NewTestTags() {
|
||||
if err := db.Put(ctx, v); err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range NewTestStatusToTags() {
|
||||
if err := db.Put(ctx, v); err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range NewTestMentions() {
|
||||
if err := db.Put(ctx, v); err != nil {
|
||||
log.Panic(err)
|
||||
|
|
|
@ -977,6 +977,15 @@ func NewTestEmojis() map[string]*gtsmodel.Emoji {
|
|||
}
|
||||
}
|
||||
|
||||
func NewTestStatusToEmojis() map[string]*gtsmodel.StatusToEmoji {
|
||||
return map[string]*gtsmodel.StatusToEmoji{
|
||||
"admin_account_status_1_rainbow": {
|
||||
StatusID: "01F8MH75CBF9JFX4ZAD54N0W0R",
|
||||
EmojiID: "01F8MH9H8E4VG3KDYJR9EGPXCQ",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestInstances() map[string]*gtsmodel.Instance {
|
||||
return map[string]*gtsmodel.Instance{
|
||||
"localhost:8080": {
|
||||
|
@ -1540,6 +1549,15 @@ func NewTestTags() map[string]*gtsmodel.Tag {
|
|||
}
|
||||
}
|
||||
|
||||
func NewTestStatusToTags() map[string]*gtsmodel.StatusToTag {
|
||||
return map[string]*gtsmodel.StatusToTag{
|
||||
"admin_account_status_1_welcome": {
|
||||
StatusID: "01F8MH75CBF9JFX4ZAD54N0W0R",
|
||||
TagID: "01F8MHA1A2NF9MJ3WCCQ3K8BSZ",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestMentions returns a map of gts model mentions keyed by their name.
|
||||
func NewTestMentions() map[string]*gtsmodel.Mention {
|
||||
return map[string]*gtsmodel.Mention{
|
||||
|
|
Loading…
Reference in a new issue