forked from mirrors/gotosocial
[bugfix] Fix potential dereference of accounts on own instance (#757)
* add GetAccountByUsernameDomain * simplify search * add escape to not deref accounts on own domain * check if local + we have account by ap uri
This commit is contained in:
parent
2ca234f42e
commit
570fa7c359
8 changed files with 243 additions and 92 deletions
15
internal/cache/account.go
vendored
15
internal/cache/account.go
vendored
|
@ -37,6 +37,7 @@ func NewAccountCache() *AccountCache {
|
|||
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
|
||||
lm.RegisterLookup("uri")
|
||||
lm.RegisterLookup("url")
|
||||
lm.RegisterLookup("usernamedomain")
|
||||
},
|
||||
|
||||
AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
|
||||
|
@ -46,6 +47,7 @@ func NewAccountCache() *AccountCache {
|
|||
if url := acc.URL; url != "" {
|
||||
lm.Set("url", url, acc.ID)
|
||||
}
|
||||
lm.Set("usernamedomain", usernameDomainKey(acc.Username, acc.Domain), acc.ID)
|
||||
},
|
||||
|
||||
DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
|
||||
|
@ -55,6 +57,7 @@ func NewAccountCache() *AccountCache {
|
|||
if url := acc.URL; url != "" {
|
||||
lm.Delete("url", url)
|
||||
}
|
||||
lm.Delete("usernamedomain", usernameDomainKey(acc.Username, acc.Domain))
|
||||
},
|
||||
})
|
||||
c.cache.SetTTL(time.Minute*5, false)
|
||||
|
@ -77,6 +80,10 @@ func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) {
|
|||
return c.cache.GetBy("uri", uri)
|
||||
}
|
||||
|
||||
func (c *AccountCache) GetByUsernameDomain(username string, domain string) (*gtsmodel.Account, bool) {
|
||||
return c.cache.GetBy("usernamedomain", usernameDomainKey(username, domain))
|
||||
}
|
||||
|
||||
// Put places a account in the cache, ensuring that the object place is a copy for thread-safety
|
||||
func (c *AccountCache) Put(account *gtsmodel.Account) {
|
||||
if account == nil || account.ID == "" {
|
||||
|
@ -135,3 +142,11 @@ func copyAccount(account *gtsmodel.Account) *gtsmodel.Account {
|
|||
SuspensionOrigin: account.SuspensionOrigin,
|
||||
}
|
||||
}
|
||||
|
||||
func usernameDomainKey(username string, domain string) string {
|
||||
u := "@" + username
|
||||
if domain != "" {
|
||||
return u + "@" + domain
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
|
4
internal/cache/account_test.go
vendored
4
internal/cache/account_test.go
vendored
|
@ -69,6 +69,10 @@ func (suite *AccountCacheTestSuite) TestAccountCache() {
|
|||
if account.URL != "" && !ok && !accountIs(account, check) {
|
||||
suite.Fail("Failed to fetch expected account with URL: %s", account.URL)
|
||||
}
|
||||
check, ok = suite.cache.GetByUsernameDomain(account.Username, account.Domain)
|
||||
if !ok && !accountIs(account, check) {
|
||||
suite.Fail("Failed to fetch expected account with username/domain: %s/%s", account.Username, account.Domain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,9 @@ type Account interface {
|
|||
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
|
||||
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
|
||||
|
||||
// GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong.
|
||||
GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error)
|
||||
|
||||
// UpdateAccount updates one account by ID.
|
||||
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
|
||||
|
||||
|
|
|
@ -84,6 +84,26 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
|
|||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
func() (*gtsmodel.Account, bool) {
|
||||
return a.cache.GetByUsernameDomain(username, domain)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
q := a.newAccountQ(account).Where("account.username = ?", username)
|
||||
|
||||
if domain != "" {
|
||||
q = q.Where("account.domain = ?", domain)
|
||||
} else {
|
||||
q = q.Where("account.domain IS NULL")
|
||||
}
|
||||
|
||||
return q.Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
|
||||
// Attempt to fetch cached account
|
||||
account, cached := cacheGet()
|
||||
|
|
|
@ -58,6 +58,18 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
|
|||
suite.NotEmpty(account.HeaderMediaAttachment.URL)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
|
||||
testAccount1 := suite.testAccounts["local_account_1"]
|
||||
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(account1)
|
||||
|
||||
testAccount2 := suite.testAccounts["remote_account_1"]
|
||||
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(account2)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestUpdateAccount() {
|
||||
testAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/activity/streams/vocab"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
|
@ -78,7 +79,10 @@ type GetRemoteAccountParams struct {
|
|||
|
||||
// GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account,
|
||||
// puts or updates it in the database (if necessary), and returns it to a caller.
|
||||
func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (remoteAccount *gtsmodel.Account, err error) {
|
||||
//
|
||||
// If a local account is passed into this function for whatever reason (hey, it happens!), then it
|
||||
// will be returned from the database without making any remote calls.
|
||||
func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (foundAccount *gtsmodel.Account, err error) {
|
||||
/*
|
||||
In this function we want to retrieve a gtsmodel representation of a remote account, with its proper
|
||||
accountDomain set, while making as few calls to remote instances as possible to save time and bandwidth.
|
||||
|
@ -99,23 +103,40 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
from that.
|
||||
*/
|
||||
|
||||
// first check if we can retrieve the account locally just with what we've been given
|
||||
skipResolve := params.SkipResolve
|
||||
|
||||
// this first step checks if we have the
|
||||
// account in the database somewhere already
|
||||
switch {
|
||||
case params.RemoteAccountID != nil:
|
||||
// try with uri
|
||||
if a, dbErr := d.db.GetAccountByURI(ctx, params.RemoteAccountID.String()); dbErr == nil {
|
||||
remoteAccount = a
|
||||
uri := params.RemoteAccountID
|
||||
host := uri.Host
|
||||
if host == config.GetHost() || host == config.GetAccountDomain() {
|
||||
// this is actually a local account,
|
||||
// make sure we don't try to resolve
|
||||
skipResolve = true
|
||||
}
|
||||
|
||||
if a, dbErr := d.db.GetAccountByURI(ctx, uri.String()); dbErr == nil {
|
||||
foundAccount = a
|
||||
} else if dbErr != db.ErrNoEntries {
|
||||
err = fmt.Errorf("GetRemoteAccount: database error looking for account %s: %s", params.RemoteAccountID, err)
|
||||
err = fmt.Errorf("GetRemoteAccount: database error looking for account with uri %s: %s", uri, err)
|
||||
}
|
||||
case params.RemoteAccountUsername != "" && (params.RemoteAccountHost == "" || params.RemoteAccountHost == config.GetHost() || params.RemoteAccountHost == config.GetAccountDomain()):
|
||||
// either no domain is provided or this seems
|
||||
// to be a local account, so don't resolve
|
||||
skipResolve = true
|
||||
|
||||
if a, dbErr := d.db.GetLocalAccountByUsername(ctx, params.RemoteAccountUsername); dbErr == nil {
|
||||
foundAccount = a
|
||||
} else if dbErr != db.ErrNoEntries {
|
||||
err = fmt.Errorf("GetRemoteAccount: database error looking for local account with username %s: %s", params.RemoteAccountUsername, err)
|
||||
}
|
||||
case params.RemoteAccountUsername != "" && params.RemoteAccountHost != "":
|
||||
// try with username/host
|
||||
a := >smodel.Account{}
|
||||
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: params.RemoteAccountHost}}
|
||||
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
|
||||
remoteAccount = a
|
||||
if a, dbErr := d.db.GetAccountByUsernameDomain(ctx, params.RemoteAccountUsername, params.RemoteAccountHost); dbErr == nil {
|
||||
foundAccount = a
|
||||
} else if dbErr != db.ErrNoEntries {
|
||||
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
|
||||
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and domain %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
|
||||
}
|
||||
default:
|
||||
err = errors.New("GetRemoteAccount: no identifying parameters were set so we cannot get account")
|
||||
|
@ -125,10 +146,11 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
return
|
||||
}
|
||||
|
||||
if params.SkipResolve {
|
||||
// if we can't resolve, return already since there's nothing more we can do
|
||||
if remoteAccount == nil {
|
||||
err = errors.New("GetRemoteAccount: error retrieving account with skipResolve set true")
|
||||
if skipResolve {
|
||||
// if we can't resolve, return already
|
||||
// since there's nothing more we can do
|
||||
if foundAccount == nil {
|
||||
err = errors.New("GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -141,8 +163,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
// ... but we still need the username so we can do a finger for the accountDomain
|
||||
|
||||
// check if we had the account stored already and got it earlier
|
||||
if remoteAccount != nil {
|
||||
params.RemoteAccountUsername = remoteAccount.Username
|
||||
if foundAccount != nil {
|
||||
params.RemoteAccountUsername = foundAccount.Username
|
||||
} else {
|
||||
// if we didn't already have it, we have dereference it from remote and just...
|
||||
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
|
||||
|
@ -167,8 +189,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
// already about what the account domain might be; this var will be overwritten later if necessary
|
||||
var accountDomain string
|
||||
switch {
|
||||
case remoteAccount != nil:
|
||||
accountDomain = remoteAccount.Domain
|
||||
case foundAccount != nil:
|
||||
accountDomain = foundAccount.Domain
|
||||
case params.RemoteAccountID != nil:
|
||||
accountDomain = params.RemoteAccountID.Host
|
||||
default:
|
||||
|
@ -178,7 +200,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
// to save on remote calls: only webfinger if we don't have a remoteAccount yet, or if we haven't
|
||||
// fingered the remote account for at least 2 days; don't finger instance accounts
|
||||
var fingered time.Time
|
||||
if remoteAccount == nil || (remoteAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(remoteAccount)) {
|
||||
if foundAccount == nil || (foundAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(foundAccount)) {
|
||||
accountDomain, params.RemoteAccountID, err = d.fingerRemoteAccount(ctx, params.RequestingUsername, params.RemoteAccountUsername, params.RemoteAccountHost)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetRemoteAccount: error while fingering: %s", err)
|
||||
|
@ -187,14 +209,14 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
fingered = time.Now()
|
||||
}
|
||||
|
||||
if !fingered.IsZero() && remoteAccount == nil {
|
||||
if !fingered.IsZero() && foundAccount == nil {
|
||||
// if we just fingered and now have a discovered account domain but still no account,
|
||||
// we should do a final lookup in the database with the discovered username + accountDomain
|
||||
// to make absolutely sure we don't already have this account
|
||||
a := >smodel.Account{}
|
||||
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: accountDomain}}
|
||||
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
|
||||
remoteAccount = a
|
||||
foundAccount = a
|
||||
} else if dbErr != db.ErrNoEntries {
|
||||
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
|
||||
return
|
||||
|
@ -203,7 +225,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
|
||||
// we may also have some extra information already, like the account we had in the db, or the
|
||||
// accountable representation that we dereferenced from remote
|
||||
if remoteAccount == nil {
|
||||
if foundAccount == nil {
|
||||
// we still don't have the account, so deference it if we didn't earlier
|
||||
if accountable == nil {
|
||||
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
|
||||
|
@ -214,7 +236,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
}
|
||||
|
||||
// then convert
|
||||
remoteAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false)
|
||||
foundAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetRemoteAccount: error converting accountable to account: %s", err)
|
||||
return
|
||||
|
@ -227,18 +249,18 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
err = fmt.Errorf("GetRemoteAccount: error generating new id for account: %s", err)
|
||||
return
|
||||
}
|
||||
remoteAccount.ID = ulid
|
||||
foundAccount.ID = ulid
|
||||
|
||||
_, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking)
|
||||
_, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetRemoteAccount: error populating further account fields: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
remoteAccount.LastWebfingeredAt = fingered
|
||||
remoteAccount.UpdatedAt = time.Now()
|
||||
foundAccount.LastWebfingeredAt = fingered
|
||||
foundAccount.UpdatedAt = time.Now()
|
||||
|
||||
err = d.db.Put(ctx, remoteAccount)
|
||||
err = d.db.Put(ctx, foundAccount)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err)
|
||||
return
|
||||
|
@ -248,9 +270,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
}
|
||||
|
||||
// we had the account already, but now we know the account domain, so update it if it's different
|
||||
if !strings.EqualFold(remoteAccount.Domain, accountDomain) {
|
||||
remoteAccount.Domain = accountDomain
|
||||
remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount)
|
||||
if !strings.EqualFold(foundAccount.Domain, accountDomain) {
|
||||
foundAccount.Domain = accountDomain
|
||||
foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetRemoteAccount: error updating account: %s", err)
|
||||
return
|
||||
|
@ -260,7 +282,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
// make sure the account fields are populated before returning:
|
||||
// the caller might want to block until everything is loaded
|
||||
var fieldsChanged bool
|
||||
fieldsChanged, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking)
|
||||
fieldsChanged, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRemoteAccount: error populating remoteAccount fields: %s", err)
|
||||
}
|
||||
|
@ -268,12 +290,12 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
|
|||
var fingeredChanged bool
|
||||
if !fingered.IsZero() {
|
||||
fingeredChanged = true
|
||||
remoteAccount.LastWebfingeredAt = fingered
|
||||
foundAccount.LastWebfingeredAt = fingered
|
||||
}
|
||||
|
||||
if fieldsChanged || fingeredChanged {
|
||||
remoteAccount.UpdatedAt = time.Now()
|
||||
remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount)
|
||||
foundAccount.UpdatedAt = time.Now()
|
||||
foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err)
|
||||
}
|
||||
|
|
|
@ -21,9 +21,11 @@ package dereferencing_test
|
|||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
@ -42,11 +44,11 @@ func (suite *AccountTestSuite) TestDereferenceGroup() {
|
|||
})
|
||||
suite.NoError(err)
|
||||
suite.NotNil(group)
|
||||
suite.NotNil(group)
|
||||
|
||||
// group values should be set
|
||||
suite.Equal("https://unknown-instance.com/groups/some_group", group.URI)
|
||||
suite.Equal("https://unknown-instance.com/@some_group", group.URL)
|
||||
suite.WithinDuration(time.Now(), group.LastWebfingeredAt, 5*time.Second)
|
||||
|
||||
// group should be in the database
|
||||
dbGroup, err := suite.db.GetAccountByURI(context.Background(), group.URI)
|
||||
|
@ -65,11 +67,11 @@ func (suite *AccountTestSuite) TestDereferenceService() {
|
|||
})
|
||||
suite.NoError(err)
|
||||
suite.NotNil(service)
|
||||
suite.NotNil(service)
|
||||
|
||||
// service values should be set
|
||||
suite.Equal("https://owncast.example.org/federation/user/rgh", service.URI)
|
||||
suite.Equal("https://owncast.example.org/federation/user/rgh", service.URL)
|
||||
suite.WithinDuration(time.Now(), service.LastWebfingeredAt, 5*time.Second)
|
||||
|
||||
// service should be in the database
|
||||
dbService, err := suite.db.GetAccountByURI(context.Background(), service.URI)
|
||||
|
@ -79,6 +81,102 @@ func (suite *AccountTestSuite) TestDereferenceService() {
|
|||
suite.Equal("example.org", dbService.Domain)
|
||||
}
|
||||
|
||||
/*
|
||||
We shouldn't try webfingering or making http calls to dereference local accounts
|
||||
that might be passed into GetRemoteAccount for whatever reason, so these tests are
|
||||
here to make sure that such cases are (basically) short-circuit evaluated and given
|
||||
back as-is without trying to make any calls to one's own instance.
|
||||
*/
|
||||
|
||||
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsRemoteURL() {
|
||||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: fetchingAccount.Username,
|
||||
RemoteAccountID: testrig.URLMustParse(targetAccount.URI),
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.NotNil(fetchedAccount)
|
||||
suite.Empty(fetchedAccount.Domain)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsername() {
|
||||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: fetchingAccount.Username,
|
||||
RemoteAccountUsername: targetAccount.Username,
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.NotNil(fetchedAccount)
|
||||
suite.Empty(fetchedAccount.Domain)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomain() {
|
||||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: fetchingAccount.Username,
|
||||
RemoteAccountUsername: targetAccount.Username,
|
||||
RemoteAccountHost: config.GetHost(),
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.NotNil(fetchedAccount)
|
||||
suite.Empty(fetchedAccount.Domain)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestDereferenceLocalAccountAsUsernameDomainAndURL() {
|
||||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: fetchingAccount.Username,
|
||||
RemoteAccountID: testrig.URLMustParse(targetAccount.URI),
|
||||
RemoteAccountUsername: targetAccount.Username,
|
||||
RemoteAccountHost: config.GetHost(),
|
||||
})
|
||||
suite.NoError(err)
|
||||
suite.NotNil(fetchedAccount)
|
||||
suite.Empty(fetchedAccount.Domain)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsername() {
|
||||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: fetchingAccount.Username,
|
||||
RemoteAccountUsername: "thisaccountdoesnotexist",
|
||||
})
|
||||
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
|
||||
suite.Nil(fetchedAccount)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUsernameDomain() {
|
||||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: fetchingAccount.Username,
|
||||
RemoteAccountUsername: "thisaccountdoesnotexist",
|
||||
RemoteAccountHost: "localhost:8080",
|
||||
})
|
||||
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
|
||||
suite.Nil(fetchedAccount)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestDereferenceLocalAccountWithUnknownUserURI() {
|
||||
fetchingAccount := suite.testAccounts["local_account_1"]
|
||||
|
||||
fetchedAccount, err := suite.dereferencer.GetRemoteAccount(context.Background(), dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: fetchingAccount.Username,
|
||||
RemoteAccountID: testrig.URLMustParse("http://localhost:8080/users/thisaccountdoesnotexist"),
|
||||
})
|
||||
suite.EqualError(err, "GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
|
||||
suite.Nil(fetchedAccount)
|
||||
}
|
||||
|
||||
func TestAccountTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(AccountTestSuite))
|
||||
}
|
||||
|
|
|
@ -39,7 +39,6 @@ import (
|
|||
|
||||
func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *apimodel.SearchQuery) (*apimodel.SearchResult, gtserror.WithCode) {
|
||||
l := log.WithFields(kv.Fields{
|
||||
|
||||
{"query", search.Query},
|
||||
}...)
|
||||
|
||||
|
@ -62,7 +61,7 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
|
|||
|
||||
/*
|
||||
SEARCH BY MENTION
|
||||
check if the query is something like @whatever_username@example.org -- this means it's a remote account
|
||||
check if the query is something like @whatever_username@example.org -- this means it's likely a remote account
|
||||
*/
|
||||
maybeNamestring := query
|
||||
if maybeNamestring[0] != '@' {
|
||||
|
@ -135,7 +134,6 @@ func (p *processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
|
|||
|
||||
func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Status, error) {
|
||||
l := log.WithFields(kv.Fields{
|
||||
|
||||
{"uri", uri.String()},
|
||||
{"resolve", resolve},
|
||||
}...)
|
||||
|
@ -161,67 +159,46 @@ func (p *processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, u
|
|||
}
|
||||
|
||||
func (p *processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL, resolve bool) (*gtsmodel.Account, error) {
|
||||
if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil {
|
||||
return maybeAccount, nil
|
||||
} else if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil {
|
||||
// it might be a web url like http://example.org/@user instead
|
||||
// of an AP uri like http://example.org/users/user, check first
|
||||
if maybeAccount, err := p.db.GetAccountByURL(ctx, uri.String()); err == nil {
|
||||
return maybeAccount, nil
|
||||
}
|
||||
|
||||
if resolve {
|
||||
// we don't have it locally so try and dereference it
|
||||
account, err := p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: authed.Account.Username,
|
||||
RemoteAccountID: uri,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("searchAccountByURI: error dereferencing account with uri %s: %s", uri.String(), err)
|
||||
}
|
||||
return account, nil
|
||||
if uri.Host == config.GetHost() || uri.Host == config.GetAccountDomain() {
|
||||
// this is a local account; if we don't have it now then
|
||||
// we should just bail instead of trying to get it remote
|
||||
if maybeAccount, err := p.db.GetAccountByURI(ctx, uri.String()); err == nil {
|
||||
return maybeAccount, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// we don't have it yet, try to find it remotely
|
||||
return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: authed.Account.Username,
|
||||
RemoteAccountID: uri,
|
||||
Blocking: true,
|
||||
SkipResolve: !resolve,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *processor) searchAccountByMention(ctx context.Context, authed *oauth.Auth, username string, domain string, resolve bool) (*gtsmodel.Account, error) {
|
||||
maybeAcct := >smodel.Account{}
|
||||
var err error
|
||||
|
||||
// if it's a local account we can skip a whole bunch of stuff
|
||||
if domain == config.GetHost() || domain == config.GetAccountDomain() || domain == "" {
|
||||
maybeAcct, err = p.db.GetLocalAccountByUsername(ctx, username)
|
||||
if err != nil && err != db.ErrNoEntries {
|
||||
maybeAcct, err := p.db.GetLocalAccountByUsername(ctx, username)
|
||||
if err == nil || err == db.ErrNoEntries {
|
||||
return maybeAcct, nil
|
||||
}
|
||||
return nil, fmt.Errorf("searchAccountByMention: error getting local account by username: %s", err)
|
||||
}
|
||||
return maybeAcct, nil
|
||||
}
|
||||
|
||||
// it's not a local account so first we'll check if it's in the database already...
|
||||
where := []db.Where{
|
||||
{Key: "username", Value: username, CaseInsensitive: true},
|
||||
{Key: "domain", Value: domain, CaseInsensitive: true},
|
||||
}
|
||||
err = p.db.GetWhere(ctx, where, maybeAcct)
|
||||
if err == nil {
|
||||
// we've got it stored locally already!
|
||||
return maybeAcct, nil
|
||||
}
|
||||
|
||||
if err != db.ErrNoEntries {
|
||||
// if it's not errNoEntries there's been a real database error so bail at this point
|
||||
return nil, fmt.Errorf("searchAccountByMention: database error: %s", err)
|
||||
}
|
||||
|
||||
// we got a db.ErrNoEntries, so we just don't have the account locally stored -- check if we can dereference it
|
||||
if resolve {
|
||||
maybeAcct, err = p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
|
||||
// we don't have it yet, try to find it remotely
|
||||
return p.federator.GetRemoteAccount(ctx, dereferencing.GetRemoteAccountParams{
|
||||
RequestingUsername: authed.Account.Username,
|
||||
RemoteAccountUsername: username,
|
||||
RemoteAccountHost: domain,
|
||||
Blocking: true,
|
||||
SkipResolve: !resolve,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("searchAccountByMention: error getting remote account: %s", err)
|
||||
}
|
||||
return maybeAcct, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue