[feature] add paging to account follows, followers and follow requests endpoints (#2186)

This commit is contained in:
kim 2023-09-12 14:00:35 +01:00 committed by GitHub
parent 4b594516ec
commit 7293d6029b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
51 changed files with 2281 additions and 641 deletions

View file

@ -3072,6 +3072,13 @@ paths:
- accounts - accounts
/api/v1/accounts/{id}/followers: /api/v1/accounts/{id}/followers:
get: get:
description: |-
The next and previous queries can be parsed from the returned Link header.
Example:
```
<https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
````
operationId: accountFollowers operationId: accountFollowers
parameters: parameters:
- description: Account ID. - description: Account ID.
@ -3079,6 +3086,25 @@ paths:
name: id name: id
required: true required: true
type: string type: string
- description: 'Return only follower accounts *OLDER* than the given max ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
in: query
name: max_id
type: string
- description: 'Return only follower accounts *NEWER* than the given since ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
in: query
name: since_id
type: string
- description: 'Return only follower accounts *IMMEDIATELY NEWER* than the given min ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
in: query
name: min_id
type: string
- default: 40
description: Number of follower accounts to return.
in: query
maximum: 80
minimum: 1
name: limit
type: integer
produces: produces:
- application/json - application/json
responses: responses:
@ -3106,6 +3132,13 @@ paths:
- accounts - accounts
/api/v1/accounts/{id}/following: /api/v1/accounts/{id}/following:
get: get:
description: |-
The next and previous queries can be parsed from the returned Link header.
Example:
```
<https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
````
operationId: accountFollowing operationId: accountFollowing
parameters: parameters:
- description: Account ID. - description: Account ID.
@ -3113,6 +3146,25 @@ paths:
name: id name: id
required: true required: true
type: string type: string
- description: 'Return only following accounts *OLDER* than the given max ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
in: query
name: max_id
type: string
- description: 'Return only following accounts *NEWER* than the given since ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
in: query
name: since_id
type: string
- description: 'Return only following accounts *IMMEDIATELY NEWER* than the given min ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.'
in: query
name: min_id
type: string
- default: 40
description: Number of following accounts to return.
in: query
maximum: 80
minimum: 1
name: limit
type: integer
produces: produces:
- application/json - application/json
responses: responses:
@ -4679,19 +4731,25 @@ paths:
```` ````
operationId: blocksGet operationId: blocksGet
parameters: parameters:
- default: 20 - description: 'Return only blocked accounts *OLDER* than the given max ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.'
description: Number of blocks to return.
in: query
name: limit
type: integer
- description: Return only blocks *OLDER* than the given block ID. The block with the specified ID will not be included in the response.
in: query in: query
name: max_id name: max_id
type: string type: string
- description: Return only blocks *NEWER* than the given block ID. The block with the specified ID will not be included in the response. - description: 'Return only blocked accounts *NEWER* than the given since ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.'
in: query in: query
name: since_id name: since_id
type: string type: string
- description: 'Return only blocked accounts *IMMEDIATELY NEWER* than the given min ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.'
in: query
name: min_id
type: string
- default: 40
description: Number of blocked accounts to return.
in: query
maximum: 80
minimum: 1
name: limit
type: integer
produces: produces:
- application/json - application/json
responses: responses:
@ -4857,12 +4915,32 @@ paths:
- featured_tags - featured_tags
/api/v1/follow_requests: /api/v1/follow_requests:
get: get:
description: Accounts will be sorted in order of follow request date descending (newest first). description: |-
The next and previous queries can be parsed from the returned Link header.
Example:
```
<https://example.org/api/v1/follow_requests?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/follow_requests?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
````
operationId: getFollowRequests operationId: getFollowRequests
parameters: parameters:
- default: 40 - description: 'Return only follow requesting accounts *OLDER* than the given max ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.'
description: Number of accounts to return.
in: query in: query
name: max_id
type: string
- description: 'Return only follow requesting accounts *NEWER* than the given since ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.'
in: query
name: since_id
type: string
- description: 'Return only follow requesting accounts *IMMEDIATELY NEWER* than the given min ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.'
in: query
name: min_id
type: string
- default: 40
description: Number of follow requesting accounts to return.
in: query
maximum: 80
minimum: 1
name: limit name: limit
type: integer type: integer
produces: produces:

1
go.mod
View file

@ -46,6 +46,7 @@ require (
github.com/superseriousbusiness/exif-terminator v0.5.0 github.com/superseriousbusiness/exif-terminator v0.5.0
github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB.0.20230227143000-f4900831d6c8 github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB.0.20230227143000-f4900831d6c8
github.com/tdewolff/minify/v2 v2.12.9 github.com/tdewolff/minify/v2 v2.12.9
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
github.com/ulule/limiter/v3 v3.11.2 github.com/ulule/limiter/v3 v3.11.2
github.com/uptrace/bun v1.1.15 github.com/uptrace/bun v1.1.15
github.com/uptrace/bun/dialect/pgdialect v1.1.15 github.com/uptrace/bun/dialect/pgdialect v1.1.15

2
go.sum
View file

@ -568,6 +568,8 @@ github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ= github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y=
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=

View file

@ -18,21 +18,33 @@
package accounts_test package accounts_test
import ( import (
"context"
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
"github.com/tomnomnom/linkheader"
) )
// random reader according to current-time source seed.
var randRd = rand.New(rand.NewSource(time.Now().Unix()))
type FollowTestSuite struct { type FollowTestSuite struct {
AccountStandardTestSuite AccountStandardTestSuite
} }
@ -69,6 +81,405 @@ func (suite *FollowTestSuite) TestFollowSelf() {
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
} }
func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit2() {
suite.testGetFollowersPage(2, "backward")
}
func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit4() {
suite.testGetFollowersPage(4, "backward")
}
func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit6() {
suite.testGetFollowersPage(6, "backward")
}
func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit2() {
suite.testGetFollowersPage(2, "forward")
}
func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit4() {
suite.testGetFollowersPage(4, "forward")
}
func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit6() {
suite.testGetFollowersPage(6, "forward")
}
func (suite *FollowTestSuite) testGetFollowersPage(limit int, direction string) {
ctx := context.Background()
// The authed local account we are going to use for HTTP requests
requestingAccount := suite.testAccounts["local_account_1"]
suite.clearAccountRelations(requestingAccount.ID)
// Get current time.
now := time.Now()
var i int
for _, targetAccount := range suite.testAccounts {
if targetAccount.ID == requestingAccount.ID {
// we cannot be our own target...
continue
}
// Get next simple ID.
id := strconv.Itoa(i)
i++
// put a follow in the database
err := suite.db.PutFollow(ctx, &gtsmodel.Follow{
ID: id,
CreatedAt: now,
UpdatedAt: now,
URI: fmt.Sprintf("%s/follow/%s", targetAccount.URI, id),
AccountID: targetAccount.ID,
TargetAccountID: requestingAccount.ID,
})
suite.NoError(err)
// Bump now by 1 second.
now = now.Add(time.Second)
}
// Get _ALL_ follows we expect to see without any paging (this filters invisible).
apiRsp, err := suite.processor.Account().FollowersGet(ctx, requestingAccount, requestingAccount.ID, nil)
suite.NoError(err)
expectAccounts := apiRsp.Items // interfaced{} account slice
// Iteratively set
// link query string.
var query string
switch direction {
case "backward":
// Set the starting query to page backward from newest.
acc := expectAccounts[0].(*model.Account)
newest, _ := suite.db.GetFollow(ctx, acc.ID, requestingAccount.ID)
expectAccounts = expectAccounts[1:]
query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID)
case "forward":
// Set the starting query to page forward from the oldest.
acc := expectAccounts[len(expectAccounts)-1].(*model.Account)
oldest, _ := suite.db.GetFollow(ctx, acc.ID, requestingAccount.ID)
expectAccounts = expectAccounts[:len(expectAccounts)-1]
query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID)
}
for p := 0; ; p++ {
// Prepare new request for endpoint
recorder := httptest.NewRecorder()
endpoint := fmt.Sprintf("/api/v1/accounts/%s/followers", requestingAccount.ID)
ctx := suite.newContext(recorder, http.MethodGet, []byte{}, endpoint, "")
ctx.Params = gin.Params{{Key: "id", Value: requestingAccount.ID}}
ctx.Request.URL.RawQuery = query // setting provided next query value
// call the handler and check for valid response code.
suite.T().Logf("direction=%q page=%d query=%q", direction, p, query)
suite.accountsModule.AccountFollowersGETHandler(ctx)
suite.Equal(http.StatusOK, recorder.Code)
var accounts []*model.Account
// Decode response body into API account models
result := recorder.Result()
dec := json.NewDecoder(result.Body)
err := dec.Decode(&accounts)
suite.NoError(err)
_ = result.Body.Close()
var (
// start provides the starting index for loop in accounts.
start func([]*model.Account) int
// iter performs the loop iter step with index.
iter func(int) int
// check performs the loop conditional check against index and accounts.
check func(int, []*model.Account) bool
// expect pulls the next account to check against from expectAccounts.
expect func([]interface{}) interface{}
// trunc drops the last checked account from expectAccounts.
trunc func([]interface{}) []interface{}
)
switch direction {
case "backward":
// When paging backwards (DESC) we:
// - iter from end of received accounts
// - iterate backward through received accounts
// - stop when we reach last index of received accounts
// - compare each received with the first index of expected accounts
// - after each compare, drop the first index of expected accounts
start = func([]*model.Account) int { return 0 }
iter = func(i int) int { return i + 1 }
check = func(idx int, i []*model.Account) bool { return idx < len(i) }
expect = func(i []interface{}) interface{} { return i[0] }
trunc = func(i []interface{}) []interface{} { return i[1:] }
case "forward":
// When paging forwards (ASC) we:
// - iter from end of received accounts
// - iterate backward through received accounts
// - stop when we reach first index of received accounts
// - compare each received with the last index of expected accounts
// - after each compare, drop the last index of expected accounts
start = func(i []*model.Account) int { return len(i) - 1 }
iter = func(i int) int { return i - 1 }
check = func(idx int, i []*model.Account) bool { return idx >= 0 }
expect = func(i []interface{}) interface{} { return i[len(i)-1] }
trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] }
}
for i := start(accounts); check(i, accounts); i = iter(i) {
// Get next expected account.
iface := expect(expectAccounts)
// Check that expected account matches received.
expectAccID := iface.(*model.Account).ID
receivdAccID := accounts[i].ID
suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p)
// Drop checked from expected accounts.
expectAccounts = trunc(expectAccounts)
}
if len(expectAccounts) == 0 {
// Reached end.
break
}
// Parse response link header values.
values := result.Header.Values("Link")
links := linkheader.ParseMultiple(values)
filteredLinks := links.FilterByRel("next")
suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p)
// A ref link header was set.
link := filteredLinks[0]
// Parse URI from URI string.
uri, err := url.Parse(link.URL)
suite.NoError(err)
// Set next raw query value.
query = uri.RawQuery
}
}
func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit2() {
suite.testGetFollowingPage(2, "backward")
}
func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit4() {
suite.testGetFollowingPage(4, "backward")
}
func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit6() {
suite.testGetFollowingPage(6, "backward")
}
func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit2() {
suite.testGetFollowingPage(2, "forward")
}
func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit4() {
suite.testGetFollowingPage(4, "forward")
}
func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit6() {
suite.testGetFollowingPage(6, "forward")
}
func (suite *FollowTestSuite) testGetFollowingPage(limit int, direction string) {
ctx := context.Background()
// The authed local account we are going to use for HTTP requests
requestingAccount := suite.testAccounts["local_account_1"]
suite.clearAccountRelations(requestingAccount.ID)
// Get current time.
now := time.Now()
var i int
for _, targetAccount := range suite.testAccounts {
if targetAccount.ID == requestingAccount.ID {
// we cannot be our own target...
continue
}
// Get next simple ID.
id := strconv.Itoa(i)
i++
// put a follow in the database
err := suite.db.PutFollow(ctx, &gtsmodel.Follow{
ID: id,
CreatedAt: now,
UpdatedAt: now,
URI: fmt.Sprintf("%s/follow/%s", requestingAccount.URI, id),
AccountID: requestingAccount.ID,
TargetAccountID: targetAccount.ID,
})
suite.NoError(err)
// Bump now by 1 second.
now = now.Add(time.Second)
}
// Get _ALL_ follows we expect to see without any paging (this filters invisible).
apiRsp, err := suite.processor.Account().FollowingGet(ctx, requestingAccount, requestingAccount.ID, nil)
suite.NoError(err)
expectAccounts := apiRsp.Items // interfaced{} account slice
// Iteratively set
// link query string.
var query string
switch direction {
case "backward":
// Set the starting query to page backward from newest.
acc := expectAccounts[0].(*model.Account)
newest, _ := suite.db.GetFollow(ctx, requestingAccount.ID, acc.ID)
expectAccounts = expectAccounts[1:]
query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID)
case "forward":
// Set the starting query to page forward from the oldest.
acc := expectAccounts[len(expectAccounts)-1].(*model.Account)
oldest, _ := suite.db.GetFollow(ctx, requestingAccount.ID, acc.ID)
expectAccounts = expectAccounts[:len(expectAccounts)-1]
query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID)
}
for p := 0; ; p++ {
// Prepare new request for endpoint
recorder := httptest.NewRecorder()
endpoint := fmt.Sprintf("/api/v1/accounts/%s/following", requestingAccount.ID)
ctx := suite.newContext(recorder, http.MethodGet, []byte{}, endpoint, "")
ctx.Params = gin.Params{{Key: "id", Value: requestingAccount.ID}}
ctx.Request.URL.RawQuery = query // setting provided next query value
// call the handler and check for valid response code.
suite.T().Logf("direction=%q page=%d query=%q", direction, p, query)
suite.accountsModule.AccountFollowingGETHandler(ctx)
suite.Equal(http.StatusOK, recorder.Code)
var accounts []*model.Account
// Decode response body into API account models
result := recorder.Result()
dec := json.NewDecoder(result.Body)
err := dec.Decode(&accounts)
suite.NoError(err)
_ = result.Body.Close()
var (
// start provides the starting index for loop in accounts.
start func([]*model.Account) int
// iter performs the loop iter step with index.
iter func(int) int
// check performs the loop conditional check against index and accounts.
check func(int, []*model.Account) bool
// expect pulls the next account to check against from expectAccounts.
expect func([]interface{}) interface{}
// trunc drops the last checked account from expectAccounts.
trunc func([]interface{}) []interface{}
)
switch direction {
case "backward":
// When paging backwards (DESC) we:
// - iter from end of received accounts
// - iterate backward through received accounts
// - stop when we reach last index of received accounts
// - compare each received with the first index of expected accounts
// - after each compare, drop the first index of expected accounts
start = func([]*model.Account) int { return 0 }
iter = func(i int) int { return i + 1 }
check = func(idx int, i []*model.Account) bool { return idx < len(i) }
expect = func(i []interface{}) interface{} { return i[0] }
trunc = func(i []interface{}) []interface{} { return i[1:] }
case "forward":
// When paging forwards (ASC) we:
// - iter from end of received accounts
// - iterate backward through received accounts
// - stop when we reach first index of received accounts
// - compare each received with the last index of expected accounts
// - after each compare, drop the last index of expected accounts
start = func(i []*model.Account) int { return len(i) - 1 }
iter = func(i int) int { return i - 1 }
check = func(idx int, i []*model.Account) bool { return idx >= 0 }
expect = func(i []interface{}) interface{} { return i[len(i)-1] }
trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] }
}
for i := start(accounts); check(i, accounts); i = iter(i) {
// Get next expected account.
iface := expect(expectAccounts)
// Check that expected account matches received.
expectAccID := iface.(*model.Account).ID
receivdAccID := accounts[i].ID
suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p)
// Drop checked from expected accounts.
expectAccounts = trunc(expectAccounts)
}
if len(expectAccounts) == 0 {
// Reached end.
break
}
// Parse response link header values.
values := result.Header.Values("Link")
links := linkheader.ParseMultiple(values)
filteredLinks := links.FilterByRel("next")
suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p)
// A ref link header was set.
link := filteredLinks[0]
// Parse URI from URI string.
uri, err := url.Parse(link.URL)
suite.NoError(err)
// Set next raw query value.
query = uri.RawQuery
}
}
func (suite *FollowTestSuite) clearAccountRelations(id string) {
// Esnure no account blocks exist between accounts.
_ = suite.db.DeleteAccountBlocks(
context.Background(),
id,
)
// Ensure no account follows exist between accounts.
_ = suite.db.DeleteAccountFollows(
context.Background(),
id,
)
// Ensure no account follow_requests exist between accounts.
_ = suite.db.DeleteAccountFollowRequests(
context.Background(),
id,
)
}
func TestFollowTestSuite(t *testing.T) { func TestFollowTestSuite(t *testing.T) {
suite.Run(t, new(FollowTestSuite)) suite.Run(t, new(FollowTestSuite))
} }

View file

@ -25,12 +25,20 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/paging"
) )
// AccountFollowersGETHandler swagger:operation GET /api/v1/accounts/{id}/followers accountFollowers // AccountFollowersGETHandler swagger:operation GET /api/v1/accounts/{id}/followers accountFollowers
// //
// See followers of account with given id. // See followers of account with given id.
// //
// The next and previous queries can be parsed from the returned Link header.
// Example:
//
// ```
// <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
// ````
//
// --- // ---
// tags: // tags:
// - accounts // - accounts
@ -45,6 +53,42 @@ import (
// description: Account ID. // description: Account ID.
// in: path // in: path
// required: true // required: true
// -
// name: max_id
// type: string
// description: >-
// Return only follower accounts *OLDER* than the given max ID.
// The follower account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: since_id
// type: string
// description: >-
// Return only follower accounts *NEWER* than the given since ID.
// The follower account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: min_id
// type: string
// description: >-
// Return only follower accounts *IMMEDIATELY NEWER* than the given min ID.
// The follower account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: limit
// type: integer
// description: Number of follower accounts to return.
// default: 40
// minimum: 1
// maximum: 80
// in: query
// required: false
// //
// security: // security:
// - OAuth2 Bearer: // - OAuth2 Bearer:
@ -87,11 +131,25 @@ func (m *Module) AccountFollowersGETHandler(c *gin.Context) {
return return
} }
followers, errWithCode := m.processor.Account().FollowersGet(c.Request.Context(), authed.Account, targetAcctID) page, errWithCode := paging.ParseIDPage(c,
1, // min limit
80, // max limit
40, // default limit
)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
c.JSON(http.StatusOK, followers) resp, errWithCode := m.processor.Account().FollowersGet(c.Request.Context(), authed.Account, targetAcctID, page)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if resp.LinkHeader != "" {
c.Header("Link", resp.LinkHeader)
}
c.JSON(http.StatusOK, resp.Items)
} }

View file

@ -25,12 +25,20 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/paging"
) )
// AccountFollowingGETHandler swagger:operation GET /api/v1/accounts/{id}/following accountFollowing // AccountFollowingGETHandler swagger:operation GET /api/v1/accounts/{id}/following accountFollowing
// //
// See accounts followed by given account id. // See accounts followed by given account id.
// //
// The next and previous queries can be parsed from the returned Link header.
// Example:
//
// ```
// <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
// ````
//
// --- // ---
// tags: // tags:
// - accounts // - accounts
@ -45,6 +53,42 @@ import (
// description: Account ID. // description: Account ID.
// in: path // in: path
// required: true // required: true
// -
// name: max_id
// type: string
// description: >-
// Return only following accounts *OLDER* than the given max ID.
// The following account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: since_id
// type: string
// description: >-
// Return only following accounts *NEWER* than the given since ID.
// The following account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: min_id
// type: string
// description: >-
// Return only following accounts *IMMEDIATELY NEWER* than the given min ID.
// The following account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: limit
// type: integer
// description: Number of following accounts to return.
// default: 40
// minimum: 1
// maximum: 80
// in: query
// required: false
// //
// security: // security:
// - OAuth2 Bearer: // - OAuth2 Bearer:
@ -87,11 +131,25 @@ func (m *Module) AccountFollowingGETHandler(c *gin.Context) {
return return
} }
following, errWithCode := m.processor.Account().FollowingGet(c.Request.Context(), authed.Account, targetAcctID) page, errWithCode := paging.ParseIDPage(c,
1, // min limit
80, // max limit
40, // default limit
)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
c.JSON(http.StatusOK, following) resp, errWithCode := m.processor.Account().FollowingGet(c.Request.Context(), authed.Account, targetAcctID, page)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if resp.LinkHeader != "" {
c.Header("Link", resp.LinkHeader)
}
c.JSON(http.StatusOK, resp.Items)
} }

View file

@ -47,25 +47,40 @@ import (
// //
// parameters: // parameters:
// - // -
// name: limit
// type: integer
// description: Number of blocks to return.
// default: 20
// in: query
// -
// name: max_id // name: max_id
// type: string // type: string
// description: >- // description: >-
// Return only blocks *OLDER* than the given block ID. // Return only blocked accounts *OLDER* than the given max ID.
// The block with the specified ID will not be included in the response. // The blocked account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal block, NOT any of the returned accounts.
// in: query // in: query
// required: false
// - // -
// name: since_id // name: since_id
// type: string // type: string
// description: >- // description: >-
// Return only blocks *NEWER* than the given block ID. // Return only blocked accounts *NEWER* than the given since ID.
// The block with the specified ID will not be included in the response. // The blocked account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal block, NOT any of the returned accounts.
// in: query // in: query
// -
// name: min_id
// type: string
// description: >-
// Return only blocked accounts *IMMEDIATELY NEWER* than the given min ID.
// The blocked account with the specified ID will not be included in the response.
// NOTE: the ID is of the internal block, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: limit
// type: integer
// description: Number of blocked accounts to return.
// default: 40
// minimum: 1
// maximum: 80
// in: query
// required: false
// //
// security: // security:
// - OAuth2 Bearer: // - OAuth2 Bearer:
@ -105,15 +120,15 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {
page, errWithCode := paging.ParseIDPage(c, page, errWithCode := paging.ParseIDPage(c,
1, // min limit 1, // min limit
100, // max limit 80, // max limit
20, // default limit 40, // default limit
) )
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
resp, errWithCode := m.processor.BlocksGet( resp, errWithCode := m.processor.Account().BlocksGet(
c.Request.Context(), c.Request.Context(),
authed.Account, authed.Account,
page, page,

View file

@ -87,7 +87,7 @@ func (m *Module) FollowRequestAuthorizePOSTHandler(c *gin.Context) {
return return
} }
relationship, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID) relationship, errWithCode := m.processor.Account().FollowRequestAccept(c.Request.Context(), authed.Account, originAccountID)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return

View file

@ -24,12 +24,19 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/paging"
) )
// FollowRequestGETHandler swagger:operation GET /api/v1/follow_requests getFollowRequests // FollowRequestGETHandler swagger:operation GET /api/v1/follow_requests getFollowRequests
// //
// Get an array of accounts that have requested to follow you. // Get an array of accounts that have requested to follow you.
// Accounts will be sorted in order of follow request date descending (newest first). //
// The next and previous queries can be parsed from the returned Link header.
// Example:
//
// ```
// <https://example.org/api/v1/follow_requests?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/follow_requests?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev"
// ````
// //
// --- // ---
// tags: // tags:
@ -40,11 +47,41 @@ import (
// //
// parameters: // parameters:
// - // -
// name: max_id
// type: string
// description: >-
// Return only follow requesting accounts *OLDER* than the given max ID.
// The follow requester with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow request, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: since_id
// type: string
// description: >-
// Return only follow requesting accounts *NEWER* than the given since ID.
// The follow requester with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow request, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: min_id
// type: string
// description: >-
// Return only follow requesting accounts *IMMEDIATELY NEWER* than the given min ID.
// The follow requester with the specified ID will not be included in the response.
// NOTE: the ID is of the internal follow request, NOT any of the returned accounts.
// in: query
// required: false
// -
// name: limit // name: limit
// type: integer // type: integer
// description: Number of accounts to return. // description: Number of follow requesting accounts to return.
// default: 40 // default: 40
// minimum: 1
// maximum: 80
// in: query // in: query
// required: false
// //
// security: // security:
// - OAuth2 Bearer: // - OAuth2 Bearer:
@ -82,11 +119,25 @@ func (m *Module) FollowRequestGETHandler(c *gin.Context) {
return return
} }
accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed) page, errWithCode := paging.ParseIDPage(c,
1, // min limit
80, // max limit
40, // default limit
)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
c.JSON(http.StatusOK, accts) resp, errWithCode := m.processor.Account().FollowRequestsGet(c.Request.Context(), authed.Account, page)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if resp.LinkHeader != "" {
c.Header("Link", resp.LinkHeader)
}
c.JSON(http.StatusOK, resp.Items)
} }

View file

@ -22,17 +22,25 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strconv"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/tomnomnom/linkheader"
) )
// random reader according to current-time source seed.
var randRd = rand.New(rand.NewSource(time.Now().Unix()))
type GetTestSuite struct { type GetTestSuite struct {
FollowRequestStandardTestSuite FollowRequestStandardTestSuite
} }
@ -68,7 +76,7 @@ func (suite *GetTestSuite) TestGet() {
defer result.Body.Close() defer result.Body.Close()
// check the response // check the response
b, err := ioutil.ReadAll(result.Body) b, err := io.ReadAll(result.Body)
assert.NoError(suite.T(), err) assert.NoError(suite.T(), err)
dst := new(bytes.Buffer) dst := new(bytes.Buffer)
err = json.Indent(dst, b, "", " ") err = json.Indent(dst, b, "", " ")
@ -99,6 +107,214 @@ func (suite *GetTestSuite) TestGet() {
]`, dst.String()) ]`, dst.String())
} }
func (suite *GetTestSuite) TestGetPageBackwardLimit2() {
suite.testGetPage(2, "backward")
}
func (suite *GetTestSuite) TestGetPageBackwardLimit4() {
suite.testGetPage(4, "backward")
}
func (suite *GetTestSuite) TestGetPageBackwardLimit6() {
suite.testGetPage(6, "backward")
}
func (suite *GetTestSuite) TestGetPageForwardLimit2() {
suite.testGetPage(2, "forward")
}
func (suite *GetTestSuite) TestGetPageForwardLimit4() {
suite.testGetPage(4, "forward")
}
func (suite *GetTestSuite) TestGetPageForwardLimit6() {
suite.testGetPage(6, "forward")
}
func (suite *GetTestSuite) testGetPage(limit int, direction string) {
ctx := context.Background()
// The authed local account we are going to use for HTTP requests
requestingAccount := suite.testAccounts["local_account_1"]
suite.clearAccountRelations(requestingAccount.ID)
// Get current time.
now := time.Now()
var i int
for _, targetAccount := range suite.testAccounts {
if targetAccount.ID == requestingAccount.ID {
// we cannot be our own target...
continue
}
// Get next simple ID.
id := strconv.Itoa(i)
i++
// put a follow request in the database
err := suite.db.PutFollowRequest(ctx, &gtsmodel.FollowRequest{
ID: id,
CreatedAt: now,
UpdatedAt: now,
URI: fmt.Sprintf("%s/follow/%s", targetAccount.URI, id),
AccountID: targetAccount.ID,
TargetAccountID: requestingAccount.ID,
})
suite.NoError(err)
// Bump now by 1 second.
now = now.Add(time.Second)
}
// Get _ALL_ follow requests we expect to see without any paging (this filters invisible).
apiRsp, err := suite.processor.Account().FollowRequestsGet(ctx, requestingAccount, nil)
suite.NoError(err)
expectAccounts := apiRsp.Items // interfaced{} account slice
// Iteratively set
// link query string.
var query string
switch direction {
case "backward":
// Set the starting query to page backward from newest.
acc := expectAccounts[0].(*model.Account)
newest, _ := suite.db.GetFollowRequest(ctx, acc.ID, requestingAccount.ID)
expectAccounts = expectAccounts[1:]
query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID)
case "forward":
// Set the starting query to page forward from the oldest.
acc := expectAccounts[len(expectAccounts)-1].(*model.Account)
oldest, _ := suite.db.GetFollowRequest(ctx, acc.ID, requestingAccount.ID)
expectAccounts = expectAccounts[:len(expectAccounts)-1]
query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID)
}
for p := 0; ; p++ {
// Prepare new request for endpoint
recorder := httptest.NewRecorder()
ctx := suite.newContext(recorder, http.MethodGet, []byte{}, "/api/v1/follow_requests", "")
ctx.Request.URL.RawQuery = query // setting provided next query value
// call the handler and check for valid response code.
suite.T().Logf("direction=%q page=%d query=%q", direction, p, query)
suite.followRequestModule.FollowRequestGETHandler(ctx)
suite.Equal(http.StatusOK, recorder.Code)
var accounts []*model.Account
// Decode response body into API account models
result := recorder.Result()
dec := json.NewDecoder(result.Body)
err := dec.Decode(&accounts)
suite.NoError(err)
_ = result.Body.Close()
var (
// start provides the starting index for loop in accounts.
start func([]*model.Account) int
// iter performs the loop iter step with index.
iter func(int) int
// check performs the loop conditional check against index and accounts.
check func(int, []*model.Account) bool
// expect pulls the next account to check against from expectAccounts.
expect func([]interface{}) interface{}
// trunc drops the last checked account from expectAccounts.
trunc func([]interface{}) []interface{}
)
switch direction {
case "backward":
// When paging backwards (DESC) we:
// - iter from end of received accounts
// - iterate backward through received accounts
// - stop when we reach last index of received accounts
// - compare each received with the first index of expected accounts
// - after each compare, drop the first index of expected accounts
start = func([]*model.Account) int { return 0 }
iter = func(i int) int { return i + 1 }
check = func(idx int, i []*model.Account) bool { return idx < len(i) }
expect = func(i []interface{}) interface{} { return i[0] }
trunc = func(i []interface{}) []interface{} { return i[1:] }
case "forward":
// When paging forwards (ASC) we:
// - iter from end of received accounts
// - iterate backward through received accounts
// - stop when we reach first index of received accounts
// - compare each received with the last index of expected accounts
// - after each compare, drop the last index of expected accounts
start = func(i []*model.Account) int { return len(i) - 1 }
iter = func(i int) int { return i - 1 }
check = func(idx int, i []*model.Account) bool { return idx >= 0 }
expect = func(i []interface{}) interface{} { return i[len(i)-1] }
trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] }
}
for i := start(accounts); check(i, accounts); i = iter(i) {
// Get next expected account.
iface := expect(expectAccounts)
// Check that expected account matches received.
expectAccID := iface.(*model.Account).ID
receivdAccID := accounts[i].ID
suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p)
// Drop checked from expected accounts.
expectAccounts = trunc(expectAccounts)
}
if len(expectAccounts) == 0 {
// Reached end.
break
}
// Parse response link header values.
values := result.Header.Values("Link")
links := linkheader.ParseMultiple(values)
filteredLinks := links.FilterByRel("next")
suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p)
// A ref link header was set.
link := filteredLinks[0]
// Parse URI from URI string.
uri, err := url.Parse(link.URL)
suite.NoError(err)
// Set next raw query value.
query = uri.RawQuery
}
}
func (suite *GetTestSuite) clearAccountRelations(id string) {
// Esnure no account blocks exist between accounts.
_ = suite.db.DeleteAccountBlocks(
context.Background(),
id,
)
// Ensure no account follows exist between accounts.
_ = suite.db.DeleteAccountFollows(
context.Background(),
id,
)
// Ensure no account follow_requests exist between accounts.
_ = suite.db.DeleteAccountFollowRequests(
context.Background(),
id,
)
}
func TestGetTestSuite(t *testing.T) { func TestGetTestSuite(t *testing.T) {
suite.Run(t, &GetTestSuite{}) suite.Run(t, &GetTestSuite{})
} }

View file

@ -85,7 +85,7 @@ func (m *Module) FollowRequestRejectPOSTHandler(c *gin.Context) {
return return
} }
relationship, errWithCode := m.processor.FollowRequestReject(c.Request.Context(), authed, originAccountID) relationship, errWithCode := m.processor.Account().FollowRequestReject(c.Request.Context(), authed.Account, originAccountID)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return

View file

@ -102,8 +102,8 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount
return &rel, nil return &rel, nil
} }
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) {
followIDs, err := r.getAccountFollowIDs(ctx, accountID) followIDs, err := r.getAccountFollowIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,8 +118,8 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s
return r.GetFollowsByIDs(ctx, followIDs) return r.GetFollowsByIDs(ctx, followIDs)
} }
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) {
followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,16 +134,16 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID
return r.GetFollowsByIDs(ctx, followerIDs) return r.GetFollowsByIDs(ctx, followerIDs)
} }
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) {
followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r.GetFollowRequestsByIDs(ctx, followReqIDs) return r.GetFollowRequestsByIDs(ctx, followReqIDs)
} }
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) {
followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -151,39 +151,15 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account
} }
func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) {
// Load block IDs from cache with database loader callback. blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page)
blockIDs, err := r.state.Caches.GTS.BlockIDs().Load(accountID, func() ([]string, error) {
var blockIDs []string
// Block IDs not in cache, perform DB query!
q := newSelectBlocks(r.db, accountID)
if _, err := q.Exec(ctx, &blockIDs); err != nil {
return nil, err
}
return blockIDs, nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Our cached / selected block IDs are
// ALWAYS stored in descending order.
// Depending on the paging requested
// this may be an unexpected order.
if !page.GetOrder().Ascending() {
blockIDs = paging.Reverse(blockIDs)
}
// Page the resulting block IDs.
blockIDs = page.Page(blockIDs)
// Convert these IDs to full block objects.
return r.GetBlocksByIDs(ctx, blockIDs) return r.GetBlocksByIDs(ctx, blockIDs)
} }
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
followIDs, err := r.getAccountFollowIDs(ctx, accountID) followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil)
return len(followIDs), err return len(followIDs), err
} }
@ -193,7 +169,7 @@ func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID
} }
func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil)
return len(followerIDs), err return len(followerIDs), err
} }
@ -203,17 +179,22 @@ func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, account
} }
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil)
return len(followReqIDs), err return len(followReqIDs), err
} }
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil)
return len(followReqIDs), err return len(followReqIDs), err
} }
func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) { func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) {
return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) { blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil)
return len(blockIDs), err
}
func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) {
var followIDs []string var followIDs []string
// Follow IDs not in cache, perform DB query! // Follow IDs not in cache, perform DB query!
@ -240,8 +221,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
}) })
} }
func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) { func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) { return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) {
var followIDs []string var followIDs []string
// Follow IDs not in cache, perform DB query! // Follow IDs not in cache, perform DB query!
@ -268,8 +249,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
}) })
} }
func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) { func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) { return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) {
var followReqIDs []string var followReqIDs []string
// Follow request IDs not in cache, perform DB query! // Follow request IDs not in cache, perform DB query!
@ -282,8 +263,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
}) })
} }
func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) { func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) { return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) {
var followReqIDs []string var followReqIDs []string
// Follow request IDs not in cache, perform DB query! // Follow request IDs not in cache, perform DB query!
@ -296,13 +277,27 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
}) })
} }
func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) {
var blockIDs []string
// Block IDs not in cache, perform DB query!
q := newSelectBlocks(r.db, accountID)
if _, err := q.Exec(ctx, &blockIDs); err != nil {
return nil, err
}
return blockIDs, nil
})
}
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery {
return db.NewSelect(). return db.NewSelect().
TableExpr("?", bun.Ident("follow_requests")). TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")). ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID). Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at")) OrderExpr("? DESC", bun.Ident("id"))
} }
// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
@ -311,7 +306,7 @@ func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery {
TableExpr("?", bun.Ident("follow_requests")). TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")). ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID). Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at")) OrderExpr("? DESC", bun.Ident("id"))
} }
// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
@ -320,7 +315,7 @@ func newSelectFollows(db *DB, accountID string) *bun.SelectQuery {
Table("follows"). Table("follows").
Column("id"). Column("id").
Where("? = ?", bun.Ident("account_id"), accountID). Where("? = ?", bun.Ident("account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at")) OrderExpr("? DESC", bun.Ident("id"))
} }
// newSelectLocalFollows returns a new select query for all rows in the follows table with // newSelectLocalFollows returns a new select query for all rows in the follows table with
@ -338,7 +333,7 @@ func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery {
Column("id"). Column("id").
Where("? IS NULL", bun.Ident("domain")), Where("? IS NULL", bun.Ident("domain")),
). ).
OrderExpr("? DESC", bun.Ident("updated_at")) OrderExpr("? DESC", bun.Ident("id"))
} }
// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
@ -347,7 +342,7 @@ func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery {
Table("follows"). Table("follows").
Column("id"). Column("id").
Where("? = ?", bun.Ident("target_account_id"), accountID). Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at")) OrderExpr("? DESC", bun.Ident("id"))
} }
// newSelectLocalFollowers returns a new select query for all rows in the follows table with // newSelectLocalFollowers returns a new select query for all rows in the follows table with
@ -365,14 +360,14 @@ func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery {
Column("id"). Column("id").
Where("? IS NULL", bun.Ident("domain")), Where("? IS NULL", bun.Ident("domain")),
). ).
OrderExpr("? DESC", bun.Ident("updated_at")) OrderExpr("? DESC", bun.Ident("id"))
} }
// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID.
func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery { func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery {
return db.NewSelect(). return db.NewSelect().
TableExpr("?", bun.Ident("blocks")). TableExpr("?", bun.Ident("blocks")).
ColumnExpr("?", bun.Ident("?")). ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("account_id"), accountID). Where("? = ?", bun.Ident("account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at")) OrderExpr("? DESC", bun.Ident("id"))
} }

View file

@ -753,14 +753,14 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }
followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID, nil)
suite.NoError(err) suite.NoError(err)
suite.Len(followRequests, 1) suite.Len(followRequests, 1)
} }
func (suite *RelationshipTestSuite) TestGetAccountFollows() { func (suite *RelationshipTestSuite) TestGetAccountFollows() {
account := suite.testAccounts["local_account_1"] account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) follows, err := suite.db.GetAccountFollows(context.Background(), account.ID, nil)
suite.NoError(err) suite.NoError(err)
suite.Len(follows, 2) suite.Len(follows, 2)
} }
@ -781,7 +781,7 @@ func (suite *RelationshipTestSuite) TestCountAccountFollows() {
func (suite *RelationshipTestSuite) TestGetAccountFollowers() { func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
account := suite.testAccounts["local_account_1"] account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID) follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil)
suite.NoError(err) suite.NoError(err)
suite.Len(follows, 2) suite.Len(follows, 2)
} }

View file

@ -114,6 +114,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
follows, err := t.state.DB.GetAccountFollows( follows, err := t.state.DB.GetAccountFollows(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
accountID, accountID,
nil, // select all
) )
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.Newf("db error getting follows for account %s: %w", accountID, err) return nil, gtserror.Newf("db error getting follows for account %s: %w", accountID, err)

View file

@ -167,8 +167,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {
follows, err := suite.state.DB.GetAccountFollows( follows, err := suite.state.DB.GetAccountFollows(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
viewingAccount.ID, viewingAccount.ID,
nil, // select all
) )
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }

View file

@ -20,7 +20,9 @@ package bundb
import ( import (
"strings" "strings"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -83,6 +85,29 @@ func whereStartsLike(
) )
} }
// loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs.
// NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order.
func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) {
// Check cache for IDs, else load.
ids, err := cache.Load(key, loadDESC)
if err != nil {
return nil, err
}
// Our cached / selected IDs are ALWAYS
// fetched from `loadDESC` in descending
// order. Depending on the paging requested
// this may be an unexpected order.
if page.GetOrder().Ascending() {
ids = paging.Reverse(ids)
}
// Page the resulting IDs.
ids = page.Page(ids)
return ids, nil
}
// updateWhere parses []db.Where and adds it to the given update query. // updateWhere parses []db.Where and adds it to the given update query.
func updateWhere(q *bun.UpdateQuery, where []db.Where) { func updateWhere(q *bun.UpdateQuery, where []db.Where) {
for _, w := range where { for _, w := range where {

View file

@ -138,43 +138,46 @@ type Relationship interface {
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error
// GetAccountFollows returns a slice of follows owned by the given accountID. // GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)
// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance. // GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// GetAccountFollowers fetches follows that target given accountID.
GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error)
// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error)
// GetAccountFollowRequesting returns all follow requests originating from the given account.
GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error)
// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters.
GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error)
// CountAccountFollows returns the amount of accounts that the given accountID is following. // CountAccountFollows returns the amount of accounts that the given accountID is following.
CountAccountFollows(ctx context.Context, accountID string) (int, error) CountAccountFollows(ctx context.Context, accountID string) (int, error)
// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance. // CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
// GetAccountFollowers fetches follows that target given accountID.
GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// CountAccountFollowers returns the amounts that the given ID is followed by. // CountAccountFollowers returns the amounts that the given ID is followed by.
CountAccountFollowers(ctx context.Context, accountID string) (int, error) CountAccountFollowers(ctx context.Context, accountID string) (int, error)
// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance. // CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
// GetAccountFollowRequesting returns all follow requests originating from the given account.
GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
// CountAccountFollowRequests returns number of follow requests targeting the given account. // CountAccountFollowRequests returns number of follow requests targeting the given account.
CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
// CountAccountFollowerRequests returns number of follow requests originating from the given account. // CountAccountFollowerRequests returns number of follow requests originating from the given account.
CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. // CountAccountBlocks ...
GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) CountAccountBlocks(ctx context.Context, accountID string) (int, error)
// GetNote gets a private note from a source account on a target account, if it exists. // GetNote gets a private note from a source account on a target account, if it exists.
GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error)

View file

@ -38,7 +38,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
return nil, err return nil, err
} }
follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID) follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err) return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err)
} }

View file

@ -38,7 +38,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
return nil, err return nil, err
} }
follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID) follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("Following: db error getting following for account id %s: %w", acct.ID, err) return nil, fmt.Errorf("Following: db error getting following for account id %s: %w", acct.ID, err)
} }

View file

@ -47,8 +47,8 @@ func (suite *FollowingTestSuite) TestGetFollowing() {
suite.Equal(`{ suite.Equal(`{
"@context": "https://www.w3.org/ns/activitystreams", "@context": "https://www.w3.org/ns/activitystreams",
"items": [ "items": [
"http://localhost:8080/users/admin", "http://localhost:8080/users/1happyturtle",
"http://localhost:8080/users/1happyturtle" "http://localhost:8080/users/admin"
], ],
"type": "Collection" "type": "Collection"
}`, string(fJson)) }`, string(fJson))

View file

@ -89,7 +89,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err) return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err)
} }
follows, err := f.state.DB.GetAccountFollowers(c, account.ID) follows, err := f.state.DB.GetAccountFollowers(c, account.ID, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err) return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err)
} }

View file

@ -17,10 +17,10 @@
package paging package paging
// MinID returns an ID boundary with given min ID value, // EitherMinID returns an ID boundary with given min ID value,
// using either the `since_id`,"DESC" name,ordering or // using either the `since_id`,"DESC" name,ordering or
// `min_id`,"ASC" name,ordering depending on which is set. // `min_id`,"ASC" name,ordering depending on which is set.
func MinID(minID, sinceID string) Boundary { func EitherMinID(minID, sinceID string) Boundary {
/* /*
Paging with `since_id` vs `min_id`: Paging with `since_id` vs `min_id`:
@ -47,19 +47,29 @@ func MinID(minID, sinceID string) Boundary {
*/ */
switch { switch {
case minID != "": case minID != "":
return Boundary{ return MinID(minID)
Name: "min_id",
Value: minID,
Order: OrderAscending,
}
default: default:
// default min is `since_id` // default min is `since_id`
return SinceID(sinceID)
}
}
// SinceID ...
func SinceID(sinceID string) Boundary {
return Boundary{ return Boundary{
Name: "since_id", Name: "since_id",
Value: sinceID, Value: sinceID,
Order: OrderDescending, Order: OrderDescending,
} }
} }
// MinID ...
func MinID(minID string) Boundary {
return Boundary{
Name: "min_id",
Value: minID,
Order: OrderAscending,
}
} }
// MaxID returns an ID boundary with given max // MaxID returns an ID boundary with given max
@ -111,7 +121,7 @@ func (b Boundary) new(value string) Boundary {
// Find finds the boundary's set value in input slice, or returns -1. // Find finds the boundary's set value in input slice, or returns -1.
func (b Boundary) Find(in []string) int { func (b Boundary) Find(in []string) int {
if zero(b.Value) { if b.Value == "" {
return -1 return -1
} }
for i := range in { for i := range in {
@ -121,15 +131,3 @@ func (b Boundary) Find(in []string) int {
} }
return -1 return -1
} }
// Query returns this boundary as assembled query key=value pair.
func (b Boundary) Query() string {
switch {
case zero(b.Value):
return ""
case b.Name == "":
panic("value without boundary name")
default:
return b.Name + "=" + b.Value
}
}

View file

@ -20,7 +20,6 @@ package paging
import ( import (
"net/url" "net/url"
"strconv" "strconv"
"strings"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@ -70,26 +69,10 @@ func (p *Page) GetOrder() Order {
} }
func (p *Page) order() Order { func (p *Page) order() Order {
var (
// Check if min/max values set.
minValue = zero(p.Min.Value)
maxValue = zero(p.Max.Value)
// Check if min/max orders set.
minOrder = (p.Min.Order != 0)
maxOrder = (p.Max.Order != 0)
)
switch { switch {
// Boundaries with a value AND order set case p.Min.Order != 0:
// take priority. Min always comes first.
case minValue && minOrder:
return p.Min.Order return p.Min.Order
case maxValue && maxOrder: case p.Max.Order != 0:
return p.Max.Order
case minOrder:
return p.Min.Order
case maxOrder:
return p.Max.Order return p.Max.Order
default: default:
return 0 return 0
@ -108,31 +91,9 @@ func (p *Page) Page(in []string) []string {
return in return in
} }
if o := p.order(); !o.Ascending() { if p.order().Ascending() {
// Default sort is descending,
// catching all cases when NOT
// ascending (even zero value).
//
// NOTE: sorted data does not always
// occur according to string ineqs
// so we unfortunately cannot check.
if maxIdx := p.Max.Find(in); maxIdx != -1 {
// Reslice skipping up to max.
in = in[maxIdx+1:]
}
if minIdx := p.Min.Find(in); minIdx != -1 {
// Reslice stripping past min.
in = in[:minIdx]
}
} else {
// Sort type is ascending, input // Sort type is ascending, input
// data is assumed to be ascending. // data is assumed to be ascending.
//
// NOTE: sorted data does not always
// occur according to string ineqs
// so we unfortunately cannot check.
if minIdx := p.Min.Find(in); minIdx != -1 { if minIdx := p.Min.Find(in); minIdx != -1 {
// Reslice skipping up to min. // Reslice skipping up to min.
@ -144,6 +105,11 @@ func (p *Page) Page(in []string) []string {
in = in[:maxIdx] in = in[:maxIdx]
} }
if p.Limit > 0 && p.Limit < len(in) {
// Reslice input to limit.
in = in[:p.Limit]
}
if len(in) > 1 { if len(in) > 1 {
// Clone input before // Clone input before
// any modifications. // any modifications.
@ -153,20 +119,34 @@ func (p *Page) Page(in []string) []string {
// ALWAYS be descending. // ALWAYS be descending.
in = Reverse(in) in = Reverse(in)
} }
} else {
// Default sort is descending,
// catching all cases when NOT
// ascending (even zero value).
if maxIdx := p.Max.Find(in); maxIdx != -1 {
// Reslice skipping up to max.
in = in[maxIdx+1:]
}
if minIdx := p.Min.Find(in); minIdx != -1 {
// Reslice stripping past min.
in = in[:minIdx]
} }
if p.Limit > 0 && p.Limit < len(in) { if p.Limit > 0 && p.Limit < len(in) {
// Reslice input to limit. // Reslice input to limit.
in = in[:p.Limit] in = in[:p.Limit]
} }
}
return in return in
} }
// Next creates a new instance for the next returnable page, using // Next creates a new instance for the next returnable page, using
// given max value. This preserves original limit and max key name. // given max value. This preserves original limit and max key name.
func (p *Page) Next(max string) *Page { func (p *Page) Next(lo, hi string) *Page {
if p == nil || max == "" { if p == nil || lo == "" || hi == "" {
// no paging. // no paging.
return nil return nil
} }
@ -177,16 +157,27 @@ func (p *Page) Next(max string) *Page {
// Set original limit. // Set original limit.
p2.Limit = p.Limit p2.Limit = p.Limit
// Create new from old. if p.order().Ascending() {
p2.Max = p.Max.new(max) // When ascending, next page
// needs to start with min at
// the next highest value.
p2.Min = p.Min.new(hi)
p2.Max = p.Max.new("")
} else {
// When descending, next page
// needs to start with max at
// the next lowest value.
p2.Min = p.Min.new("")
p2.Max = p.Max.new(lo)
}
return p2 return p2
} }
// Prev creates a new instance for the prev returnable page, using // Prev creates a new instance for the prev returnable page, using
// given min value. This preserves original limit and min key name. // given min value. This preserves original limit and min key name.
func (p *Page) Prev(min string) *Page { func (p *Page) Prev(lo, hi string) *Page {
if p == nil || min == "" { if p == nil || lo == "" || hi == "" {
// no paging. // no paging.
return nil return nil
} }
@ -197,55 +188,56 @@ func (p *Page) Prev(min string) *Page {
// Set original limit. // Set original limit.
p2.Limit = p.Limit p2.Limit = p.Limit
// Create new from old. if p.order().Ascending() {
p2.Min = p.Min.new(min) // When ascending, prev page
// needs to start with max at
// the next lowest value.
p2.Min = p.Min.new("")
p2.Max = p.Max.new(lo)
} else {
// When descending, next page
// needs to start with max at
// the next lowest value.
p2.Min = p.Min.new(hi)
p2.Max = p.Max.new("")
}
return p2 return p2
} }
// ToLink builds a URL link for given endpoint information and extra query parameters, // ToLink builds a URL link for given endpoint information and extra query parameters,
// appending this Page's minimum / maximum boundaries and available limit (if any). // appending this Page's minimum / maximum boundaries and available limit (if any).
func (p *Page) ToLink(proto, host, path string, queryParams []string) string { func (p *Page) ToLink(proto, host, path string, queryParams url.Values) string {
if p == nil { if p == nil {
// no paging. // no paging.
return "" return ""
} }
// Check length before if queryParams == nil {
// adding boundary params. // Allocate new query parameters.
old := len(queryParams) queryParams = make(url.Values)
}
if minParam := p.Min.Query(); minParam != "" { if p.Min.Value != "" {
// A page-minimum query parameter is available. // A page-minimum query parameter is available.
queryParams = append(queryParams, minParam) queryParams.Add(p.Min.Name, p.Min.Value)
} }
if maxParam := p.Max.Query(); maxParam != "" { if p.Max.Value != "" {
// A page-maximum query parameter is available. // A page-maximum query parameter is available.
queryParams = append(queryParams, maxParam) queryParams.Add(p.Max.Name, p.Max.Value)
}
if len(queryParams) == old {
// No page boundaries.
return ""
} }
if p.Limit > 0 { if p.Limit > 0 {
// Build limit key-value query parameter. // A page limit query parameter is available.
param := "limit=" + strconv.Itoa(p.Limit) queryParams.Add("limit", strconv.Itoa(p.Limit))
// Append `limit=$value` query parameter.
queryParams = append(queryParams, param)
} }
// Join collected params into query str.
query := strings.Join(queryParams, "&")
// Build URL string. // Build URL string.
return (&url.URL{ return (&url.URL{
Scheme: proto, Scheme: proto,
Host: host, Host: host,
Path: path, Path: path,
RawQuery: query, RawQuery: queryParams.Encode(),
}).String() }).String()
} }

View file

@ -97,7 +97,7 @@ var cases = []Case{
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
Min: paging.MinID(minID, ""), Min: paging.MinID(minID),
Max: paging.MaxID(maxID), Max: paging.MaxID(maxID),
}, expect }, expect
}), }),
@ -129,7 +129,7 @@ var cases = []Case{
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
Min: paging.MinID(minID, ""), Min: paging.MinID(minID),
Max: paging.MaxID(maxID), Max: paging.MaxID(maxID),
Limit: limit, Limit: limit,
}, expect }, expect
@ -156,7 +156,7 @@ var cases = []Case{
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
Min: paging.MinID(minID, ""), Min: paging.MinID(minID),
Max: paging.MaxID(maxID), Max: paging.MaxID(maxID),
Limit: len(ids) * 2, Limit: len(ids) * 2,
}, expect }, expect
@ -182,7 +182,7 @@ var cases = []Case{
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
Min: paging.MinID("", sinceID), Min: paging.SinceID(sinceID),
Max: paging.MaxID(maxID), Max: paging.MaxID(maxID),
}, expect }, expect
}), }),
@ -225,7 +225,7 @@ var cases = []Case{
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
Min: paging.MinID("", sinceID), Min: paging.SinceID(sinceID),
}, expect }, expect
}), }),
CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) { CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) {
@ -247,7 +247,7 @@ var cases = []Case{
// Return page and expected IDs. // Return page and expected IDs.
return ids, &paging.Page{ return ids, &paging.Page{
Min: paging.MinID(minID, ""), Min: paging.MinID(minID),
}, expect }, expect
}), }),
} }

View file

@ -30,9 +30,9 @@ import (
// While conversely, a zero default limit will not enforce paging, returning a nil page value. // While conversely, a zero default limit will not enforce paging, returning a nil page value.
func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) { func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) {
// Extract request query params. // Extract request query params.
sinceID := c.Query("since_id") sinceID, haveSince := c.GetQuery("since_id")
minID := c.Query("min_id") minID, haveMin := c.GetQuery("min_id")
maxID := c.Query("max_id") maxID, haveMax := c.GetQuery("max_id")
// Extract request limit parameter. // Extract request limit parameter.
limit, errWithCode := ParseLimit(c, min, max, _default) limit, errWithCode := ParseLimit(c, min, max, _default)
@ -40,20 +40,38 @@ func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCo
return nil, errWithCode return nil, errWithCode
} }
if sinceID == "" && switch {
minID == "" && case haveMin:
maxID == "" && // A min_id was supplied, even if the value
limit == 0 { // itself is empty. This indicates ASC order.
// No ID paging params provided, and no default
// limit value which indicates paging not enforced.
return nil, nil
}
return &Page{ return &Page{
Min: MinID(minID, sinceID), Min: MinID(minID),
Max: MaxID(maxID), Max: MaxID(maxID),
Limit: limit, Limit: limit,
}, nil }, nil
case haveMax || haveSince:
// A max_id or since_id was supplied, even if the
// value itself is empty. This indicates DESC order.
return &Page{
Min: SinceID(sinceID),
Max: MaxID(maxID),
Limit: limit,
}, nil
case limit == 0:
// No ID paging params provided, and no default
// limit value which indicates paging not enforced.
return nil, nil
default:
// only limit.
return &Page{
Min: SinceID(""),
Max: MaxID(""),
Limit: limit,
}, nil
}
} }
// ParseShortcodeDomainPage parses an emoji shortcode domain Page from a request context, returning BadRequest // ParseShortcodeDomainPage parses an emoji shortcode domain Page from a request context, returning BadRequest
@ -62,8 +80,8 @@ func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCo
// a zero default limit will not enforce paging, returning a nil page value. // a zero default limit will not enforce paging, returning a nil page value.
func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) { func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) {
// Extract request query parameters. // Extract request query parameters.
minShortcode := c.Query("min_shortcode_domain") minShortcode, haveMin := c.GetQuery("min_shortcode_domain")
maxShortcode := c.Query("max_shortcode_domain") maxShortcode, haveMax := c.GetQuery("max_shortcode_domain")
// Extract request limit parameter. // Extract request limit parameter.
limit, errWithCode := ParseLimit(c, min, max, _default) limit, errWithCode := ParseLimit(c, min, max, _default)
@ -71,8 +89,8 @@ func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gt
return nil, errWithCode return nil, errWithCode
} }
if minShortcode == "" && if !haveMin &&
maxShortcode == "" && !haveMax &&
limit == 0 { limit == 0 {
// No ID paging params provided, and no default // No ID paging params provided, and no default
// limit value which indicates paging not enforced. // limit value which indicates paging not enforced.
@ -89,7 +107,10 @@ func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gt
// ParseLimit parses the limit query parameter from a request context, returning BadRequest on error parsing and _default if zero limit given. // ParseLimit parses the limit query parameter from a request context, returning BadRequest on error parsing and _default if zero limit given.
func ParseLimit(c *gin.Context, min, max, _default int) (int, gtserror.WithCode) { func ParseLimit(c *gin.Context, min, max, _default int) (int, gtserror.WithCode) {
// Get limit query param. // Get limit query param.
str := c.Query("limit") str, ok := c.GetQuery("limit")
if !ok {
return _default, nil
}
// Attempt to parse limit int. // Attempt to parse limit int.
i, err := strconv.Atoi(str) i, err := strconv.Atoi(str)

View file

@ -18,6 +18,7 @@
package paging package paging
import ( import (
"net/url"
"strings" "strings"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@ -35,18 +36,13 @@ type ResponseParams struct {
Path string // path to use for next/prev queries in the link header Path string // path to use for next/prev queries in the link header
Next *Page // page details for the next page Next *Page // page details for the next page
Prev *Page // page details for the previous page Prev *Page // page details for the previous page
Query []string // any extra query parameters to provide in the link header, should be in the format 'example=value' Query url.Values // any extra query parameters to provide in the link header, should be in the format 'example=value'
} }
// PackageResponse is a convenience function for returning // PackageResponse is a convenience function for returning
// a bunch of pageable items (notifications, statuses, etc), as well // a bunch of pageable items (notifications, statuses, etc), as well
// as a Link header to inform callers of where to find next/prev items. // as a Link header to inform callers of where to find next/prev items.
func PackageResponse(params ResponseParams) *apimodel.PageableResponse { func PackageResponse(params ResponseParams) *apimodel.PageableResponse {
if len(params.Items) == 0 {
// No items to page through.
return EmptyResponse()
}
var ( var (
// Extract paging params. // Extract paging params.
nextPg = params.Next nextPg = params.Next

View file

@ -42,9 +42,9 @@ func (suite *PagingSuite) TestPagingStandard() {
resp := paging.PackageResponse(params) resp := paging.PackageResponse(params)
suite.Equal(make([]interface{}, 10, 10), resp.Items) suite.Equal(make([]interface{}, 10, 10), resp.Items)
suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10>; rel="next", <https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10>; rel="prev"`, resp.LinkHeader) suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next", <https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader)
suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10`, resp.NextLink) suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink)
suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10`, resp.PrevLink) suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink)
} }
func (suite *PagingSuite) TestPagingNoLimit() { func (suite *PagingSuite) TestPagingNoLimit() {
@ -77,9 +77,9 @@ func (suite *PagingSuite) TestPagingNoNextID() {
resp := paging.PackageResponse(params) resp := paging.PackageResponse(params)
suite.Equal(make([]interface{}, 10, 10), resp.Items) suite.Equal(make([]interface{}, 10, 10), resp.Items)
suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10>; rel="prev"`, resp.LinkHeader) suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader)
suite.Equal(``, resp.NextLink) suite.Equal(``, resp.NextLink)
suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10`, resp.PrevLink) suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink)
} }
func (suite *PagingSuite) TestPagingNoPrevID() { func (suite *PagingSuite) TestPagingNoPrevID() {
@ -94,27 +94,11 @@ func (suite *PagingSuite) TestPagingNoPrevID() {
resp := paging.PackageResponse(params) resp := paging.PackageResponse(params)
suite.Equal(make([]interface{}, 10, 10), resp.Items) suite.Equal(make([]interface{}, 10, 10), resp.Items)
suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10>; rel="next"`, resp.LinkHeader) suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next"`, resp.LinkHeader)
suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10`, resp.NextLink) suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink)
suite.Equal(``, resp.PrevLink) suite.Equal(``, resp.PrevLink)
} }
func (suite *PagingSuite) TestPagingNoItems() {
config.SetHost("example.org")
params := paging.ResponseParams{
Next: nextPage("01H11KA1DM2VH3747YDE7FV5HN", 10),
Prev: prevPage("01H11KBBVRRDYYC5KEPME1NP5R", 10),
}
resp := paging.PackageResponse(params)
suite.Empty(resp.Items)
suite.Empty(resp.LinkHeader)
suite.Empty(resp.NextLink)
suite.Empty(resp.PrevLink)
}
func TestPagingSuite(t *testing.T) { func TestPagingSuite(t *testing.T) {
suite.Run(t, &PagingSuite{}) suite.Run(t, &PagingSuite{})
} }
@ -128,7 +112,7 @@ func nextPage(id string, limit int) *paging.Page {
func prevPage(id string, limit int) *paging.Page { func prevPage(id string, limit int) *paging.Page {
return &paging.Page{ return &paging.Page{
Min: paging.MinID(id, ""), Min: paging.MinID(id),
Limit: limit, Limit: limit,
} }
} }

View file

@ -41,9 +41,3 @@ func Reverse(in []string) []string {
return in return in
} }
// zero is a shorthand to check a generic value is its zero value.
func zero[T comparable](t T) bool {
var z T
return t == z
}

View file

@ -22,6 +22,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing/common"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
@ -32,6 +33,9 @@ import (
// //
// It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc. // It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc.
type Processor struct { type Processor struct {
// common processor logic
c *common.Processor
state *state.State state *state.State
tc typeutils.TypeConverter tc typeutils.TypeConverter
mediaManager *media.Manager mediaManager *media.Manager
@ -44,6 +48,7 @@ type Processor struct {
// New returns a new account processor. // New returns a new account processor.
func New( func New(
common *common.Processor,
state *state.State, state *state.State,
tc typeutils.TypeConverter, tc typeutils.TypeConverter,
mediaManager *media.Manager, mediaManager *media.Manager,
@ -53,6 +58,7 @@ func New(
parseMention gtsmodel.ParseMentionFunc, parseMention gtsmodel.ParseMentionFunc,
) Processor { ) Processor {
return Processor{ return Processor{
c: common,
state: state, state: state,
tc: tc, tc: tc,
mediaManager: mediaManager, mediaManager: mediaManager,

View file

@ -30,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/processing/account" "github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/processing/common"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
@ -113,7 +114,8 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
filter := visibility.NewFilter(&suite.state) filter := visibility.NewFilter(&suite.state)
suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) common := common.New(&suite.state, suite.tc, suite.federator, filter)
suite.accountProcessor = account.New(&common, &suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
} }

View file

@ -28,8 +28,11 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/uris" "github.com/superseriousbusiness/gotosocial/internal/uris"
"github.com/superseriousbusiness/gotosocial/internal/util"
) )
// BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local. // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local.
@ -128,6 +131,53 @@ func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel
return p.RelationshipGet(ctx, requestingAccount, targetAccountID) return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
} }
// BlocksGet ...
func (p *Processor) BlocksGet(
ctx context.Context,
requestingAccount *gtsmodel.Account,
page *paging.Page,
) (*apimodel.PageableResponse, gtserror.WithCode) {
blocks, err := p.state.DB.GetAccountBlocks(ctx,
requestingAccount.ID,
page,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
// Check for empty response.
count := len(blocks)
if len(blocks) == 0 {
return util.EmptyPageableResponse(), nil
}
items := make([]interface{}, 0, count)
for _, block := range blocks {
// Convert target account to frontend API model. (target will never be nil)
account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount)
if err != nil {
log.Errorf(ctx, "error converting account to public api account: %v", err)
continue
}
// Append target to return items.
items = append(items, account)
}
// Get the lowest and highest
// ID values, used for paging.
lo := blocks[count-1].ID
hi := blocks[0].ID
return paging.PackageResponse(paging.ResponseParams{
Items: items,
Path: "/api/v1/blocks",
Next: page.Next(lo, hi),
Prev: page.Prev(lo, hi),
}), nil
}
func (p *Processor) getBlockTarget(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*gtsmodel.Account, *gtsmodel.Block, gtserror.WithCode) { func (p *Processor) getBlockTarget(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*gtsmodel.Account, *gtsmodel.Block, gtserror.WithCode) {
// Account should not block or unblock itself. // Account should not block or unblock itself.
if requestingAccount.ID == targetAccountID { if requestingAccount.ID == targetAccountID {

View file

@ -160,7 +160,7 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *
// - Follow requests created by account. // - Follow requests created by account.
func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.Account) error { func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.Account) error {
// Delete follows targeting this account. // Delete follows targeting this account.
followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID) followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err) return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err)
} }
@ -172,7 +172,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
} }
// Delete follow requests targeting this account. // Delete follow requests targeting this account.
followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID) followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err) return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err)
} }
@ -193,7 +193,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
) )
// Delete follows originating from this account. // Delete follows originating from this account.
following, err := p.state.DB.GetAccountFollows(ctx, account.ID) following, err := p.state.DB.GetAccountFollows(ctx, account.ID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err) return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err)
} }
@ -211,7 +211,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
} }
// Delete follow requests originating from this account. // Delete follow requests originating from this account.
followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID) followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err) return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err)
} }

View file

@ -20,7 +20,6 @@ package account
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@ -35,7 +34,7 @@ import (
// FollowCreate handles a follow request to an account, either remote or local. // FollowCreate handles a follow request to an account, either remote or local.
func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount.ID, form.ID) targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount, form.ID)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -46,7 +45,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
requestingAccount.ID, requestingAccount.ID,
targetAccount.ID, targetAccount.ID,
); err != nil && !errors.Is(err, db.ErrNoEntries) { ); err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FollowCreate: db error checking existing follow: %w", err) err = gtserror.Newf("db error checking existing follow: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} else if follow != nil { } else if follow != nil {
// Already follows, update if necessary + return relationship. // Already follows, update if necessary + return relationship.
@ -66,7 +65,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
requestingAccount.ID, requestingAccount.ID,
targetAccount.ID, targetAccount.ID,
); err != nil && !errors.Is(err, db.ErrNoEntries) { ); err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FollowCreate: db error checking existing follow request: %w", err) err = gtserror.Newf("db error checking existing follow request: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} else if followRequest != nil { } else if followRequest != nil {
// Already requested, update if necessary + return relationship. // Already requested, update if necessary + return relationship.
@ -100,7 +99,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
} }
if err := p.state.DB.PutFollowRequest(ctx, fr); err != nil { if err := p.state.DB.PutFollowRequest(ctx, fr); err != nil {
err = fmt.Errorf("FollowCreate: error creating follow request in db: %s", err) err = gtserror.Newf("error creating follow request in db: %s", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -112,7 +111,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
// Because we know the requestingAccount is also // Because we know the requestingAccount is also
// local, we don't need to federate the accept out. // local, we don't need to federate the accept out.
if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
err = fmt.Errorf("FollowCreate: error accepting follow request for local unlocked account: %w", err) err = gtserror.Newf("error accepting follow request for local unlocked account: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
} else if targetAccount.IsRemote() { } else if targetAccount.IsRemote() {
@ -132,7 +131,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
// FollowRemove handles the removal of a follow/follow request to an account, either remote or local. // FollowRemove handles the removal of a follow/follow request to an account, either remote or local.
func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount.ID, targetAccountID) targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount, targetAccountID)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -140,7 +139,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
// Unfollow and deal with side effects. // Unfollow and deal with side effects.
msgs, err := p.unfollow(ctx, requestingAccount, targetAccount) msgs, err := p.unfollow(ctx, requestingAccount, targetAccount)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("FollowRemove: account %s not found in the db: %s", targetAccountID, err)) return nil, gtserror.NewErrorNotFound(gtserror.Newf("account %s not found in the db: %s", targetAccountID, err))
} }
// Batch queue accreted client api messages. // Batch queue accreted client api messages.
@ -166,7 +165,6 @@ func (p *Processor) updateFollow(
currentNotify *bool, currentNotify *bool,
update func(...string) error, update func(...string) error,
) (*apimodel.Relationship, gtserror.WithCode) { ) (*apimodel.Relationship, gtserror.WithCode) {
if form.Reblogs == nil && form.Notify == nil { if form.Reblogs == nil && form.Notify == nil {
// There's nothing to update. // There's nothing to update.
return p.RelationshipGet(ctx, requestingAccount, form.ID) return p.RelationshipGet(ctx, requestingAccount, form.ID)
@ -192,7 +190,7 @@ func (p *Processor) updateFollow(
} }
if err := update(columns...); err != nil { if err := update(columns...); err != nil {
err = fmt.Errorf("updateFollow: error updating existing follow (request): %w", err) err = gtserror.Newf("error updating existing follow (request): %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -201,38 +199,23 @@ func (p *Processor) updateFollow(
// getFollowTarget is a convenience function which: // getFollowTarget is a convenience function which:
// - Checks if account is trying to follow/unfollow itself. // - Checks if account is trying to follow/unfollow itself.
// - Returns not found if there's a block in place between accounts. // - Returns not found if target should not be visible to requester.
// - Returns target account according to its id. // - Returns target account according to its id.
func (p *Processor) getFollowTarget(ctx context.Context, requestingAccountID string, targetAccountID string) (*gtsmodel.Account, gtserror.WithCode) { func (p *Processor) getFollowTarget(ctx context.Context, requester *gtsmodel.Account, targetID string) (*gtsmodel.Account, gtserror.WithCode) {
// Check for requester.
if requester == nil {
err := errors.New("no authorized user")
return nil, gtserror.NewErrorUnauthorized(err)
}
// Account can't follow or unfollow itself. // Account can't follow or unfollow itself.
if requestingAccountID == targetAccountID { if requester.ID == targetID {
err := errors.New("account can't follow or unfollow itself") err := errors.New("account can't follow or unfollow itself")
return nil, gtserror.NewErrorNotAcceptable(err) return nil, gtserror.NewErrorNotAcceptable(err)
} }
// Do nothing if a block exists in either direction between accounts. // Fetch the target account for requesting user account.
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, targetAccountID); err != nil { return p.c.GetVisibleTargetAccount(ctx, requester, targetID)
err = fmt.Errorf("db error checking block between accounts: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
err = errors.New("block exists between accounts")
return nil, gtserror.NewErrorNotFound(err)
}
// Ensure target account retrievable.
targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// Real db error.
err = fmt.Errorf("db error looking for target account %s: %w", targetAccountID, err)
return nil, gtserror.NewErrorInternalError(err)
}
// Account not found.
err = fmt.Errorf("target account %s not found in the db", targetAccountID)
return nil, gtserror.NewErrorNotFound(err, err.Error())
}
return targetAccount, nil
} }
// unfollow is a convenience function for having requesting account // unfollow is a convenience function for having requesting account
@ -248,7 +231,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
// Get follow from requesting account to target account. // Get follow from requesting account to target account.
follow, err := p.state.DB.GetFollow(ctx, requestingAccount.ID, targetAccount.ID) follow, err := p.state.DB.GetFollow(ctx, requestingAccount.ID, targetAccount.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("unfollow: error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) err = gtserror.Newf("error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err return nil, err
} }
@ -257,7 +240,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
err = p.state.DB.DeleteFollowByID(ctx, follow.ID) err = p.state.DB.DeleteFollowByID(ctx, follow.ID)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("unfollow: error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) err = gtserror.Newf("error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err return nil, err
} }
@ -284,7 +267,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
// Get follow request from requesting account to target account. // Get follow request from requesting account to target account.
followReq, err := p.state.DB.GetFollowRequest(ctx, requestingAccount.ID, targetAccount.ID) followReq, err := p.state.DB.GetFollowRequest(ctx, requestingAccount.ID, targetAccount.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("unfollow: error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) err = gtserror.Newf("error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err return nil, err
} }
@ -293,7 +276,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
err = p.state.DB.DeleteFollowRequestByID(ctx, followReq.ID) err = p.state.DB.DeleteFollowRequestByID(ctx, followReq.ID)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("unfollow: error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) err = gtserror.Newf("error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err return nil, err
} }

View file

@ -0,0 +1,119 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package account
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/ap"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/paging"
)
// FollowRequestAccept handles the accepting of a follow request from the sourceAccountID to the requestingAccount (the currently authorized account).
func (p *Processor) FollowRequestAccept(ctx context.Context, requestingAccount *gtsmodel.Account, sourceAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
follow, err := p.state.DB.AcceptFollowRequest(ctx, sourceAccountID, requestingAccount.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
if follow.Account != nil {
// Only enqueue work in the case we have a request creating account stored.
// NOTE: due to how AcceptFollowRequest works, the inverse shouldn't be possible.
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept,
GTSModel: follow,
OriginAccount: follow.Account,
TargetAccount: follow.TargetAccount,
})
}
return p.RelationshipGet(ctx, requestingAccount, sourceAccountID)
}
// FollowRequestReject handles the rejection of a follow request from the sourceAccountID to the requestingAccount (the currently authorized account).
func (p *Processor) FollowRequestReject(ctx context.Context, requestingAccount *gtsmodel.Account, sourceAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
followRequest, err := p.state.DB.GetFollowRequest(ctx, sourceAccountID, requestingAccount.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
err = p.state.DB.RejectFollowRequest(ctx, sourceAccountID, requestingAccount.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
if followRequest.Account != nil {
// Only enqueue work in the case we have a request creating account stored.
// NOTE: due to how GetFollowRequest works, the inverse shouldn't be possible.
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityReject,
GTSModel: followRequest,
OriginAccount: followRequest.Account,
TargetAccount: followRequest.TargetAccount,
})
}
return p.RelationshipGet(ctx, requestingAccount, sourceAccountID)
}
// FollowRequestsGet fetches a list of the accounts that are follow requesting the given requestingAccount (the currently authorized account).
func (p *Processor) FollowRequestsGet(ctx context.Context, requestingAccount *gtsmodel.Account, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) {
// Fetch follow requests targeting the given requesting account model.
followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, requestingAccount.ID, page)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
// Check for empty response.
count := len(followRequests)
if count == 0 {
return paging.EmptyResponse(), nil
}
// Get the lowest and highest
// ID values, used for paging.
lo := followRequests[count-1].ID
hi := followRequests[0].ID
// Func to fetch follow source at index.
getIdx := func(i int) *gtsmodel.Account {
return followRequests[i].Account
}
// Get a filtered slice of public API account models.
items := p.c.GetVisibleAPIAccountsPaged(ctx,
requestingAccount,
getIdx,
count,
)
return paging.PackageResponse(paging.ResponseParams{
Items: items,
Path: "/api/v1/follow_requests",
Next: page.Next(lo, hi),
Prev: page.Prev(lo, hi),
}), nil
}

View file

@ -20,128 +20,120 @@ package account
import ( import (
"context" "context"
"errors" "errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/paging"
) )
// FollowersGet fetches a list of the target account's followers. // FollowersGet fetches a list of the target account's followers.
func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) {
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil { // Fetch target account to check it exists, and visibility of requester->target.
err = fmt.Errorf("FollowersGet: db error checking block: %w", err) _, errWithCode := p.c.GetVisibleTargetAccount(ctx, requestingAccount, targetAccountID)
return nil, gtserror.NewErrorInternalError(err) if errWithCode != nil {
} else if blocked { return nil, errWithCode
err = errors.New("FollowersGet: block exists between accounts")
return nil, gtserror.NewErrorNotFound(err)
} }
follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID) follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID, page)
if err != nil { if err != nil && !errors.Is(err, db.ErrNoEntries) {
if !errors.Is(err, db.ErrNoEntries) { err = gtserror.Newf("db error getting followers: %w", err)
err = fmt.Errorf("FollowersGet: db error getting followers: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
return []apimodel.Account{}, nil
// Check for empty response.
count := len(follows)
if count == 0 {
return paging.EmptyResponse(), nil
} }
return p.accountsFromFollows(ctx, follows, requestingAccount.ID) // Get the lowest and highest
// ID values, used for paging.
lo := follows[count-1].ID
hi := follows[0].ID
// Func to fetch follow source at index.
getIdx := func(i int) *gtsmodel.Account {
return follows[i].Account
}
// Get a filtered slice of public API account models.
items := p.c.GetVisibleAPIAccountsPaged(ctx,
requestingAccount,
getIdx,
len(follows),
)
return paging.PackageResponse(paging.ResponseParams{
Items: items,
Path: "/api/v1/accounts/" + targetAccountID + "/followers",
Next: page.Next(lo, hi),
Prev: page.Prev(lo, hi),
}), nil
} }
// FollowingGet fetches a list of the accounts that target account is following. // FollowingGet fetches a list of the accounts that target account is following.
func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) {
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil { // Fetch target account to check it exists, and visibility of requester->target.
err = fmt.Errorf("FollowingGet: db error checking block: %w", err) _, errWithCode := p.c.GetVisibleTargetAccount(ctx, requestingAccount, targetAccountID)
return nil, gtserror.NewErrorInternalError(err) if errWithCode != nil {
} else if blocked { return nil, errWithCode
err = errors.New("FollowingGet: block exists between accounts")
return nil, gtserror.NewErrorNotFound(err)
} }
follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID) // Fetch known accounts that follow given target account ID.
if err != nil { follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID, page)
if !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FollowingGet: db error getting followers: %w", err) err = gtserror.Newf("db error getting followers: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
return []apimodel.Account{}, nil
// Check for empty response.
count := len(follows)
if count == 0 {
return paging.EmptyResponse(), nil
} }
return p.targetAccountsFromFollows(ctx, follows, requestingAccount.ID) // Get the lowest and highest
// ID values, used for paging.
lo := follows[count-1].ID
hi := follows[0].ID
// Func to fetch follow source at index.
getIdx := func(i int) *gtsmodel.Account {
return follows[i].TargetAccount
}
// Get a filtered slice of public API account models.
items := p.c.GetVisibleAPIAccountsPaged(ctx,
requestingAccount,
getIdx,
len(follows),
)
return paging.PackageResponse(paging.ResponseParams{
Items: items,
Path: "/api/v1/accounts/" + targetAccountID + "/following",
Next: page.Next(lo, hi),
Prev: page.Prev(lo, hi),
}), nil
} }
// RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account. // RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account.
func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
if requestingAccount == nil { if requestingAccount == nil {
return nil, gtserror.NewErrorForbidden(errors.New("not authed")) return nil, gtserror.NewErrorForbidden(gtserror.New("not authed"))
} }
gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID) gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err)) return nil, gtserror.NewErrorInternalError(gtserror.Newf("error getting relationship: %s", err))
} }
r, err := p.tc.RelationshipToAPIRelationship(ctx, gtsR) r, err := p.tc.RelationshipToAPIRelationship(ctx, gtsR)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting relationship: %s", err)) return nil, gtserror.NewErrorInternalError(gtserror.Newf("error converting relationship: %s", err))
} }
return r, nil return r, nil
} }
func (p *Processor) accountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) {
accounts := make([]apimodel.Account, 0, len(follows))
for _, follow := range follows {
if follow.Account == nil {
// No account set for some reason; just skip.
log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated account")
continue
}
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.AccountID); err != nil {
err = fmt.Errorf("accountsFromFollows: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
continue
}
account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.Account)
if err != nil {
err = fmt.Errorf("accountsFromFollows: error converting account to api account: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
accounts = append(accounts, *account)
}
return accounts, nil
}
func (p *Processor) targetAccountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) {
accounts := make([]apimodel.Account, 0, len(follows))
for _, follow := range follows {
if follow.TargetAccount == nil {
// No account set for some reason; just skip.
log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated target account")
continue
}
if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.TargetAccountID); err != nil {
err = fmt.Errorf("targetAccountsFromFollows: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
continue
}
account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.TargetAccount)
if err != nil {
err = fmt.Errorf("targetAccountsFromFollows: error converting account to api account: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
accounts = append(accounts, *account)
}
return accounts, nil
}

View file

@ -1,86 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package processing
import (
"context"
"errors"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// BlocksGet ...
func (p *Processor) BlocksGet(
ctx context.Context,
requestingAccount *gtsmodel.Account,
page *paging.Page,
) (*apimodel.PageableResponse, gtserror.WithCode) {
blocks, err := p.state.DB.GetAccountBlocks(ctx,
requestingAccount.ID,
page,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
// Check for zero length.
count := len(blocks)
if len(blocks) == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, 0, count)
// Set next + prev values before API converting
// so the caller can still page even on error.
nextMaxIDValue = blocks[count-1].ID
prevMinIDValue = blocks[0].ID
)
for _, block := range blocks {
if block.TargetAccount == nil {
// All models should be populated at this point.
log.Warnf(ctx, "block target account was nil: %v", err)
continue
}
// Convert target account to frontend API model.
account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount)
if err != nil {
log.Errorf(ctx, "error converting account to public api account: %v", err)
continue
}
// Append target to return items.
items = append(items, account)
}
return paging.PackageResponse(paging.ResponseParams{
Items: items,
Path: "/api/v1/blocks",
Next: page.Next(nextMaxIDValue),
Prev: page.Prev(prevMinIDValue),
}), nil
}

View file

@ -0,0 +1,238 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package common
import (
"context"
"errors"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// GetTargetAccountBy fetches the target account with db load function, given the authorized (or, nil) requester's
// account. This returns an approprate gtserror.WithCode accounting (ha) for not found and visibility to requester.
func (p *Processor) GetTargetAccountBy(
ctx context.Context,
requester *gtsmodel.Account,
getTargetFromDB func() (*gtsmodel.Account, error),
) (
account *gtsmodel.Account,
visible bool,
errWithCode gtserror.WithCode,
) {
// Fetch the target account from db.
target, err := getTargetFromDB()
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, false, gtserror.NewErrorInternalError(err)
}
if target == nil {
// DB loader could not find account in database.
err := errors.New("target account not found")
return nil, false, gtserror.NewErrorNotFound(err)
}
// Check whether target account is visible to requesting account.
visible, err = p.filter.AccountVisible(ctx, requester, target)
if err != nil {
return nil, false, gtserror.NewErrorInternalError(err)
}
if requester != nil && visible {
// Ensure the account is up-to-date.
p.federator.RefreshAccountAsync(ctx,
requester.Username,
target,
nil,
false,
)
}
return target, visible, nil
}
// GetTargetAccountByID is a call-through to GetTargetAccountBy() using the db GetAccountByID() function.
func (p *Processor) GetTargetAccountByID(
ctx context.Context,
requester *gtsmodel.Account,
targetID string,
) (
account *gtsmodel.Account,
visible bool,
errWithCode gtserror.WithCode,
) {
return p.GetTargetAccountBy(ctx, requester, func() (*gtsmodel.Account, error) {
return p.state.DB.GetAccountByID(ctx, targetID)
})
}
// GetVisibleTargetAccount calls GetTargetAccountByID(),
// but converts a non-visible result to not-found error.
func (p *Processor) GetVisibleTargetAccount(
ctx context.Context,
requester *gtsmodel.Account,
targetID string,
) (
account *gtsmodel.Account,
errWithCode gtserror.WithCode,
) {
// Fetch the target account by ID from the database.
target, visible, errWithCode := p.GetTargetAccountByID(ctx,
requester,
targetID,
)
if errWithCode != nil {
return nil, errWithCode
}
if !visible {
// Pretend account doesn't exist if not visible.
err := errors.New("target account not found")
return nil, gtserror.NewErrorNotFound(err)
}
return target, nil
}
// GetAPIAccount fetches the appropriate API account model depending on whether requester = target.
func (p *Processor) GetAPIAccount(
ctx context.Context,
requester *gtsmodel.Account,
target *gtsmodel.Account,
) (
apiAcc *apimodel.Account,
errWithCode gtserror.WithCode,
) {
var err error
if requester != nil && requester.ID == target.ID {
// Only return sensitive account model _if_ requester = target.
apiAcc, err = p.converter.AccountToAPIAccountSensitive(ctx, target)
} else {
// Else, fall back to returning the public account model.
apiAcc, err = p.converter.AccountToAPIAccountPublic(ctx, target)
}
if err != nil {
err := gtserror.Newf("error converting account: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
return apiAcc, nil
}
// GetAPIAccountBlocked fetches the limited "blocked" account model for given target.
func (p *Processor) GetAPIAccountBlocked(
ctx context.Context,
targetAcc *gtsmodel.Account,
) (
apiAcc *apimodel.Account,
errWithCode gtserror.WithCode,
) {
apiAccount, err := p.converter.AccountToAPIAccountBlocked(ctx, targetAcc)
if err != nil {
err = gtserror.Newf("error converting account: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
return apiAccount, nil
}
// GetVisibleAPIAccounts converts an array of gtsmodel.Accounts (inputted by next function) into
// public API model accounts, checking first for visibility. Please note that all errors will be
// logged at ERROR level, but will not be returned. Callers are likely to run into show-stopping
// errors in the lead-up to this function, whereas calling this should not be a show-stopper.
func (p *Processor) GetVisibleAPIAccounts(
ctx context.Context,
requester *gtsmodel.Account,
next func(int) *gtsmodel.Account,
length int,
) []*apimodel.Account {
return p.getVisibleAPIAccounts(ctx, 3, requester, next, length)
}
// GetVisibleAPIAccountsPaged is functionally equivalent to GetVisibleAPIAccounts(),
// except the accounts are returned as a converted slice of accounts as interface{}.
func (p *Processor) GetVisibleAPIAccountsPaged(
ctx context.Context,
requester *gtsmodel.Account,
next func(int) *gtsmodel.Account,
length int,
) []interface{} {
accounts := p.getVisibleAPIAccounts(ctx, 3, requester, next, length)
if len(accounts) == 0 {
return nil
}
items := make([]interface{}, len(accounts))
for i, account := range accounts {
items[i] = account
}
return items
}
func (p *Processor) getVisibleAPIAccounts(
ctx context.Context,
calldepth int, // used to skip wrapping func above these's names
requester *gtsmodel.Account,
next func(int) *gtsmodel.Account,
length int,
) []*apimodel.Account {
// Start new log entry with
// the above calling func's name.
l := log.
WithContext(ctx).
WithField("caller", log.Caller(calldepth+1))
// Preallocate slice according to expected length.
accounts := make([]*apimodel.Account, 0, length)
for i := 0; i < length; i++ {
// Get next account.
account := next(i)
if account == nil {
continue
}
// Check whether this account is visible to requesting account.
visible, err := p.filter.AccountVisible(ctx, requester, account)
if err != nil {
l.Errorf("error checking account visibility: %v", err)
continue
}
if !visible {
// Not visible to requester.
continue
}
// Convert the account to a public API model representation.
apiAcc, err := p.converter.AccountToAPIAccountPublic(ctx, account)
if err != nil {
l.Errorf("error converting account: %v", err)
continue
}
// Append API model to return slice.
accounts = append(accounts, apiAcc)
}
return accounts
}

View file

@ -0,0 +1,50 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package common
import (
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
// Processor provides a processor with logic
// common to multiple logical domains of the
// processing subsection of the codebase.
type Processor struct {
state *state.State
converter typeutils.TypeConverter
federator federation.Federator
filter *visibility.Filter
}
// New returns a new Processor instance.
func New(
state *state.State,
converter typeutils.TypeConverter,
federator federation.Federator,
filter *visibility.Filter,
) Processor {
return Processor{
state: state,
converter: converter,
federator: federator,
filter: filter,
}
}

View file

@ -0,0 +1,248 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package common
import (
"context"
"errors"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// GetTargetStatusBy fetches the target status with db load function, given the authorized (or, nil) requester's
// account. This returns an approprate gtserror.WithCode accounting for not found and visibility to requester.
func (p *Processor) GetTargetStatusBy(
ctx context.Context,
requester *gtsmodel.Account,
getTargetFromDB func() (*gtsmodel.Status, error),
) (
status *gtsmodel.Status,
visible bool,
errWithCode gtserror.WithCode,
) {
// Fetch the target status from db.
target, err := getTargetFromDB()
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, false, gtserror.NewErrorInternalError(err)
}
if target == nil {
// DB loader could not find status in database.
err := errors.New("target status not found")
return nil, false, gtserror.NewErrorNotFound(err)
}
// Check whether target status is visible to requesting account.
visible, err = p.filter.StatusVisible(ctx, requester, target)
if err != nil {
return nil, false, gtserror.NewErrorInternalError(err)
}
if requester != nil && visible {
// Ensure remote status is up-to-date.
p.federator.RefreshStatusAsync(ctx,
requester.Username,
target,
nil,
false,
)
}
return target, visible, nil
}
// GetTargetStatusByID is a call-through to GetTargetStatus() using the db GetStatusByID() function.
func (p *Processor) GetTargetStatusByID(
ctx context.Context,
requester *gtsmodel.Account,
targetID string,
) (
status *gtsmodel.Status,
visible bool,
errWithCode gtserror.WithCode,
) {
return p.GetTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) {
return p.state.DB.GetStatusByID(ctx, targetID)
})
}
// GetVisibleTargetStatus calls GetTargetStatusByID(),
// but converts a non-visible result to not-found error.
func (p *Processor) GetVisibleTargetStatus(
ctx context.Context,
requester *gtsmodel.Account,
targetID string,
) (
status *gtsmodel.Status,
errWithCode gtserror.WithCode,
) {
// Fetch the target status by ID from the database.
target, visible, errWithCode := p.GetTargetStatusByID(ctx,
requester,
targetID,
)
if errWithCode != nil {
return nil, errWithCode
}
if !visible {
// Target should not be seen by requester.
err := errors.New("target status not found")
return nil, gtserror.NewErrorNotFound(err)
}
return target, nil
}
// GetAPIStatus fetches the appropriate API status model for target.
func (p *Processor) GetAPIStatus(
ctx context.Context,
requester *gtsmodel.Account,
target *gtsmodel.Status,
) (
apiStatus *apimodel.Status,
errWithCode gtserror.WithCode,
) {
apiStatus, err := p.converter.StatusToAPIStatus(ctx, target, requester)
if err != nil {
err = gtserror.Newf("error converting status: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
return apiStatus, nil
}
// GetVisibleAPIStatuses converts an array of gtsmodel.Status (inputted by next function) into
// API model statuses, checking first for visibility. Please note that all errors will be
// logged at ERROR level, but will not be returned. Callers are likely to run into show-stopping
// errors in the lead-up to this function, whereas calling this should not be a show-stopper.
func (p *Processor) GetVisibleAPIStatuses(
ctx context.Context,
requester *gtsmodel.Account,
next func(int) *gtsmodel.Status,
length int,
) []*apimodel.Status {
return p.getVisibleAPIStatuses(ctx, 3, requester, next, length)
}
// GetVisibleAPIStatusesPaged is functionally equivalent to GetVisibleAPIStatuses(),
// except the statuses are returned as a converted slice of statuses as interface{}.
func (p *Processor) GetVisibleAPIStatusesPaged(
ctx context.Context,
requester *gtsmodel.Account,
next func(int) *gtsmodel.Status,
length int,
) []interface{} {
statuses := p.getVisibleAPIStatuses(ctx, 3, requester, next, length)
if len(statuses) == 0 {
return nil
}
items := make([]interface{}, len(statuses))
for i, status := range statuses {
items[i] = status
}
return items
}
func (p *Processor) getVisibleAPIStatuses(
ctx context.Context,
calldepth int, // used to skip wrapping func above these's names
requester *gtsmodel.Account,
next func(int) *gtsmodel.Status,
length int,
) []*apimodel.Status {
// Start new log entry with
// the above calling func's name.
l := log.
WithContext(ctx).
WithField("caller", log.Caller(calldepth+1))
// Preallocate slice according to expected length.
statuses := make([]*apimodel.Status, 0, length)
for i := 0; i < length; i++ {
// Get next status.
status := next(i)
if status == nil {
continue
}
// Check whether this status is visible to requesting account.
visible, err := p.filter.StatusVisible(ctx, requester, status)
if err != nil {
l.Errorf("error checking status visibility: %v", err)
continue
}
if !visible {
// Not visible to requester.
continue
}
// Convert the status to an API model representation.
apiStatus, err := p.converter.StatusToAPIStatus(ctx, status, requester)
if err != nil {
l.Errorf("error converting status: %v", err)
continue
}
// Append API model to return slice.
statuses = append(statuses, apiStatus)
}
return statuses
}
// InvalidateTimelinedStatus is a shortcut function for invalidating the cached
// representation one status in the home timeline and all list timelines of the
// given accountID. It should only be called in cases where a status update
// does *not* need to be passed into the processor via the worker queue, since
// such invalidation will, in that case, be handled by the processor instead.
func (p *Processor) InvalidateTimelinedStatus(ctx context.Context, accountID string, statusID string) error {
// Get lists first + bail if this fails.
lists, err := p.state.DB.GetListsForAccountID(ctx, accountID)
if err != nil {
return gtserror.Newf("db error getting lists for account %s: %w", accountID, err)
}
// Start new log entry with
// the above calling func's name.
l := log.
WithContext(ctx).
WithField("caller", log.Caller(3)).
WithField("accountID", accountID).
WithField("statusID", statusID)
// Unprepare item from home + list timelines, just log
// if something goes wrong since this is not a showstopper.
if err := p.state.Timelines.Home.UnprepareItem(ctx, accountID, statusID); err != nil {
l.Errorf("error unpreparing item from home timeline: %v", err)
}
for _, list := range lists {
if err := p.state.Timelines.List.UnprepareItem(ctx, list.ID, statusID); err != nil {
l.Errorf("error unpreparing item from list timeline %s: %v", list.ID, err)
}
}
return nil
}

View file

@ -1,123 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package processing
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/ap"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
accts := make([]apimodel.Account, 0, len(followRequests))
for _, followRequest := range followRequests {
if followRequest.Account == nil {
// The creator of the follow doesn't exist,
// just skip this one.
log.WithContext(ctx).WithField("followRequest", followRequest).Warn("follow request had no associated account")
continue
}
apiAcct, err := p.tc.AccountToAPIAccountPublic(ctx, followRequest.Account)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
accts = append(accts, *apiAcct)
}
return accts, nil
}
func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
if follow.Account == nil {
// The creator of the follow doesn't exist,
// so we can't do further processing.
log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated account")
return p.relationship(ctx, auth.Account.ID, accountID)
}
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept,
GTSModel: follow,
OriginAccount: follow.Account,
TargetAccount: follow.TargetAccount,
})
return p.relationship(ctx, auth.Account.ID, accountID)
}
func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
followRequest, err := p.state.DB.GetFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
err = p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
if followRequest.Account == nil {
// The creator of the request doesn't exist,
// so we can't do further processing.
return p.relationship(ctx, auth.Account.ID, accountID)
}
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityReject,
GTSModel: followRequest,
OriginAccount: followRequest.Account,
TargetAccount: followRequest.TargetAccount,
})
return p.relationship(ctx, auth.Account.ID, accountID)
}
func (p *Processor) relationship(ctx context.Context, accountID string, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
relationship, err := p.state.DB.GetRelationship(ctx, accountID, targetAccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
apiRelationship, err := p.tc.RelationshipToAPIRelationship(ctx, relationship)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
return apiRelationship, nil
}

View file

@ -30,35 +30,57 @@ import (
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
// TODO: move this to the "internal/processing/account" pkg
type FollowRequestTestSuite struct { type FollowRequestTestSuite struct {
ProcessingStandardTestSuite ProcessingStandardTestSuite
} }
func (suite *FollowRequestTestSuite) TestFollowRequestAccept() { func (suite *FollowRequestTestSuite) TestFollowRequestAccept() {
requestingAccount := suite.testAccounts["remote_account_2"] // The authed local account we are going to use for HTTP requests
targetAccount := suite.testAccounts["local_account_1"] requestingAccount := suite.testAccounts["local_account_1"]
// The remote account whose follow request we are accepting
targetAccount := suite.testAccounts["remote_account_2"]
// put a follow request in the database // put a follow request in the database
fr := &gtsmodel.FollowRequest{ fr := &gtsmodel.FollowRequest{
ID: "01FJ1S8DX3STJJ6CEYPMZ1M0R3", ID: "01FJ1S8DX3STJJ6CEYPMZ1M0R3",
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", requestingAccount.URI), URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", targetAccount.URI),
AccountID: requestingAccount.ID, AccountID: targetAccount.ID,
TargetAccountID: targetAccount.ID, TargetAccountID: requestingAccount.ID,
} }
err := suite.db.Put(context.Background(), fr) err := suite.db.Put(context.Background(), fr)
suite.NoError(err) suite.NoError(err)
relationship, errWithCode := suite.processor.FollowRequestAccept(context.Background(), suite.testAutheds["local_account_1"], requestingAccount.ID) relationship, errWithCode := suite.processor.Account().FollowRequestAccept(
context.Background(),
requestingAccount,
targetAccount.ID,
)
suite.NoError(errWithCode) suite.NoError(errWithCode)
suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: true, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship) suite.EqualValues(&apimodel.Relationship{
ID: "01FHMQX3GAABWSM0S2VZEC2SWC",
Following: false,
ShowingReblogs: false,
Notifying: false,
FollowedBy: true,
Blocking: false,
BlockedBy: false,
Muting: false,
MutingNotifications: false,
Requested: false,
DomainBlocking: false,
Endorsed: false,
Note: "",
}, relationship)
// accept should be sent to Some_User // accept should be sent to Some_User
var sent [][]byte var sent [][]byte
if !testrig.WaitFor(func() bool { if !testrig.WaitFor(func() bool {
sentI, ok := suite.httpClient.SentMessages.Load(requestingAccount.InboxURI) sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI)
if ok { if ok {
sent, ok = sentI.([][]byte) sent, ok = sentI.([][]byte)
if !ok { if !ok {
@ -87,41 +109,45 @@ func (suite *FollowRequestTestSuite) TestFollowRequestAccept() {
err = json.Unmarshal(sent[0], accept) err = json.Unmarshal(sent[0], accept)
suite.NoError(err) suite.NoError(err)
suite.Equal(targetAccount.URI, accept.Actor) suite.Equal(requestingAccount.URI, accept.Actor)
suite.Equal(requestingAccount.URI, accept.Object.Actor) suite.Equal(targetAccount.URI, accept.Object.Actor)
suite.Equal(fr.URI, accept.Object.ID) suite.Equal(fr.URI, accept.Object.ID)
suite.Equal(targetAccount.URI, accept.Object.Object) suite.Equal(requestingAccount.URI, accept.Object.Object)
suite.Equal(targetAccount.URI, accept.Object.To) suite.Equal(requestingAccount.URI, accept.Object.To)
suite.Equal("Follow", accept.Object.Type) suite.Equal("Follow", accept.Object.Type)
suite.Equal(requestingAccount.URI, accept.To) suite.Equal(targetAccount.URI, accept.To)
suite.Equal("Accept", accept.Type) suite.Equal("Accept", accept.Type)
} }
func (suite *FollowRequestTestSuite) TestFollowRequestReject() { func (suite *FollowRequestTestSuite) TestFollowRequestReject() {
requestingAccount := suite.testAccounts["remote_account_2"] requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["remote_account_2"]
// put a follow request in the database // put a follow request in the database
fr := &gtsmodel.FollowRequest{ fr := &gtsmodel.FollowRequest{
ID: "01FJ1S8DX3STJJ6CEYPMZ1M0R3", ID: "01FJ1S8DX3STJJ6CEYPMZ1M0R3",
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", requestingAccount.URI), URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", targetAccount.URI),
AccountID: requestingAccount.ID, AccountID: targetAccount.ID,
TargetAccountID: targetAccount.ID, TargetAccountID: requestingAccount.ID,
} }
err := suite.db.Put(context.Background(), fr) err := suite.db.Put(context.Background(), fr)
suite.NoError(err) suite.NoError(err)
relationship, errWithCode := suite.processor.FollowRequestReject(context.Background(), suite.testAutheds["local_account_1"], requestingAccount.ID) relationship, errWithCode := suite.processor.Account().FollowRequestReject(
context.Background(),
requestingAccount,
targetAccount.ID,
)
suite.NoError(errWithCode) suite.NoError(errWithCode)
suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: false, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship) suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: false, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship)
// reject should be sent to Some_User // reject should be sent to Some_User
var sent [][]byte var sent [][]byte
if !testrig.WaitFor(func() bool { if !testrig.WaitFor(func() bool {
sentI, ok := suite.httpClient.SentMessages.Load(requestingAccount.InboxURI) sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI)
if ok { if ok {
sent, ok = sentI.([][]byte) sent, ok = sentI.([][]byte)
if !ok { if !ok {
@ -150,13 +176,13 @@ func (suite *FollowRequestTestSuite) TestFollowRequestReject() {
err = json.Unmarshal(sent[0], reject) err = json.Unmarshal(sent[0], reject)
suite.NoError(err) suite.NoError(err)
suite.Equal(targetAccount.URI, reject.Actor) suite.Equal(requestingAccount.URI, reject.Actor)
suite.Equal(requestingAccount.URI, reject.Object.Actor) suite.Equal(targetAccount.URI, reject.Object.Actor)
suite.Equal(fr.URI, reject.Object.ID) suite.Equal(fr.URI, reject.Object.ID)
suite.Equal(targetAccount.URI, reject.Object.Object) suite.Equal(requestingAccount.URI, reject.Object.Object)
suite.Equal(targetAccount.URI, reject.Object.To) suite.Equal(requestingAccount.URI, reject.Object.To)
suite.Equal("Follow", reject.Object.Type) suite.Equal("Follow", reject.Object.Type)
suite.Equal(requestingAccount.URI, reject.To) suite.Equal(targetAccount.URI, reject.To)
suite.Equal("Reject", reject.Type) suite.Equal("Reject", reject.Type)
} }

View file

@ -24,6 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing/account" "github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/processing/admin" "github.com/superseriousbusiness/gotosocial/internal/processing/admin"
"github.com/superseriousbusiness/gotosocial/internal/processing/common"
"github.com/superseriousbusiness/gotosocial/internal/processing/fedi" "github.com/superseriousbusiness/gotosocial/internal/processing/fedi"
"github.com/superseriousbusiness/gotosocial/internal/processing/list" "github.com/superseriousbusiness/gotosocial/internal/processing/list"
"github.com/superseriousbusiness/gotosocial/internal/processing/markers" "github.com/superseriousbusiness/gotosocial/internal/processing/markers"
@ -147,7 +148,8 @@ func NewProcessor(
// //
// Start with sub processors that will // Start with sub processors that will
// be required by the workers processor. // be required by the workers processor.
accountProcessor := account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc) commonProcessor := common.New(state, tc, federator, filter)
accountProcessor := account.New(&commonProcessor, state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc)
mediaProcessor := media.New(state, tc, mediaManager, federator.TransportController()) mediaProcessor := media.New(state, tc, mediaManager, federator.TransportController())
streamProcessor := stream.New(state, oauthServer) streamProcessor := stream.New(state, oauthServer)

View file

@ -66,6 +66,7 @@ func (suite *GetTestSuite) emptyAccountFollows(ctx context.Context, accountID st
follows, err := suite.state.DB.GetAccountFollows( follows, err := suite.state.DB.GetAccountFollows(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
accountID, accountID,
nil, // select all
) )
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
@ -82,6 +83,7 @@ func (suite *GetTestSuite) emptyAccountFollows(ctx context.Context, accountID st
follows, err = suite.state.DB.GetAccountFollows( follows, err = suite.state.DB.GetAccountFollows(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
accountID, accountID,
nil, // select all
) )
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())

View file

@ -364,6 +364,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
SuspendedAt: time.Time{}, SuspendedAt: time.Time{},
HideCollections: util.Ptr(false), HideCollections: util.Ptr(false),
SuspensionOrigin: "", SuspensionOrigin: "",
EnableRSS: util.Ptr(false),
}, },
"admin_account": { "admin_account": {
ID: "01F8MH17FWEB39HZJ76B6VXSKF", ID: "01F8MH17FWEB39HZJ76B6VXSKF",
@ -539,6 +540,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
SuspendedAt: time.Time{}, SuspendedAt: time.Time{},
HideCollections: util.Ptr(false), HideCollections: util.Ptr(false),
SuspensionOrigin: "", SuspensionOrigin: "",
EnableRSS: util.Ptr(false),
}, },
"remote_account_2": { "remote_account_2": {
ID: "01FHMQX3GAABWSM0S2VZEC2SWC", ID: "01FHMQX3GAABWSM0S2VZEC2SWC",
@ -575,6 +577,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
SuspendedAt: time.Time{}, SuspendedAt: time.Time{},
HideCollections: util.Ptr(false), HideCollections: util.Ptr(false),
SuspensionOrigin: "", SuspensionOrigin: "",
EnableRSS: util.Ptr(false),
}, },
"remote_account_3": { "remote_account_3": {
ID: "062G5WYKY35KKD12EMSM3F8PJ8", ID: "062G5WYKY35KKD12EMSM3F8PJ8",
@ -612,6 +615,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
HideCollections: util.Ptr(false), HideCollections: util.Ptr(false),
SuspensionOrigin: "", SuspensionOrigin: "",
HeaderMediaAttachmentID: "01PFPMWK2FF0D9WMHEJHR07C3R", HeaderMediaAttachmentID: "01PFPMWK2FF0D9WMHEJHR07C3R",
EnableRSS: util.Ptr(false),
}, },
"remote_account_4": { "remote_account_4": {
ID: "07GZRBAEMBNKGZ8Z9VSKSXKR98", ID: "07GZRBAEMBNKGZ8Z9VSKSXKR98",

2
vendor/github.com/tomnomnom/linkheader/.gitignore generated vendored Normal file
View file

@ -0,0 +1,2 @@
cpu.out
linkheader.test

6
vendor/github.com/tomnomnom/linkheader/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,6 @@
language: go
go:
- 1.6
- 1.7
- tip

View file

@ -0,0 +1,10 @@
# Contributing
* Raise an issue if appropriate
* Fork the repo
* Bootstrap the dev dependencies (run `./script/bootstrap`)
* Make your changes
* Use [gofmt](https://golang.org/cmd/gofmt/)
* Make sure the tests pass (run `./script/test`)
* Make sure the linters pass (run `./script/lint`)
* Issue a pull request

21
vendor/github.com/tomnomnom/linkheader/LICENSE generated vendored Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2016 Tom Hudson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

35
vendor/github.com/tomnomnom/linkheader/README.mkd generated vendored Normal file
View file

@ -0,0 +1,35 @@
# Golang Link Header Parser
Library for parsing HTTP Link headers. Requires Go 1.6 or higher.
Docs can be found on [the GoDoc page](https://godoc.org/github.com/tomnomnom/linkheader).
[![Build Status](https://travis-ci.org/tomnomnom/linkheader.svg)](https://travis-ci.org/tomnomnom/linkheader)
## Basic Example
```go
package main
import (
"fmt"
"github.com/tomnomnom/linkheader"
)
func main() {
header := "<https://api.github.com/user/58276/repos?page=2>; rel=\"next\"," +
"<https://api.github.com/user/58276/repos?page=2>; rel=\"last\""
links := linkheader.Parse(header)
for _, link := range links {
fmt.Printf("URL: %s; Rel: %s\n", link.URL, link.Rel)
}
}
// Output:
// URL: https://api.github.com/user/58276/repos?page=2; Rel: next
// URL: https://api.github.com/user/58276/repos?page=2; Rel: last
```

151
vendor/github.com/tomnomnom/linkheader/main.go generated vendored Normal file
View file

@ -0,0 +1,151 @@
// Package linkheader provides functions for parsing HTTP Link headers
package linkheader
import (
"fmt"
"strings"
)
// A Link is a single URL and related parameters
type Link struct {
URL string
Rel string
Params map[string]string
}
// HasParam returns if a Link has a particular parameter or not
func (l Link) HasParam(key string) bool {
for p := range l.Params {
if p == key {
return true
}
}
return false
}
// Param returns the value of a parameter if it exists
func (l Link) Param(key string) string {
for k, v := range l.Params {
if key == k {
return v
}
}
return ""
}
// String returns the string representation of a link
func (l Link) String() string {
p := make([]string, 0, len(l.Params))
for k, v := range l.Params {
p = append(p, fmt.Sprintf("%s=\"%s\"", k, v))
}
if l.Rel != "" {
p = append(p, fmt.Sprintf("%s=\"%s\"", "rel", l.Rel))
}
return fmt.Sprintf("<%s>; %s", l.URL, strings.Join(p, "; "))
}
// Links is a slice of Link structs
type Links []Link
// FilterByRel filters a group of Links by the provided Rel attribute
func (l Links) FilterByRel(r string) Links {
links := make(Links, 0)
for _, link := range l {
if link.Rel == r {
links = append(links, link)
}
}
return links
}
// String returns the string representation of multiple Links
// for use in HTTP responses etc
func (l Links) String() string {
if l == nil {
return fmt.Sprint(nil)
}
var strs []string
for _, link := range l {
strs = append(strs, link.String())
}
return strings.Join(strs, ", ")
}
// Parse parses a raw Link header in the form:
// <url>; rel="foo", <url>; rel="bar"; wat="dis"
// returning a slice of Link structs
func Parse(raw string) Links {
var links Links
// One chunk: <url>; rel="foo"
for _, chunk := range strings.Split(raw, ",") {
link := Link{URL: "", Rel: "", Params: make(map[string]string)}
// Figure out what each piece of the chunk is
for _, piece := range strings.Split(chunk, ";") {
piece = strings.Trim(piece, " ")
if piece == "" {
continue
}
// URL
if piece[0] == '<' && piece[len(piece)-1] == '>' {
link.URL = strings.Trim(piece, "<>")
continue
}
// Params
key, val := parseParam(piece)
if key == "" {
continue
}
// Special case for rel
if strings.ToLower(key) == "rel" {
link.Rel = val
} else {
link.Params[key] = val
}
}
if link.URL != "" {
links = append(links, link)
}
}
return links
}
// ParseMultiple is like Parse, but accepts a slice of headers
// rather than just one header string
func ParseMultiple(headers []string) Links {
links := make(Links, 0)
for _, header := range headers {
links = append(links, Parse(header)...)
}
return links
}
// parseParam takes a raw param in the form key="val" and
// returns the key and value as seperate strings
func parseParam(raw string) (key, val string) {
parts := strings.SplitN(raw, "=", 2)
if len(parts) == 1 {
return parts[0], ""
}
if len(parts) != 2 {
return "", ""
}
key = parts[0]
val = strings.Trim(parts[1], "\"")
return key, val
}

3
vendor/modules.txt vendored
View file

@ -672,6 +672,9 @@ github.com/tdewolff/parse/v2/strconv
# github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc # github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc
## explicit ## explicit
github.com/tmthrgd/go-hex github.com/tmthrgd/go-hex
# github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
## explicit
github.com/tomnomnom/linkheader
# github.com/twitchyliquid64/golang-asm v0.15.1 # github.com/twitchyliquid64/golang-asm v0.15.1
## explicit; go 1.13 ## explicit; go 1.13
github.com/twitchyliquid64/golang-asm/asm/arch github.com/twitchyliquid64/golang-asm/asm/arch