use new paging logic for GetAccounts() series of admin endpoints, small changes to query building

This commit is contained in:
kim 2024-04-12 15:40:09 +01:00
parent 94d0dd7ad5
commit 05ccf89a4e
7 changed files with 189 additions and 226 deletions

View file

@ -174,6 +174,7 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/paging"
)
func (m *Module) AccountsGETV1Handler(c *gin.Context) {
@ -199,7 +200,7 @@ func (m *Module) AccountsGETV1Handler(c *gin.Context) {
return
}
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 100, 200, 1)
page, errWithCode := paging.ParseIDPage(c, 1, 200, 100)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
@ -326,14 +327,14 @@ func (m *Module) AccountsGETV1Handler(c *gin.Context) {
ByDomain: c.Query(apiutil.AdminByDomainKey),
Email: c.Query(apiutil.AdminEmailKey),
IP: c.Query(apiutil.AdminIPKey),
MaxID: apiutil.ParseMaxID(c.Query(apiutil.MaxIDKey), ""),
SinceID: apiutil.ParseSinceID(c.Query(apiutil.SinceIDKey), ""),
MinID: apiutil.ParseMinID(c.Query(apiutil.MinIDKey), ""),
Limit: limit,
APIVersion: 1,
}
resp, errWithCode := m.processor.Admin().AccountsGet(c.Request.Context(), params)
resp, errWithCode := m.processor.Admin().AccountsGet(
c.Request.Context(),
params,
page,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -147,6 +147,7 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/paging"
)
func (m *Module) AccountsGETV2Handler(c *gin.Context) {
@ -172,7 +173,7 @@ func (m *Module) AccountsGETV2Handler(c *gin.Context) {
return
}
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 100, 200, 1)
page, errWithCode := paging.ParseIDPage(c, 1, 200, 100)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
@ -190,14 +191,14 @@ func (m *Module) AccountsGETV2Handler(c *gin.Context) {
ByDomain: c.Query(apiutil.AdminByDomainKey),
Email: c.Query(apiutil.AdminEmailKey),
IP: c.Query(apiutil.AdminIPKey),
MaxID: apiutil.ParseMaxID(c.Query(apiutil.MaxIDKey), ""),
SinceID: apiutil.ParseSinceID(c.Query(apiutil.SinceIDKey), ""),
MinID: apiutil.ParseMinID(c.Query(apiutil.MinIDKey), ""),
Limit: limit,
APIVersion: 2,
}
resp, errWithCode := m.processor.Admin().AccountsGet(c.Request.Context(), params)
resp, errWithCode := m.processor.Admin().AccountsGet(
c.Request.Context(),
params,
page,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -258,17 +258,6 @@ type AdminGetAccountsRequest struct {
Email string
// Lookup users with this IP address.
IP string
// All results returned will be
// older than the item with this ID.
MaxID string
// All results returned will be
// newer than the item with this ID.
SinceID string
// Returns results immediately newer
// than the item with this ID.
MinID string
// Maximum number of results to return.
Limit int
// API version to use for this request (1 or 2).
// Set internally, not by callers.
APIVersion int

View file

@ -23,6 +23,7 @@ import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging"
)
// Account contains functions related to account getting/setting/creation.
@ -70,11 +71,11 @@ type Account interface {
domain string,
email string,
ip net.IP,
maxID string,
sinceID string,
minID string,
limit int,
) ([]*gtsmodel.Account, error)
page *paging.Page,
) (
[]*gtsmodel.Account,
error,
)
// PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
PopulateAccount(ctx context.Context, account *gtsmodel.Account) error

View file

@ -33,6 +33,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
@ -236,29 +237,45 @@ func (a *accountDB) GetAccounts(
domain string,
email string,
ip net.IP,
maxID string,
sinceID string,
minID string,
limit int,
) ([]*gtsmodel.Account, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
page *paging.Page,
) (
[]*gtsmodel.Account,
error,
) {
var (
accountIDs = make([]string, 0, limit)
accountIDIn []string
useAccountIDIn bool
frontToBack = true
)
// local users lists,
// required for some
// limiting parameters.
users []*gtsmodel.User
// We need users for this query.
users, err := a.state.DB.GetAllUsers(gtscontext.SetBarebones(ctx))
if err != nil {
return nil, fmt.Errorf("error getting users: %w", err)
}
// lazyLoadUsers only loads the users
// slice if it's required by params.
lazyLoadUsers = func() (err error) {
if users == nil {
users, err = a.state.DB.GetAllUsers(gtscontext.SetBarebones(ctx))
if err != nil {
return fmt.Errorf("error getting users: %w", err)
}
}
return nil
}
// Get paging params.
//
// Note this may be min_id OR since_id
// from the API, this gets handled below
// when checking order to reverse slice.
minID = page.GetMin()
maxID = page.GetMax()
limit = page.GetLimit()
order = page.GetOrder()
// Make educated guess for slice size
accountIDs = make([]string, 0, limit)
accountIDIn []string
useAccountIDIn bool
)
q := a.db.
NewSelect().
@ -280,49 +297,27 @@ func (a *accountDB) GetAccounts(
q = q.Where("? < ?", bun.Ident("account.created_at"), maxIDAcct.CreatedAt)
}
if sinceID != "" {
// Return only accounts NEWER
// than account with sinceID.
sinceIDAcct, err := a.GetAccountByID(
gtscontext.SetBarebones(ctx),
sinceID,
)
if err != nil {
return nil, fmt.Errorf("error getting sinceID account %s: %w", sinceID, err)
}
q = q.Where("? > ?", bun.Ident("account.created_at"), sinceIDAcct.CreatedAt)
}
// Return only accounts NEWER
// than account with minID.
if minID != "" {
// Return only accounts NEWER
// than account with minID.
minIDAcct, err := a.GetAccountByID(
gtscontext.SetBarebones(ctx),
sinceID,
minID,
)
if err != nil {
return nil, fmt.Errorf("error getting minID account %s: %w", minID, err)
}
q = q.Where("? > ?", bun.Ident("account.created_at"), minIDAcct.CreatedAt)
// Paging up.
frontToBack = false
}
if origin == "local" {
// Get only local accounts.
q = q.Where("? IS NULL", bun.Ident("account.domain"))
} else if origin == "remote" {
// Get only remote accounts.
q = q.Where("? IS NOT NULL", bun.Ident("account.domain"))
}
switch status {
case "active":
// Get only enabled accounts.
if err := lazyLoadUsers(); err != nil {
return nil, err
}
for _, user := range users {
if !*user.Disabled {
accountIDIn = append(accountIDIn, user.AccountID)
@ -332,6 +327,9 @@ func (a *accountDB) GetAccounts(
case "pending":
// Get only unapproved accounts.
if err := lazyLoadUsers(); err != nil {
return nil, err
}
for _, user := range users {
if !*user.Approved {
accountIDIn = append(accountIDIn, user.AccountID)
@ -341,6 +339,9 @@ func (a *accountDB) GetAccounts(
case "disabled":
// Get only disabled accounts.
if err := lazyLoadUsers(); err != nil {
return nil, err
}
for _, user := range users {
if *user.Disabled {
accountIDIn = append(accountIDIn, user.AccountID)
@ -359,6 +360,9 @@ func (a *accountDB) GetAccounts(
if mods {
// Get only mod accounts.
if err := lazyLoadUsers(); err != nil {
return nil, err
}
for _, user := range users {
if *user.Moderator || *user.Admin {
accountIDIn = append(accountIDIn, user.AccountID)
@ -382,6 +386,9 @@ func (a *accountDB) GetAccounts(
}
if email != "" {
if err := lazyLoadUsers(); err != nil {
return nil, err
}
for _, user := range users {
if user.Email == email || user.UnconfirmedEmail == email {
accountIDIn = append(accountIDIn, user.AccountID)
@ -391,6 +398,9 @@ func (a *accountDB) GetAccounts(
}
if ip != nil {
if err := lazyLoadUsers(); err != nil {
return nil, err
}
for _, user := range users {
if user.SignUpIP.String() == ip.String() {
accountIDIn = append(accountIDIn, user.AccountID)
@ -399,7 +409,37 @@ func (a *accountDB) GetAccounts(
useAccountIDIn = true
}
if origin == "local" && !useAccountIDIn {
// In the case we're not already limiting
// by specific subset of account IDs, just
// use existing list of user.AccountIDs
// instead of adding WHERE to the query.
if err := lazyLoadUsers(); err != nil {
return nil, err
}
for _, user := range users {
accountIDIn = append(accountIDIn, user.AccountID)
}
useAccountIDIn = true
} else if origin == "remote" {
// Get only remote accounts.
q = q.Where("? IS NOT NULL", bun.Ident("account.domain"))
if useAccountIDIn {
// useAccountIDIn specifically indicates
// a parameter that limits querying to
// local accounts, there will be none.
return nil, nil
}
}
if useAccountIDIn {
if len(accountIDIn) == 0 {
// There will be no
// possible answer.
return nil, nil
}
q = q.Where("? IN (?)", bun.Ident("account.id"), bun.In(accountIDIn))
}
@ -409,12 +449,12 @@ func (a *accountDB) GetAccounts(
q = q.Limit(limit)
}
if frontToBack {
// Page down.
q = q.Order("account.created_at DESC")
} else {
if order == paging.OrderAscending {
// Page up.
q = q.Order("account.created_at ASC")
} else {
// Page down.
q = q.Order("account.created_at DESC")
}
if err := q.Scan(ctx, &accountIDs); err != nil {
@ -427,7 +467,7 @@ func (a *accountDB) GetAccounts(
// If we're paging up, we still want accounts
// to be sorted by createdAt desc, so reverse ids slice.
if !frontToBack {
if order == paging.OrderAscending {
slices.Reverse(accountIDs)
}

View file

@ -34,6 +34,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -494,20 +495,17 @@ func (suite *AccountTestSuite) TestPopulateAccountWithUnknownMovedToURI() {
func (suite *AccountTestSuite) TestGetAccountsAll() {
var (
ctx = context.Background()
origin = ""
status = ""
mods = false
invitedBy = ""
username = ""
displayName = ""
domain = ""
email = ""
ip net.IP = nil
maxID = ""
sinceID = ""
minID = ""
limit = 100
ctx = context.Background()
origin = ""
status = ""
mods = false
invitedBy = ""
username = ""
displayName = ""
domain = ""
email = ""
ip net.IP = nil
page *paging.Page = nil
)
accounts, err := suite.db.GetAccounts(
@ -521,10 +519,7 @@ func (suite *AccountTestSuite) TestGetAccountsAll() {
domain,
email,
ip,
maxID,
sinceID,
minID,
limit,
page,
)
if err != nil {
suite.FailNow(err.Error())
@ -545,10 +540,9 @@ func (suite *AccountTestSuite) TestGetAccountsModsOnly() {
domain = ""
email = ""
ip net.IP = nil
maxID = ""
sinceID = ""
minID = ""
limit = 100
page = &paging.Page{
Limit: 100,
}
)
accounts, err := suite.db.GetAccounts(
@ -562,10 +556,7 @@ func (suite *AccountTestSuite) TestGetAccountsModsOnly() {
domain,
email,
ip,
maxID,
sinceID,
minID,
limit,
page,
)
if err != nil {
suite.FailNow(err.Error())
@ -586,10 +577,9 @@ func (suite *AccountTestSuite) TestGetAccountsLocalWithEmail() {
domain = ""
email = "tortle.dude@example.org"
ip net.IP = nil
maxID = ""
sinceID = ""
minID = ""
limit = 100
page = &paging.Page{
Limit: 100,
}
)
accounts, err := suite.db.GetAccounts(
@ -603,10 +593,7 @@ func (suite *AccountTestSuite) TestGetAccountsLocalWithEmail() {
domain,
email,
ip,
maxID,
sinceID,
minID,
limit,
page,
)
if err != nil {
suite.FailNow(err.Error())
@ -627,10 +614,9 @@ func (suite *AccountTestSuite) TestGetPendingAccounts() {
domain = ""
email = ""
ip net.IP = nil
maxID = ""
sinceID = ""
minID = ""
limit = 100
page = &paging.Page{
Limit: 100,
}
)
accounts, err := suite.db.GetAccounts(
@ -644,10 +630,7 @@ func (suite *AccountTestSuite) TestGetPendingAccounts() {
domain,
email,
ip,
maxID,
sinceID,
minID,
limit,
page,
)
if err != nil {
suite.FailNow(err.Error())

View file

@ -22,6 +22,7 @@ import (
"errors"
"fmt"
"net"
"net/url"
"slices"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@ -29,13 +30,17 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/paging"
)
func (p *Processor) AccountsGet(
ctx context.Context,
request *apimodel.AdminGetAccountsRequest,
) (*apimodel.PageableResponse, gtserror.WithCode) {
page *paging.Page,
) (
*apimodel.PageableResponse,
gtserror.WithCode,
) {
// Validate "origin".
if v := request.Origin; v != "" {
valid := []string{"local", "remote"}
@ -84,10 +89,7 @@ func (p *Processor) AccountsGet(
request.ByDomain,
request.Email,
ip,
request.MaxID,
request.SinceID,
request.MinID,
request.Limit,
page,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = gtserror.Newf("db error getting accounts: %w", err)
@ -96,11 +98,11 @@ func (p *Processor) AccountsGet(
count := len(accounts)
if count == 0 {
return util.EmptyPageableResponse(), nil
return paging.EmptyResponse(), nil
}
nextMax := accounts[count-1].ID
prevMin := accounts[0].ID
hi := accounts[count-1].ID
lo := accounts[0].ID
items := make([]interface{}, 0, count)
for _, account := range accounts {
@ -109,7 +111,6 @@ func (p *Processor) AccountsGet(
log.Errorf(ctx, "error converting to api account: %v", err)
continue
}
items = append(items, apiAccount)
}
@ -117,10 +118,10 @@ func (p *Processor) AccountsGet(
// the API version used to call this function.
switch request.APIVersion {
case 1:
return packageAccountsV1(items, nextMax, prevMin, request)
return packageAccountsV1(items, lo, hi, request, page)
case 2:
return packageAccountsV2(items, nextMax, prevMin, request)
return packageAccountsV2(items, lo, hi, request, page)
default:
log.Panic(ctx, "api version was neither 1 nor 2")
@ -130,11 +131,11 @@ func (p *Processor) AccountsGet(
func packageAccountsV1(
items []interface{},
nextMax string,
prevMin string,
loID, hiID string,
request *apimodel.AdminGetAccountsRequest,
page *paging.Page,
) (*apimodel.PageableResponse, gtserror.WithCode) {
extraQueryParams := []string{}
queryParams := make(url.Values, 8)
// Translate origin to v1.
if v := request.Origin; v != "" {
@ -146,10 +147,7 @@ func packageAccountsV1(
k = apiutil.AdminRemoteKey
}
extraQueryParams = append(
extraQueryParams,
k+"=true",
)
queryParams.Add(k, "true")
}
// Translate status to v1.
@ -169,142 +167,92 @@ func packageAccountsV1(
k = apiutil.AdminSuspendedKey
}
extraQueryParams = append(
extraQueryParams,
k+"=true",
)
queryParams.Add(k, "true")
}
if v := request.Username; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.UsernameKey+"="+v,
)
queryParams.Add(apiutil.UsernameKey, v)
}
if v := request.DisplayName; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminDisplayNameKey+"="+v,
)
queryParams.Add(apiutil.AdminDisplayNameKey, v)
}
if v := request.ByDomain; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminByDomainKey+"="+v,
)
queryParams.Add(apiutil.AdminByDomainKey, v)
}
if v := request.Email; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminEmailKey+"="+v,
)
queryParams.Add(apiutil.AdminEmailKey, v)
}
if v := request.IP; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminIPKey+"="+v,
)
queryParams.Add(apiutil.AdminIPKey, v)
}
// Translate permissions to v1.
if v := request.Permissions; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminStaffKey+"=true",
)
queryParams.Add(apiutil.AdminStaffKey, v)
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "/api/v1/admin/accounts",
NextMaxIDValue: nextMax,
PrevMinIDValue: prevMin,
Limit: request.Limit,
ExtraQueryParams: extraQueryParams,
})
return paging.PackageResponse(paging.ResponseParams{
Items: items,
Path: "/api/v1/admin/accounts",
Next: page.Next(loID, hiID),
Prev: page.Prev(loID, hiID),
Query: queryParams,
}), nil
}
func packageAccountsV2(
items []interface{},
nextMax string,
prevMin string,
loID, hiID string,
request *apimodel.AdminGetAccountsRequest,
page *paging.Page,
) (*apimodel.PageableResponse, gtserror.WithCode) {
extraQueryParams := []string{}
queryParams := make(url.Values, 9)
if v := request.Origin; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminOriginKey+"="+v,
)
queryParams.Add(apiutil.AdminOriginKey, v)
}
if v := request.Status; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminStatusKey+"="+v,
)
queryParams.Add(apiutil.AdminStatusKey, v)
}
if v := request.Permissions; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminPermissionsKey+"="+v,
)
queryParams.Add(apiutil.AdminPermissionsKey, v)
}
if v := request.InvitedBy; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminInvitedByKey+"="+v,
)
queryParams.Add(apiutil.AdminInvitedByKey, v)
}
if v := request.Username; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.UsernameKey+"="+v,
)
queryParams.Add(apiutil.UsernameKey, v)
}
if v := request.DisplayName; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminDisplayNameKey+"="+v,
)
queryParams.Add(apiutil.AdminDisplayNameKey, v)
}
if v := request.ByDomain; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminByDomainKey+"="+v,
)
queryParams.Add(apiutil.AdminByDomainKey, v)
}
if v := request.Email; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminEmailKey+"="+v,
)
queryParams.Add(apiutil.AdminEmailKey, v)
}
if v := request.IP; v != "" {
extraQueryParams = append(
extraQueryParams,
apiutil.AdminIPKey+"="+v,
)
queryParams.Add(apiutil.AdminIPKey, v)
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "/api/v2/admin/accounts",
NextMaxIDValue: nextMax,
PrevMinIDValue: prevMin,
Limit: request.Limit,
ExtraQueryParams: extraQueryParams,
})
return paging.PackageResponse(paging.ResponseParams{
Items: items,
Path: "/api/v2/admin/accounts",
Next: page.Next(loID, hiID),
Prev: page.Prev(loID, hiID),
Query: queryParams,
}), nil
}