mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-01-25 23:48:09 +00:00
Improve GetRemoteStatus and db.GetStatus() logic (#174)
* only fetch status parents / children if explicity requested when dereferencing Signed-off-by: kim (grufwub) <grufwub@gmail.com> * Remove recursive DB GetStatus logic, don't fetch parent unless requested Signed-off-by: kim (grufwub) <grufwub@gmail.com> * StatusCache copies status so there are no thread-safety issues with modified status objects Signed-off-by: kim (grufwub) <grufwub@gmail.com> * remove sqlite test files Signed-off-by: kim (grufwub) <grufwub@gmail.com> * fix bugs introduced by previous commit Signed-off-by: kim (grufwub) <grufwub@gmail.com> * fix not continue on error in loop Signed-off-by: kim (grufwub) <grufwub@gmail.com> * use our own RunInTx implementation (possible fix for nested tx error) Signed-off-by: kim (grufwub) <grufwub@gmail.com> * fix cast statement to work with SQLite Signed-off-by: kim (grufwub) <grufwub@gmail.com> * be less strict about valid status in cache Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add cache=shared ALWAYS for SQLite db instances Signed-off-by: kim (grufwub) <grufwub@gmail.com> * Fix EnrichRemoteAccount when updating account fails Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add nolint tag Signed-off-by: kim (grufwub) <grufwub@gmail.com> * ensure file: prefixes the filename in sqlite addr Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add an account cache, add status author account from db Signed-off-by: kim (grufwub) <grufwub@gmail.com> * Fix incompatible SQLite query Signed-off-by: kim (grufwub) <grufwub@gmail.com> * *actually* use the new getAccount() function in accountsDB Signed-off-by: kim (grufwub) <grufwub@gmail.com> * update cache tests to use test suite Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add RelationshipTestSuite, add tests for methods with changed SQL Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
parent
ed46224573
commit
7d193de25f
36 changed files with 660 additions and 234 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
157
internal/cache/account.go
vendored
Normal file
157
internal/cache/account.go
vendored
Normal file
|
@ -0,0 +1,157 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ReneKroon/ttlcache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
// AccountCache is a wrapper around ttlcache.Cache to provide URL and URI lookups for gtsmodel.Account
|
||||
type AccountCache struct {
|
||||
cache *ttlcache.Cache // map of IDs -> cached accounts
|
||||
urls map[string]string // map of account URLs -> IDs
|
||||
uris map[string]string // map of account URIs -> IDs
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewAccountCache returns a new instantiated AccountCache object
|
||||
func NewAccountCache() *AccountCache {
|
||||
c := AccountCache{
|
||||
cache: ttlcache.NewCache(),
|
||||
urls: make(map[string]string, 100),
|
||||
uris: make(map[string]string, 100),
|
||||
mutex: sync.Mutex{},
|
||||
}
|
||||
|
||||
// Set callback to purge lookup maps on expiration
|
||||
c.cache.SetExpirationCallback(func(key string, value interface{}) {
|
||||
account := value.(*gtsmodel.Account)
|
||||
|
||||
c.mutex.Lock()
|
||||
delete(c.urls, account.URL)
|
||||
delete(c.uris, account.URI)
|
||||
c.mutex.Unlock()
|
||||
})
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
// GetByID attempts to fetch a account from the cache by its ID, you will receive a copy for thread-safety
|
||||
func (c *AccountCache) GetByID(id string) (*gtsmodel.Account, bool) {
|
||||
c.mutex.Lock()
|
||||
account, ok := c.getByID(id)
|
||||
c.mutex.Unlock()
|
||||
return account, ok
|
||||
}
|
||||
|
||||
// GetByURL attempts to fetch a account from the cache by its URL, you will receive a copy for thread-safety
|
||||
func (c *AccountCache) GetByURL(url string) (*gtsmodel.Account, bool) {
|
||||
// Perform safe ID lookup
|
||||
c.mutex.Lock()
|
||||
id, ok := c.urls[url]
|
||||
|
||||
// Not found, unlock early
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Attempt account lookup
|
||||
account, ok := c.getByID(id)
|
||||
c.mutex.Unlock()
|
||||
return account, ok
|
||||
}
|
||||
|
||||
// GetByURI attempts to fetch a account from the cache by its URI, you will receive a copy for thread-safety
|
||||
func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) {
|
||||
// Perform safe ID lookup
|
||||
c.mutex.Lock()
|
||||
id, ok := c.uris[uri]
|
||||
|
||||
// Not found, unlock early
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Attempt account lookup
|
||||
account, ok := c.getByID(id)
|
||||
c.mutex.Unlock()
|
||||
return account, ok
|
||||
}
|
||||
|
||||
// getByID performs an unsafe (no mutex locks) lookup of account by ID, returning a copy of account in cache
|
||||
func (c *AccountCache) getByID(id string) (*gtsmodel.Account, bool) {
|
||||
v, ok := c.cache.Get(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return copyAccount(v.(*gtsmodel.Account)), true
|
||||
}
|
||||
|
||||
// Put places a account in the cache, ensuring that the object place is a copy for thread-safety
|
||||
func (c *AccountCache) Put(account *gtsmodel.Account) {
|
||||
if account == nil || account.ID == "" {
|
||||
panic("invalid account")
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.cache.Set(account.ID, copyAccount(account))
|
||||
if account.URL != "" {
|
||||
c.urls[account.URL] = account.ID
|
||||
}
|
||||
if account.URI != "" {
|
||||
c.uris[account.URI] = account.ID
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
// copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects.
|
||||
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
|
||||
// this should be a relatively cheap process
|
||||
func copyAccount(account *gtsmodel.Account) *gtsmodel.Account {
|
||||
return >smodel.Account{
|
||||
ID: account.ID,
|
||||
Username: account.Username,
|
||||
Domain: account.Domain,
|
||||
AvatarMediaAttachmentID: account.AvatarMediaAttachmentID,
|
||||
AvatarMediaAttachment: nil,
|
||||
AvatarRemoteURL: account.AvatarRemoteURL,
|
||||
HeaderMediaAttachmentID: account.HeaderMediaAttachmentID,
|
||||
HeaderMediaAttachment: nil,
|
||||
HeaderRemoteURL: account.HeaderRemoteURL,
|
||||
DisplayName: account.DisplayName,
|
||||
Fields: account.Fields,
|
||||
Note: account.Note,
|
||||
Memorial: account.Memorial,
|
||||
MovedToAccountID: account.MovedToAccountID,
|
||||
CreatedAt: account.CreatedAt,
|
||||
UpdatedAt: account.UpdatedAt,
|
||||
Bot: account.Bot,
|
||||
Reason: account.Reason,
|
||||
Locked: account.Locked,
|
||||
Discoverable: account.Discoverable,
|
||||
Privacy: account.Privacy,
|
||||
Sensitive: account.Sensitive,
|
||||
Language: account.Language,
|
||||
URI: account.URI,
|
||||
URL: account.URL,
|
||||
LastWebfingeredAt: account.LastWebfingeredAt,
|
||||
InboxURI: account.InboxURI,
|
||||
OutboxURI: account.OutboxURI,
|
||||
FollowingURI: account.FollowingURI,
|
||||
FollowersURI: account.FollowersURI,
|
||||
FeaturedCollectionURI: account.FeaturedCollectionURI,
|
||||
ActorType: account.ActorType,
|
||||
AlsoKnownAs: account.AlsoKnownAs,
|
||||
PrivateKey: account.PrivateKey,
|
||||
PublicKey: account.PublicKey,
|
||||
PublicKeyURI: account.PublicKeyURI,
|
||||
SensitizedAt: account.SensitizedAt,
|
||||
SilencedAt: account.SilencedAt,
|
||||
SuspendedAt: account.SuspendedAt,
|
||||
HideCollections: account.HideCollections,
|
||||
SuspensionOrigin: account.SuspensionOrigin,
|
||||
}
|
||||
}
|
63
internal/cache/account_test.go
vendored
Normal file
63
internal/cache/account_test.go
vendored
Normal file
|
@ -0,0 +1,63 @@
|
|||
package cache_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
type AccountCacheTestSuite struct {
|
||||
suite.Suite
|
||||
data map[string]*gtsmodel.Account
|
||||
cache *cache.AccountCache
|
||||
}
|
||||
|
||||
func (suite *AccountCacheTestSuite) SetupSuite() {
|
||||
suite.data = testrig.NewTestAccounts()
|
||||
}
|
||||
|
||||
func (suite *AccountCacheTestSuite) SetupTest() {
|
||||
suite.cache = cache.NewAccountCache()
|
||||
}
|
||||
|
||||
func (suite *AccountCacheTestSuite) TearDownTest() {
|
||||
suite.data = nil
|
||||
suite.cache = nil
|
||||
}
|
||||
|
||||
func (suite *AccountCacheTestSuite) TestAccountCache() {
|
||||
for _, account := range suite.data {
|
||||
// Place in the cache
|
||||
suite.cache.Put(account)
|
||||
}
|
||||
|
||||
for _, account := range suite.data {
|
||||
var ok bool
|
||||
var check *gtsmodel.Account
|
||||
|
||||
// Check we can retrieve
|
||||
check, ok = suite.cache.GetByID(account.ID)
|
||||
if !ok && !accountIs(account, check) {
|
||||
suite.Fail("Failed to fetch expected account with ID: %s", account.ID)
|
||||
}
|
||||
check, ok = suite.cache.GetByURI(account.URI)
|
||||
if account.URI != "" && !ok && !accountIs(account, check) {
|
||||
suite.Fail("Failed to fetch expected account with URI: %s", account.URI)
|
||||
}
|
||||
check, ok = suite.cache.GetByURL(account.URL)
|
||||
if account.URL != "" && !ok && !accountIs(account, check) {
|
||||
suite.Fail("Failed to fetch expected account with URL: %s", account.URL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountCache(t *testing.T) {
|
||||
suite.Run(t, &AccountCacheTestSuite{})
|
||||
}
|
||||
|
||||
func accountIs(account1, account2 *gtsmodel.Account) bool {
|
||||
return account1.ID == account2.ID && account1.URI == account2.URI && account1.URL == account2.URL
|
||||
}
|
66
internal/cache/status.go
vendored
66
internal/cache/status.go
vendored
|
@ -37,7 +37,7 @@ func NewStatusCache() *StatusCache {
|
|||
return &c
|
||||
}
|
||||
|
||||
// GetByID attempts to fetch a status from the cache by its ID
|
||||
// GetByID attempts to fetch a status from the cache by its ID, you will receive a copy for thread-safety
|
||||
func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {
|
||||
c.mutex.Lock()
|
||||
status, ok := c.getByID(id)
|
||||
|
@ -45,7 +45,7 @@ func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {
|
|||
return status, ok
|
||||
}
|
||||
|
||||
// GetByURL attempts to fetch a status from the cache by its URL
|
||||
// GetByURL attempts to fetch a status from the cache by its URL, you will receive a copy for thread-safety
|
||||
func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {
|
||||
// Perform safe ID lookup
|
||||
c.mutex.Lock()
|
||||
|
@ -63,7 +63,7 @@ func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {
|
|||
return status, ok
|
||||
}
|
||||
|
||||
// GetByURI attempts to fetch a status from the cache by its URI
|
||||
// GetByURI attempts to fetch a status from the cache by its URI, you will receive a copy for thread-safety
|
||||
func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {
|
||||
// Perform safe ID lookup
|
||||
c.mutex.Lock()
|
||||
|
@ -81,26 +81,72 @@ func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {
|
|||
return status, ok
|
||||
}
|
||||
|
||||
// getByID performs an unsafe (no mutex locks) lookup of status by ID
|
||||
// getByID performs an unsafe (no mutex locks) lookup of status by ID, returning a copy of status in cache
|
||||
func (c *StatusCache) getByID(id string) (*gtsmodel.Status, bool) {
|
||||
v, ok := c.cache.Get(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return v.(*gtsmodel.Status), true
|
||||
return copyStatus(v.(*gtsmodel.Status)), true
|
||||
}
|
||||
|
||||
// Put places a status in the cache
|
||||
// Put places a status in the cache, ensuring that the object place is a copy for thread-safety
|
||||
func (c *StatusCache) Put(status *gtsmodel.Status) {
|
||||
if status == nil || status.ID == "" ||
|
||||
status.URL == "" ||
|
||||
status.URI == "" {
|
||||
if status == nil || status.ID == "" {
|
||||
panic("invalid status")
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.cache.Set(status.ID, status)
|
||||
c.cache.Set(status.ID, copyStatus(status))
|
||||
if status.URL != "" {
|
||||
c.urls[status.URL] = status.ID
|
||||
}
|
||||
if status.URI != "" {
|
||||
c.uris[status.URI] = status.ID
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
// copyStatus performs a surface-level copy of status, only keeping attached IDs intact, not the objects.
|
||||
// due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr)
|
||||
// this should be a relatively cheap process
|
||||
func copyStatus(status *gtsmodel.Status) *gtsmodel.Status {
|
||||
return >smodel.Status{
|
||||
ID: status.ID,
|
||||
URI: status.URI,
|
||||
URL: status.URL,
|
||||
Content: status.Content,
|
||||
AttachmentIDs: status.AttachmentIDs,
|
||||
Attachments: nil,
|
||||
TagIDs: status.TagIDs,
|
||||
Tags: nil,
|
||||
MentionIDs: status.MentionIDs,
|
||||
Mentions: nil,
|
||||
EmojiIDs: status.EmojiIDs,
|
||||
Emojis: nil,
|
||||
CreatedAt: status.CreatedAt,
|
||||
UpdatedAt: status.UpdatedAt,
|
||||
Local: status.Local,
|
||||
AccountID: status.AccountID,
|
||||
Account: nil,
|
||||
AccountURI: status.AccountURI,
|
||||
InReplyToID: status.InReplyToID,
|
||||
InReplyTo: nil,
|
||||
InReplyToURI: status.InReplyToURI,
|
||||
InReplyToAccountID: status.InReplyToAccountID,
|
||||
InReplyToAccount: nil,
|
||||
BoostOfID: status.BoostOfID,
|
||||
BoostOf: nil,
|
||||
BoostOfAccountID: status.BoostOfAccountID,
|
||||
BoostOfAccount: nil,
|
||||
ContentWarning: status.ContentWarning,
|
||||
Visibility: status.Visibility,
|
||||
Sensitive: status.Sensitive,
|
||||
Language: status.Language,
|
||||
CreatedWithApplicationID: status.CreatedWithApplicationID,
|
||||
VisibilityAdvanced: status.VisibilityAdvanced,
|
||||
ActivityStreamsType: status.ActivityStreamsType,
|
||||
Text: status.Text,
|
||||
Pinned: status.Pinned,
|
||||
}
|
||||
}
|
||||
|
|
56
internal/cache/status_test.go
vendored
56
internal/cache/status_test.go
vendored
|
@ -3,37 +3,59 @@ package cache_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
func TestStatusCache(t *testing.T) {
|
||||
cache := cache.NewStatusCache()
|
||||
type StatusCacheTestSuite struct {
|
||||
suite.Suite
|
||||
data map[string]*gtsmodel.Status
|
||||
cache *cache.StatusCache
|
||||
}
|
||||
|
||||
// Attempt to place a status
|
||||
status := gtsmodel.Status{
|
||||
ID: "id",
|
||||
URI: "uri",
|
||||
URL: "url",
|
||||
func (suite *StatusCacheTestSuite) SetupSuite() {
|
||||
suite.data = testrig.NewTestStatuses()
|
||||
}
|
||||
|
||||
func (suite *StatusCacheTestSuite) SetupTest() {
|
||||
suite.cache = cache.NewStatusCache()
|
||||
}
|
||||
|
||||
func (suite *StatusCacheTestSuite) TearDownTest() {
|
||||
suite.data = nil
|
||||
suite.cache = nil
|
||||
}
|
||||
|
||||
func (suite *StatusCacheTestSuite) TestStatusCache() {
|
||||
for _, status := range suite.data {
|
||||
// Place in the cache
|
||||
suite.cache.Put(status)
|
||||
}
|
||||
cache.Put(&status)
|
||||
|
||||
for _, status := range suite.data {
|
||||
var ok bool
|
||||
var check *gtsmodel.Status
|
||||
|
||||
// Check we can retrieve
|
||||
check, ok = cache.GetByID(status.ID)
|
||||
if !ok || !statusIs(&status, check) {
|
||||
t.Fatal("Could not find expected status")
|
||||
check, ok = suite.cache.GetByID(status.ID)
|
||||
if !ok && !statusIs(status, check) {
|
||||
suite.Fail("Failed to fetch expected account with ID: %s", status.ID)
|
||||
}
|
||||
check, ok = cache.GetByURI(status.URI)
|
||||
if !ok || !statusIs(&status, check) {
|
||||
t.Fatal("Could not find expected status")
|
||||
check, ok = suite.cache.GetByURI(status.URI)
|
||||
if status.URI != "" && !ok && !statusIs(status, check) {
|
||||
suite.Fail("Failed to fetch expected account with URI: %s", status.URI)
|
||||
}
|
||||
check, ok = cache.GetByURL(status.URL)
|
||||
if !ok || !statusIs(&status, check) {
|
||||
t.Fatal("Could not find expected status")
|
||||
check, ok = suite.cache.GetByURL(status.URL)
|
||||
if status.URL != "" && !ok && !statusIs(status, check) {
|
||||
suite.Fail("Failed to fetch expected account with URL: %s", status.URL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusCache(t *testing.T) {
|
||||
suite.Run(t, &StatusCacheTestSuite{})
|
||||
}
|
||||
|
||||
func statusIs(status1, status2 *gtsmodel.Status) bool {
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
@ -34,6 +35,7 @@ import (
|
|||
type accountDB struct {
|
||||
config *config.Config
|
||||
conn *DBConn
|
||||
cache *cache.AccountCache
|
||||
}
|
||||
|
||||
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
|
||||
|
@ -45,60 +47,80 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
|
|||
}
|
||||
|
||||
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
|
||||
account := new(gtsmodel.Account)
|
||||
|
||||
q := a.newAccountQ(account).
|
||||
Where("account.id = ?", id)
|
||||
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
return account, nil
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
func() (*gtsmodel.Account, bool) {
|
||||
return a.cache.GetByID(id)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
|
||||
account := new(gtsmodel.Account)
|
||||
|
||||
q := a.newAccountQ(account).
|
||||
Where("account.uri = ?", uri)
|
||||
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
return account, nil
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
func() (*gtsmodel.Account, bool) {
|
||||
return a.cache.GetByURI(uri)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
|
||||
account := new(gtsmodel.Account)
|
||||
func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
func() (*gtsmodel.Account, bool) {
|
||||
return a.cache.GetByURL(url)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
q := a.newAccountQ(account).
|
||||
Where("account.url = ?", uri)
|
||||
func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
|
||||
// Attempt to fetch cached account
|
||||
account, cached := cacheGet()
|
||||
|
||||
err := q.Scan(ctx)
|
||||
if !cached {
|
||||
account = >smodel.Account{}
|
||||
|
||||
// Not cached! Perform database query
|
||||
err := dbQuery(account)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Place in the cache
|
||||
a.cache.Put(account)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
|
||||
if strings.TrimSpace(account.ID) == "" {
|
||||
// TODO: we should not need this check here
|
||||
return nil, errors.New("account had no ID")
|
||||
}
|
||||
|
||||
// Update the account's last-used
|
||||
account.UpdatedAt = time.Now()
|
||||
|
||||
q := a.conn.
|
||||
NewUpdate().
|
||||
Model(account).
|
||||
WherePK()
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
// Update the account model in the DB
|
||||
_, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Place updated account in cache
|
||||
// (this will replace existing, i.e. invalidating)
|
||||
a.cache.Put(account)
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -91,6 +91,15 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
|
|||
conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log)
|
||||
case dbTypeSqlite:
|
||||
// SQLITE
|
||||
|
||||
// Drop anything fancy from DB address
|
||||
c.DBConfig.Address = strings.Split(c.DBConfig.Address, "?")[0]
|
||||
c.DBConfig.Address = strings.TrimPrefix(c.DBConfig.Address, "file:")
|
||||
|
||||
// Append our own SQLite preferences
|
||||
c.DBConfig.Address = "file:" + c.DBConfig.Address + "?cache=shared"
|
||||
|
||||
// Open new DB instance
|
||||
var err error
|
||||
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
|
||||
if err != nil {
|
||||
|
@ -98,7 +107,7 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
|
|||
}
|
||||
conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log)
|
||||
|
||||
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
|
||||
if c.DBConfig.Address == "file::memory:?cache=shared" {
|
||||
log.Warn("sqlite in-memory database should only be used for debugging")
|
||||
|
||||
// don't close connections on disconnect -- otherwise
|
||||
|
@ -121,11 +130,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
|
|||
conn.RegisterModel(t)
|
||||
}
|
||||
|
||||
accounts := &accountDB{config: c, conn: conn, cache: cache.NewAccountCache()}
|
||||
|
||||
ps := &bunDBService{
|
||||
Account: &accountDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
},
|
||||
Account: accounts,
|
||||
Admin: &adminDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
|
@ -168,6 +176,7 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
|
|||
config: c,
|
||||
conn: conn,
|
||||
cache: cache.NewStatusCache(),
|
||||
accounts: accounts,
|
||||
},
|
||||
Timeline: &timelineDB{
|
||||
config: c,
|
||||
|
|
|
@ -12,6 +12,8 @@ import (
|
|||
|
||||
// dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
|
||||
type DBConn struct {
|
||||
// TODO: move *Config here, no need to be in each struct type
|
||||
|
||||
errProc func(error) db.Error // errProc is the SQL-type specific error processor
|
||||
log *logrus.Logger // log is the logger passed with this DBConn
|
||||
*bun.DB // DB is the underlying bun.DB connection
|
||||
|
@ -35,6 +37,24 @@ func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn {
|
|||
}
|
||||
}
|
||||
|
||||
func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error {
|
||||
// Acquire a new transaction
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Perform supplied transaction
|
||||
if err = fn(tx); err != nil {
|
||||
tx.Rollback() //nolint
|
||||
return conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Finally, commit transaction
|
||||
err = tx.Commit()
|
||||
return conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// ProcessError processes an error to replace any known values with our own db.Error types,
|
||||
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
|
||||
func (conn *DBConn) ProcessError(err error) db.Error {
|
||||
|
|
|
@ -237,7 +237,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
|||
if _, err := r.conn.
|
||||
NewInsert().
|
||||
Model(follow).
|
||||
On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
|
||||
On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
@ -298,7 +298,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
|||
|
||||
if localOnly {
|
||||
q = q.ColumnExpr("follow.*").
|
||||
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
|
||||
Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").
|
||||
Where("follow.target_account_id = ?", accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("a.domain"))
|
||||
} else {
|
||||
|
|
124
internal/db/bundb/relationship_test.go
Normal file
124
internal/db/bundb/relationship_test.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package bundb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
type RelationshipTestSuite struct {
|
||||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) SetupSuite() {
|
||||
suite.testTokens = testrig.NewTestTokens()
|
||||
suite.testClients = testrig.NewTestClients()
|
||||
suite.testApplications = testrig.NewTestApplications()
|
||||
suite.testUsers = testrig.NewTestUsers()
|
||||
suite.testAccounts = testrig.NewTestAccounts()
|
||||
suite.testAttachments = testrig.NewTestAttachments()
|
||||
suite.testStatuses = testrig.NewTestStatuses()
|
||||
suite.testTags = testrig.NewTestTags()
|
||||
suite.testMentions = testrig.NewTestMentions()
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) SetupTest() {
|
||||
suite.config = testrig.NewTestConfig()
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.log = testrig.NewTestLog()
|
||||
|
||||
testrig.StandardDBSetup(suite.db, suite.testAccounts)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TearDownTest() {
|
||||
testrig.StandardDBTeardown(suite.db)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsBlocked() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetBlock() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetRelationship() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsFollowing() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) AcceptFollowRequest() {
|
||||
for _, account := range suite.testAccounts {
|
||||
_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID")
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
suite.Suite.Fail("error accepting follow request: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollowRequests() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollows() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) CountAccountFollows() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) GetAccountFollowedBy() {
|
||||
// TODO: more comprehensive tests here
|
||||
|
||||
for _, account := range suite.testAccounts {
|
||||
var err error
|
||||
|
||||
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
|
||||
if err != nil {
|
||||
suite.Suite.Fail("error checking accounts followed by: %v", err)
|
||||
}
|
||||
|
||||
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
|
||||
if err != nil {
|
||||
suite.Suite.Fail("error checking localOnly accounts followed by: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) CountAccountFollowedBy() {
|
||||
suite.Suite.T().Skip("TODO: implement")
|
||||
}
|
||||
|
||||
func TestRelationshipTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(RelationshipTestSuite))
|
||||
}
|
Binary file not shown.
|
@ -21,7 +21,6 @@ package bundb
|
|||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
|
@ -35,6 +34,11 @@ type statusDB struct {
|
|||
config *config.Config
|
||||
conn *DBConn
|
||||
cache *cache.StatusCache
|
||||
|
||||
// TODO: keep method definitions in same place but instead have receiver
|
||||
// all point to one single "db" type, so they can all share methods
|
||||
// and caches where necessary
|
||||
accounts *accountDB
|
||||
}
|
||||
|
||||
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
|
||||
|
@ -51,30 +55,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
|
|||
Relation("CreatedWithApplication")
|
||||
}
|
||||
|
||||
func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
|
||||
if status.InReplyToID != "" && status.InReplyTo == nil {
|
||||
// TODO: do we want to keep this possibly recursive strategy?
|
||||
|
||||
if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached {
|
||||
status.InReplyTo = inReplyTo
|
||||
} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
|
||||
status.InReplyTo = inReplyTo
|
||||
}
|
||||
}
|
||||
|
||||
if status.BoostOfID != "" && status.BoostOf == nil {
|
||||
// TODO: do we want to keep this possibly recursive strategy?
|
||||
|
||||
if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached {
|
||||
status.BoostOf = boostOf
|
||||
} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
|
||||
status.BoostOf = boostOf
|
||||
}
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
|
||||
return s.conn.
|
||||
NewSelect().
|
||||
|
@ -85,64 +65,79 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
|
|||
}
|
||||
|
||||
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.cache.GetByID(id); cached {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
status := >smodel.Status{}
|
||||
|
||||
q := s.newStatusQ(status).
|
||||
Where("status.id = ?", id)
|
||||
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
s.cache.Put(status)
|
||||
return s.getAttachedStatuses(ctx, status), nil
|
||||
return s.getStatus(
|
||||
ctx,
|
||||
func() (*gtsmodel.Status, bool) {
|
||||
return s.cache.GetByID(id)
|
||||
},
|
||||
func(status *gtsmodel.Status) error {
|
||||
return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.cache.GetByURI(uri); cached {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
status := >smodel.Status{}
|
||||
|
||||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.uri) = LOWER(?)", uri)
|
||||
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
s.cache.Put(status)
|
||||
return s.getAttachedStatuses(ctx, status), nil
|
||||
return s.getStatus(
|
||||
ctx,
|
||||
func() (*gtsmodel.Status, bool) {
|
||||
return s.cache.GetByURI(uri)
|
||||
},
|
||||
func(status *gtsmodel.Status) error {
|
||||
return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.cache.GetByURL(url); cached {
|
||||
return status, nil
|
||||
}
|
||||
return s.getStatus(
|
||||
ctx,
|
||||
func() (*gtsmodel.Status, bool) {
|
||||
return s.cache.GetByURL(url)
|
||||
},
|
||||
func(status *gtsmodel.Status) error {
|
||||
return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
status := >smodel.Status{}
|
||||
func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) {
|
||||
// Attempt to fetch cached status
|
||||
status, cached := cacheGet()
|
||||
|
||||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.url) = LOWER(?)", url)
|
||||
if !cached {
|
||||
status = >smodel.Status{}
|
||||
|
||||
err := q.Scan(ctx)
|
||||
// Not cached! Perform database query
|
||||
err := dbQuery(status)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// If there is boosted, fetch from DB also
|
||||
if status.BoostOfID != "" {
|
||||
boostOf, err := s.GetStatusByID(ctx, status.BoostOfID)
|
||||
if err == nil {
|
||||
status.BoostOf = boostOf
|
||||
}
|
||||
}
|
||||
|
||||
// Place in the cache
|
||||
s.cache.Put(status)
|
||||
return s.getAttachedStatuses(ctx, status), nil
|
||||
}
|
||||
|
||||
// Set the status author account
|
||||
author, err := s.accounts.GetAccountByID(ctx, status.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the prepared status
|
||||
status.Account = author
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
|
||||
transaction := func(ctx context.Context, tx bun.Tx) error {
|
||||
return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// create links between this status and any emojis it uses
|
||||
for _, i := range status.EmojiIDs {
|
||||
if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{
|
||||
|
@ -174,10 +169,10 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
|
|||
}
|
||||
}
|
||||
|
||||
// Finally, insert the status
|
||||
_, err := tx.NewInsert().Model(status).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
|
||||
|
@ -210,12 +205,8 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu
|
|||
|
||||
children := []*gtsmodel.Status{}
|
||||
for e := foundStatuses.Front(); e != nil; e = e.Next() {
|
||||
entry, ok := e.Value.(*gtsmodel.Status)
|
||||
if !ok {
|
||||
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
|
||||
}
|
||||
|
||||
// only append children, not the overall parent status
|
||||
entry := e.Value.(*gtsmodel.Status)
|
||||
if entry.ID != status.ID {
|
||||
children = append(children, entry)
|
||||
}
|
||||
|
@ -242,11 +233,7 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
|
|||
for _, child := range immediateChildren {
|
||||
insertLoop:
|
||||
for e := foundStatuses.Front(); e != nil; e = e.Next() {
|
||||
entry, ok := e.Value.(*gtsmodel.Status)
|
||||
if !ok {
|
||||
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
|
||||
}
|
||||
|
||||
entry := e.Value.(*gtsmodel.Status)
|
||||
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
|
||||
foundStatuses.InsertAfter(child, e)
|
||||
break insertLoop
|
||||
|
|
|
@ -105,10 +105,9 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
|
|||
suite.NotNil(status)
|
||||
suite.NotNil(status.Account)
|
||||
suite.NotNil(status.CreatedWithApplication)
|
||||
suite.NotEmpty(status.Mentions)
|
||||
suite.NotEmpty(status.MentionIDs)
|
||||
suite.NotNil(status.InReplyTo)
|
||||
suite.NotNil(status.InReplyToAccount)
|
||||
suite.NotEmpty(status.InReplyToID)
|
||||
suite.NotEmpty(status.InReplyToAccountID)
|
||||
}
|
||||
|
||||
func (suite *StatusTestSuite) TestGetStatusTwice() {
|
||||
|
|
|
@ -26,13 +26,13 @@ import (
|
|||
|
||||
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
|
||||
type Status interface {
|
||||
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
|
||||
// GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs
|
||||
GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
|
||||
|
||||
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
|
||||
// GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs
|
||||
GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
|
||||
|
||||
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
|
||||
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
|
||||
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
|
||||
|
||||
// PutStatus stores one status in the database.
|
||||
|
|
|
@ -34,12 +34,12 @@ func (f *federator) EnrichRemoteAccount(ctx context.Context, username string, ac
|
|||
return f.dereferencer.EnrichRemoteAccount(ctx, username, account)
|
||||
}
|
||||
|
||||
func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
|
||||
return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh)
|
||||
func (f *federator) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
|
||||
return f.dereferencer.GetRemoteStatus(ctx, username, remoteStatusID, refresh, includeParent, includeChilds)
|
||||
}
|
||||
|
||||
func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
|
||||
return f.dereferencer.EnrichRemoteStatus(ctx, username, status)
|
||||
func (f *federator) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) {
|
||||
return f.dereferencer.EnrichRemoteStatus(ctx, username, status, includeParent, includeChilds)
|
||||
}
|
||||
|
||||
func (f *federator) DereferenceRemoteThread(ctx context.Context, username string, statusIRI *url.URL) error {
|
||||
|
|
|
@ -48,7 +48,6 @@ func instanceAccount(account *gtsmodel.Account) bool {
|
|||
// EnrichRemoteAccount is mostly useful for calling after an account has been initially created by
|
||||
// the federatingDB's Create function, or during the federated authorization flow.
|
||||
func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
|
||||
|
||||
// if we're dealing with an instance account, we don't need to update anything
|
||||
if instanceAccount(account) {
|
||||
return account, nil
|
||||
|
@ -58,13 +57,13 @@ func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, accoun
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var err error
|
||||
account, err = d.db.UpdateAccount(ctx, account)
|
||||
updated, err := d.db.UpdateAccount(ctx, account)
|
||||
if err != nil {
|
||||
d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err)
|
||||
return account, nil
|
||||
}
|
||||
|
||||
return account, nil
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account,
|
||||
|
|
|
@ -46,7 +46,7 @@ func (d *deref) DereferenceAnnounce(ctx context.Context, announce *gtsmodel.Stat
|
|||
return fmt.Errorf("DereferenceAnnounce: error dereferencing thread of boosted status: %s", err)
|
||||
}
|
||||
|
||||
boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false)
|
||||
boostedStatus, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, boostedStatusURI, false, false, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err)
|
||||
}
|
||||
|
|
|
@ -38,8 +38,8 @@ type Dereferencer interface {
|
|||
GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
|
||||
EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
|
||||
|
||||
GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
|
||||
EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
|
||||
GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error)
|
||||
EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error)
|
||||
|
||||
GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
|
||||
|
||||
|
|
Binary file not shown.
|
@ -39,8 +39,8 @@ import (
|
|||
//
|
||||
// EnrichRemoteStatus is mostly useful for calling after a status has been initially created by
|
||||
// the federatingDB's Create function, but additional dereferencing is needed on it.
|
||||
func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error) {
|
||||
if err := d.populateStatusFields(ctx, status, username); err != nil {
|
||||
func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error) {
|
||||
if err := d.populateStatusFields(ctx, status, username, includeParent, includeChilds); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ func (d *deref) EnrichRemoteStatus(ctx context.Context, username string, status
|
|||
// If a dereference was performed, then the function also returns the ap.Statusable representation for further processing.
|
||||
//
|
||||
// SIDE EFFECTS: remote status will be stored in the database, and the remote status owner will also be stored.
|
||||
func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
|
||||
func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error) {
|
||||
new := true
|
||||
|
||||
// check if we already have the status in our db
|
||||
|
@ -105,7 +105,7 @@ func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStat
|
|||
}
|
||||
gtsStatus.ID = ulid
|
||||
|
||||
if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil {
|
||||
if err := d.populateStatusFields(ctx, gtsStatus, username, includeParent, includeChilds); err != nil {
|
||||
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
|
||||
}
|
||||
|
||||
|
@ -115,7 +115,7 @@ func (d *deref) GetRemoteStatus(ctx context.Context, username string, remoteStat
|
|||
} else {
|
||||
gtsStatus.ID = maybeStatus.ID
|
||||
|
||||
if err := d.populateStatusFields(ctx, gtsStatus, username); err != nil {
|
||||
if err := d.populateStatusFields(ctx, gtsStatus, username, includeParent, includeChilds); err != nil {
|
||||
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
|
||||
}
|
||||
|
||||
|
@ -235,7 +235,7 @@ func (d *deref) dereferenceStatusable(ctx context.Context, username string, remo
|
|||
// This function will deference all of the above, insert them in the database as necessary,
|
||||
// and attach them to the status. The status itself will not be added to the database yet,
|
||||
// that's up the caller to do.
|
||||
func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string) error {
|
||||
func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Status, requestingUsername string, includeParent, includeChilds bool) error {
|
||||
l := d.log.WithFields(logrus.Fields{
|
||||
"func": "dereferenceStatusFields",
|
||||
"status": fmt.Sprintf("%+v", status),
|
||||
|
@ -275,15 +275,20 @@ func (d *deref) populateStatusFields(ctx context.Context, status *gtsmodel.Statu
|
|||
// 3. Emojis
|
||||
// TODO
|
||||
|
||||
// 4. Mentions
|
||||
// 4. Mentions (only if requested)
|
||||
// TODO: do we need to handle removing empty mention objects and just using mention IDs slice?
|
||||
if includeChilds {
|
||||
if err := d.populateStatusMentions(ctx, status, requestingUsername); err != nil {
|
||||
return fmt.Errorf("populateStatusFields: error populating status mentions: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Replied-to-status.
|
||||
// 5. Replied-to-status (only if requested)
|
||||
if includeParent {
|
||||
if err := d.populateStatusRepliedTo(ctx, status, requestingUsername); err != nil {
|
||||
return fmt.Errorf("populateStatusFields: error populating status repliedTo: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -391,7 +396,6 @@ func (d *deref) populateStatusAttachments(ctx context.Context, status *gtsmodel.
|
|||
attachments := []*gtsmodel.MediaAttachment{}
|
||||
|
||||
for _, a := range status.Attachments {
|
||||
|
||||
aURL, err := url.Parse(a.RemoteURL)
|
||||
if err != nil {
|
||||
l.Errorf("populateStatusAttachments: couldn't parse attachment url %s: %s", a.RemoteURL, err)
|
||||
|
@ -401,6 +405,7 @@ func (d *deref) populateStatusAttachments(ctx context.Context, status *gtsmodel.
|
|||
attachment, err := d.GetRemoteAttachment(ctx, requestingUsername, aURL, status.AccountID, status.ID, a.File.ContentType)
|
||||
if err != nil {
|
||||
l.Errorf("populateStatusAttachments: couldn't get remote attachment %s: %s", a.RemoteURL, err)
|
||||
continue
|
||||
}
|
||||
|
||||
attachmentIDs = append(attachmentIDs, attachment.ID)
|
||||
|
@ -420,27 +425,14 @@ func (d *deref) populateStatusRepliedTo(ctx context.Context, status *gtsmodel.St
|
|||
return err
|
||||
}
|
||||
|
||||
var replyToStatus *gtsmodel.Status
|
||||
errs := []string{}
|
||||
|
||||
// see if we have the status in our db already
|
||||
if s, err := d.db.GetStatusByURI(ctx, status.InReplyToURI); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
} else {
|
||||
replyToStatus = s
|
||||
replyToStatus, err := d.db.GetStatusByURI(ctx, status.InReplyToURI)
|
||||
if err != nil {
|
||||
// Status was not in the DB, try fetch
|
||||
replyToStatus, _, _, err = d.GetRemoteStatus(ctx, requestingUsername, statusURI, false, false, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("populateStatusRepliedTo: couldn't get reply to status with uri %s: %s", status.InReplyToURI, err)
|
||||
}
|
||||
|
||||
if replyToStatus == nil {
|
||||
// didn't find the status in our db, try to get it remotely
|
||||
if s, _, _, err := d.GetRemoteStatus(ctx, requestingUsername, statusURI, false); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
} else {
|
||||
replyToStatus = s
|
||||
}
|
||||
}
|
||||
|
||||
if replyToStatus == nil {
|
||||
return fmt.Errorf("populateStatusRepliedTo: couldn't get reply to status with uri %s: %s", statusURI, strings.Join(errs, " : "))
|
||||
}
|
||||
|
||||
// we have the status
|
||||
|
|
|
@ -119,7 +119,7 @@ func (suite *StatusTestSuite) TestDereferenceSimpleStatus() {
|
|||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
statusURL := testrig.URLMustParse("https://unknown-instance.com/users/brand_new_person/statuses/01FE4NTHKWW7THT67EF10EB839")
|
||||
status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false)
|
||||
status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false, false, false)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(status)
|
||||
suite.NotNil(statusable)
|
||||
|
@ -157,7 +157,7 @@ func (suite *StatusTestSuite) TestDereferenceStatusWithMention() {
|
|||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
statusURL := testrig.URLMustParse("https://unknown-instance.com/users/brand_new_person/statuses/01FE5Y30E3W4P7TRE0R98KAYQV")
|
||||
status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false)
|
||||
status, statusable, new, err := suite.dereferencer.GetRemoteStatus(context.Background(), fetchingAccount.Username, statusURL, false, false, true)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(status)
|
||||
suite.NotNil(statusable)
|
||||
|
|
|
@ -49,7 +49,7 @@ func (d *deref) DereferenceThread(ctx context.Context, username string, statusIR
|
|||
}
|
||||
|
||||
// first make sure we have this status in our db
|
||||
_, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true)
|
||||
_, statusable, _, err := d.GetRemoteStatus(ctx, username, statusIRI, true, false, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DereferenceThread: error getting status with id %s: %s", statusIRI.String(), err)
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI
|
|||
|
||||
// If we reach here, we're looking at a remote status -- make sure we have it in our db by calling GetRemoteStatus
|
||||
// We call it with refresh to true because we want the statusable representation to parse inReplyTo from.
|
||||
status, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true)
|
||||
_, statusable, _, err := d.GetRemoteStatus(ctx, username, &statusIRI, true, false, false)
|
||||
if err != nil {
|
||||
l.Debugf("error getting remote status: %s", err)
|
||||
return nil
|
||||
|
@ -116,18 +116,6 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI
|
|||
return nil
|
||||
}
|
||||
|
||||
// get the ancestor status into our database if we don't have it yet
|
||||
if _, _, _, err := d.GetRemoteStatus(ctx, username, inReplyTo, false); err != nil {
|
||||
l.Debugf("error getting remote status: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// now enrich the current status, since we should have the ancestor in the db
|
||||
if _, err := d.EnrichRemoteStatus(ctx, username, status); err != nil {
|
||||
l.Debugf("error enriching remote status: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// now move up to the next ancestor
|
||||
return d.iterateAncestors(ctx, username, *inReplyTo)
|
||||
}
|
||||
|
@ -226,7 +214,7 @@ pageLoop:
|
|||
foundReplies = foundReplies + 1
|
||||
|
||||
// get the remote statusable and put it in the db
|
||||
_, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false)
|
||||
_, statusable, new, err := d.GetRemoteStatus(ctx, username, itemURI, false, false, false)
|
||||
if new && err == nil && statusable != nil {
|
||||
// now iterate descendants of *that* status
|
||||
if err := d.iterateDescendants(ctx, username, *itemURI, statusable); err != nil {
|
||||
|
|
|
@ -62,8 +62,8 @@ type Federator interface {
|
|||
GetRemoteAccount(ctx context.Context, username string, remoteAccountID *url.URL, refresh bool) (*gtsmodel.Account, bool, error)
|
||||
EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error)
|
||||
|
||||
GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh bool) (*gtsmodel.Status, ap.Statusable, bool, error)
|
||||
EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status) (*gtsmodel.Status, error)
|
||||
GetRemoteStatus(ctx context.Context, username string, remoteStatusID *url.URL, refresh, includeParent, includeChilds bool) (*gtsmodel.Status, ap.Statusable, bool, error)
|
||||
EnrichRemoteStatus(ctx context.Context, username string, status *gtsmodel.Status, includeParent, includeChilds bool) (*gtsmodel.Status, error)
|
||||
|
||||
GetRemoteInstance(ctx context.Context, username string, remoteInstanceURI *url.URL) (*gtsmodel.Instance, error)
|
||||
|
||||
|
@ -88,7 +88,6 @@ type federator struct {
|
|||
|
||||
// NewFederator returns a new federator
|
||||
func NewFederator(db db.DB, federatingDB federatingdb.DB, transportController transport.Controller, config *config.Config, log *logrus.Logger, typeConverter typeutils.TypeConverter, mediaHandler media.Handler) Federator {
|
||||
|
||||
dereferencer := dereferencing.NewDereferencer(config, db, typeConverter, transportController, mediaHandler, log)
|
||||
|
||||
clock := &Clock{}
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -49,7 +49,7 @@ func (p *processor) processFromFederator(ctx context.Context, federatorMsg gtsmo
|
|||
return errors.New("note was not parseable as *gtsmodel.Status")
|
||||
}
|
||||
|
||||
status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus)
|
||||
status, err := p.federator.EnrichRemoteStatus(ctx, federatorMsg.ReceivingAccount.Username, incomingStatus, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -130,7 +130,7 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u
|
|||
|
||||
// we don't have it locally so dereference it if we're allowed to
|
||||
if resolve {
|
||||
status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true)
|
||||
status, _, _, err := p.federator.GetRemoteStatus(ctx, authed.Account.Username, uri, true, false, false)
|
||||
if err == nil {
|
||||
if err := p.federator.DereferenceRemoteThread(ctx, authed.Account.Username, uri); err != nil {
|
||||
// try to deref the thread while we're here
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -339,7 +339,6 @@ func (c *converter) ASStatusToStatus(ctx context.Context, statusable ap.Statusab
|
|||
}
|
||||
|
||||
func (c *converter) ASFollowToFollowRequest(ctx context.Context, followable ap.Followable) (*gtsmodel.FollowRequest, error) {
|
||||
|
||||
idProp := followable.GetJSONLDId()
|
||||
if idProp == nil || !idProp.IsIRI() {
|
||||
return nil, errors.New("no id property set on follow, or was not an iri")
|
||||
|
|
Binary file not shown.
Loading…
Reference in a new issue