diff --git a/internal/ap/activitystreams_test.go b/internal/ap/activitystreams_test.go index ee03f9b0f..d769fa42f 100644 --- a/internal/ap/activitystreams_test.go +++ b/internal/ap/activitystreams_test.go @@ -49,6 +49,7 @@ func TestASCollection(t *testing.T) { // Create new collection using builder function. c := ap.NewASCollection(ap.CollectionParams{ ID: parseURI(idURI), + Query: url.Values{"limit": []string{"40"}}, Total: total, }) @@ -56,7 +57,7 @@ func TestASCollection(t *testing.T) { s := toJSON(c) // Ensure outputs are equal. - assert.Equal(t, s, expect) + assert.Equal(t, expect, s) } func TestASCollectionPage(t *testing.T) { @@ -110,7 +111,7 @@ func TestASCollectionPage(t *testing.T) { s := toJSON(p) // Ensure outputs are equal. - assert.Equal(t, s, expect) + assert.Equal(t, expect, s) } func TestASOrderedCollection(t *testing.T) { @@ -131,6 +132,7 @@ func TestASOrderedCollection(t *testing.T) { // Create new collection using builder function. c := ap.NewASOrderedCollection(ap.CollectionParams{ ID: parseURI(idURI), + Query: url.Values{"limit": []string{"40"}}, Total: total, }) @@ -138,7 +140,7 @@ func TestASOrderedCollection(t *testing.T) { s := toJSON(c) // Ensure outputs are equal. - assert.Equal(t, s, expect) + assert.Equal(t, expect, s) } func TestASOrderedCollectionPage(t *testing.T) { @@ -192,7 +194,7 @@ func TestASOrderedCollectionPage(t *testing.T) { s := toJSON(p) // Ensure outputs are equal. - assert.Equal(t, s, expect) + assert.Equal(t, expect, s) } func parseURI(s string) *url.URL { diff --git a/internal/ap/collections.go b/internal/ap/collections.go index e86d989ff..ba3887a5b 100644 --- a/internal/ap/collections.go +++ b/internal/ap/collections.go @@ -20,7 +20,6 @@ package ap import ( "fmt" "net/url" - "strconv" "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" @@ -169,6 +168,10 @@ type CollectionParams struct { // ID (i.e. NOT the page). ID *url.URL + // First page details. + First paging.Page + Query url.Values + // Total no. items. Total int } @@ -224,7 +227,7 @@ type ItemsPropertyBuilder interface { // NewASCollection builds and returns a new ActivityStreams Collection from given parameters. func NewASCollection(params CollectionParams) vocab.ActivityStreamsCollection { collection := streams.NewActivityStreamsCollection() - buildCollection(collection, params, 40) + buildCollection(collection, params) return collection } @@ -239,7 +242,7 @@ func NewASCollectionPage(params CollectionPageParams) vocab.ActivityStreamsColle // NewASOrderedCollection builds and returns a new ActivityStreams OrderedCollection from given parameters. func NewASOrderedCollection(params CollectionParams) vocab.ActivityStreamsOrderedCollection { collection := streams.NewActivityStreamsOrderedCollection() - buildCollection(collection, params, 40) + buildCollection(collection, params) return collection } @@ -251,7 +254,7 @@ func NewASOrderedCollectionPage(params CollectionPageParams) vocab.ActivityStrea return collectionPage } -func buildCollection[C CollectionBuilder](collection C, params CollectionParams, pageLimit int) { +func buildCollection[C CollectionBuilder](collection C, params CollectionParams) { // Add the collection ID property. idProp := streams.NewJSONLDIdProperty() idProp.SetIRI(params.ID) @@ -262,15 +265,20 @@ func buildCollection[C CollectionBuilder](collection C, params CollectionParams, totalItems.Set(params.Total) collection.SetActivityStreamsTotalItems(totalItems) - // Clone the collection ID page - // to add first page query data. - firstIRI := new(url.URL) - *firstIRI = *params.ID + // Append paging query params + // to those already in ID prop. + pageQueryParams := appendQuery( + params.Query, + params.ID.Query(), + ) - // Note that simply adding a limit signals to our - // endpoint to use paging (which will start at beginning). - limit := "limit=" + strconv.Itoa(pageLimit) - firstIRI.RawQuery = appendQuery(firstIRI.RawQuery, limit) + // Build the first page link IRI. + firstIRI := params.First.ToLinkURL( + params.ID.Scheme, + params.ID.Host, + params.ID.Path, + pageQueryParams, + ) // Add the collection first IRI property. first := streams.NewActivityStreamsFirstProperty() @@ -284,12 +292,19 @@ func buildCollectionPage[C CollectionPageBuilder, I ItemsPropertyBuilder](collec partOfProp.SetIRI(params.ID) collectionPage.SetActivityStreamsPartOf(partOfProp) + // Append paging query params + // to those already in ID prop. + pageQueryParams := appendQuery( + params.Query, + params.ID.Query(), + ) + // Build the current page link IRI. currentIRI := params.Current.ToLinkURL( params.ID.Scheme, params.ID.Host, params.ID.Path, - params.Query, + pageQueryParams, ) // Add the collection ID property for @@ -303,7 +318,7 @@ func buildCollectionPage[C CollectionPageBuilder, I ItemsPropertyBuilder](collec params.ID.Scheme, params.ID.Host, params.ID.Path, - params.Query, + pageQueryParams, ) if nextIRI != nil { @@ -318,7 +333,7 @@ func buildCollectionPage[C CollectionPageBuilder, I ItemsPropertyBuilder](collec params.ID.Scheme, params.ID.Host, params.ID.Path, - params.Query, + pageQueryParams, ) if prevIRI != nil { @@ -349,11 +364,13 @@ func buildCollectionPage[C CollectionPageBuilder, I ItemsPropertyBuilder](collec setItems(itemsProp) } -// appendQuery appends part to an existing raw -// query with ampersand, else just returning part. -func appendQuery(raw, part string) string { - if raw != "" { - return raw + "&" + part +// appendQuery appends query values in 'src' to 'dst', returning 'dst'. +func appendQuery(dst, src url.Values) url.Values { + if dst == nil { + return src } - return part + for k, vs := range src { + dst[k] = append(dst[k], vs...) + } + return dst } diff --git a/internal/api/activitypub/users/repliesget.go b/internal/api/activitypub/users/repliesget.go index fd9dc090b..3ac4ccbbb 100644 --- a/internal/api/activitypub/users/repliesget.go +++ b/internal/api/activitypub/users/repliesget.go @@ -20,14 +20,13 @@ package users import ( "encoding/json" "errors" - "fmt" "net/http" - "strconv" "strings" "github.com/gin-gonic/gin" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // StatusRepliesGETHandler swagger:operation GET /users/{username}/statuses/{status}/replies s2sRepliesGet @@ -120,36 +119,43 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) { return } - var page bool - if pageString := c.Query(PageKey); pageString != "" { - i, err := strconv.ParseBool(pageString) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", PageKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - page = i + // Look for supplied 'only_other_accounts' query key. + onlyOtherAccounts, errWithCode := apiutil.ParseOnlyOtherAccounts( + c.Query(apiutil.OnlyOtherAccountsKey), + true, // default = enabled + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - onlyOtherAccounts := false - onlyOtherAccountsString := c.Query(OnlyOtherAccountsKey) - if onlyOtherAccountsString != "" { - i, err := strconv.ParseBool(onlyOtherAccountsString) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", OnlyOtherAccountsKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - onlyOtherAccounts = i + // Look for given paging query parameters. + page, errWithCode := paging.ParseIDPage(c, + 1, // min limit + 40, // max limit + 0, // default = disabled + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - minID := "" - minIDString := c.Query(MinIDKey) - if minIDString != "" { - minID = minIDString + // COMPATIBILITY FIX: 'page=true' enables paging. + if page == nil && c.Query("page") == "true" { + page = new(paging.Page) + page.Max = paging.MaxID("") + page.Min = paging.MinID("") + page.Limit = 20 // default } - resp, errWithCode := m.processor.Fedi().StatusRepliesGet(c.Request.Context(), requestedUsername, requestedStatusID, page, onlyOtherAccounts, c.Query("only_other_accounts") != "", minID) + // Fetch serialized status replies response for input status. + resp, errWithCode := m.processor.Fedi().StatusRepliesGet( + c.Request.Context(), + requestedUsername, + requestedStatusID, + page, + onlyOtherAccounts, + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return @@ -157,7 +163,8 @@ func (m *Module) StatusRepliesGETHandler(c *gin.Context) { b, err := json.Marshal(resp) if err != nil { - apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) + errWithCode := gtserror.NewErrorInternalError(err) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } diff --git a/internal/api/activitypub/users/repliesget_test.go b/internal/api/activitypub/users/repliesget_test.go index 26492d8ce..ac25f3617 100644 --- a/internal/api/activitypub/users/repliesget_test.go +++ b/internal/api/activitypub/users/repliesget_test.go @@ -18,10 +18,10 @@ package users_test import ( + "bytes" "context" "encoding/json" - "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -31,6 +31,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" + "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -49,7 +50,7 @@ func (suite *RepliesGetTestSuite) TestGetReplies() { // setup request recorder := httptest.NewRecorder() ctx, _ := testrig.CreateGinTestContext(recorder, nil) - ctx.Request = httptest.NewRequest(http.MethodGet, targetStatus.URI+"/replies", nil) // the endpoint we're hitting + ctx.Request = httptest.NewRequest(http.MethodGet, targetStatus.URI+"/replies?only_other_accounts=false", nil) // the endpoint we're hitting ctx.Request.Header.Set("accept", "application/activity+json") ctx.Request.Header.Set("Signature", signedRequest.SignatureHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader) @@ -76,13 +77,26 @@ func (suite *RepliesGetTestSuite) TestGetReplies() { // check response suite.EqualValues(http.StatusOK, recorder.Code) + // Read response body. result := recorder.Result() defer result.Body.Close() - b, err := ioutil.ReadAll(result.Body) + b, err := io.ReadAll(result.Body) assert.NoError(suite.T(), err) - assert.Equal(suite.T(), `{"@context":"https://www.w3.org/ns/activitystreams","first":{"id":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?page=true","next":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?only_other_accounts=false\u0026page=true","partOf":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies","type":"CollectionPage"},"id":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies","type":"Collection"}`, string(b)) - // should be a Collection + // Indent JSON + // for readability. + b = indentJSON(b) + + // Create JSON string of expected output. + expect := toJSON(map[string]any{ + "@context": "https://www.w3.org/ns/activitystreams", + "type": "OrderedCollection", + "id": targetStatus.URI + "/replies?only_other_accounts=false", + "first": targetStatus.URI + "/replies?limit=20&only_other_accounts=false", + "totalItems": 1, + }) + assert.Equal(suite.T(), expect, string(b)) + m := make(map[string]interface{}) err = json.Unmarshal(b, &m) assert.NoError(suite.T(), err) @@ -90,7 +104,7 @@ func (suite *RepliesGetTestSuite) TestGetReplies() { t, err := streams.ToType(context.Background(), m) assert.NoError(suite.T(), err) - _, ok := t.(vocab.ActivityStreamsCollection) + _, ok := t.(vocab.ActivityStreamsOrderedCollection) assert.True(suite.T(), ok) } @@ -131,14 +145,29 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() { // check response suite.EqualValues(http.StatusOK, recorder.Code) + // Read response body. result := recorder.Result() defer result.Body.Close() - b, err := ioutil.ReadAll(result.Body) + b, err := io.ReadAll(result.Body) assert.NoError(suite.T(), err) - assert.Equal(suite.T(), `{"@context":"https://www.w3.org/ns/activitystreams","id":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?page=true\u0026only_other_accounts=false","items":"http://localhost:8080/users/admin/statuses/01FF25D5Q0DH7CHD57CTRS6WK0","next":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?only_other_accounts=false\u0026page=true\u0026min_id=01FF25D5Q0DH7CHD57CTRS6WK0","partOf":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies","type":"CollectionPage"}`, string(b)) + // Indent JSON + // for readability. + b = indentJSON(b) + + // Create JSON string of expected output. + expect := toJSON(map[string]any{ + "@context": "https://www.w3.org/ns/activitystreams", + "type": "OrderedCollectionPage", + "id": targetStatus.URI + "/replies?limit=20&only_other_accounts=false", + "partOf": targetStatus.URI + "/replies?only_other_accounts=false", + "next": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?limit=20&min_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false", + "prev": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?limit=20&max_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false", + "orderedItems": "http://localhost:8080/users/admin/statuses/01FF25D5Q0DH7CHD57CTRS6WK0", + "totalItems": 1, + }) + assert.Equal(suite.T(), expect, string(b)) - // should be a Collection m := make(map[string]interface{}) err = json.Unmarshal(b, &m) assert.NoError(suite.T(), err) @@ -146,10 +175,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() { t, err := streams.ToType(context.Background(), m) assert.NoError(suite.T(), err) - page, ok := t.(vocab.ActivityStreamsCollectionPage) + page, ok := t.(vocab.ActivityStreamsOrderedCollectionPage) assert.True(suite.T(), ok) - assert.Equal(suite.T(), page.GetActivityStreamsItems().Len(), 1) + assert.Equal(suite.T(), page.GetActivityStreamsOrderedItems().Len(), 1) } func (suite *RepliesGetTestSuite) TestGetRepliesLast() { @@ -162,7 +191,7 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() { // setup request recorder := httptest.NewRecorder() ctx, _ := testrig.CreateGinTestContext(recorder, nil) - ctx.Request = httptest.NewRequest(http.MethodGet, targetStatus.URI+"/replies?only_other_accounts=false&page=true&min_id=01FF25D5Q0DH7CHD57CTRS6WK0", nil) // the endpoint we're hitting + ctx.Request = httptest.NewRequest(http.MethodGet, targetStatus.URI+"/replies?min_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false", nil) ctx.Request.Header.Set("accept", "application/activity+json") ctx.Request.Header.Set("Signature", signedRequest.SignatureHeader) ctx.Request.Header.Set("Date", signedRequest.DateHeader) @@ -189,15 +218,27 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() { // check response suite.EqualValues(http.StatusOK, recorder.Code) + // Read response body. result := recorder.Result() defer result.Body.Close() - b, err := ioutil.ReadAll(result.Body) + b, err := io.ReadAll(result.Body) assert.NoError(suite.T(), err) - fmt.Println(string(b)) - assert.Equal(suite.T(), `{"@context":"https://www.w3.org/ns/activitystreams","id":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?page=true\u0026only_other_accounts=false\u0026min_id=01FF25D5Q0DH7CHD57CTRS6WK0","items":[],"next":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies?only_other_accounts=false\u0026page=true","partOf":"http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/replies","type":"CollectionPage"}`, string(b)) + // Indent JSON + // for readability. + b = indentJSON(b) + + // Create JSON string of expected output. + expect := toJSON(map[string]any{ + "@context": "https://www.w3.org/ns/activitystreams", + "type": "OrderedCollectionPage", + "id": targetStatus.URI + "/replies?min_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false", + "partOf": targetStatus.URI + "/replies?only_other_accounts=false", + "orderedItems": []any{}, // empty + "totalItems": 1, + }) + assert.Equal(suite.T(), expect, string(b)) - // should be a Collection m := make(map[string]interface{}) err = json.Unmarshal(b, &m) assert.NoError(suite.T(), err) @@ -205,12 +246,39 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() { t, err := streams.ToType(context.Background(), m) assert.NoError(suite.T(), err) - page, ok := t.(vocab.ActivityStreamsCollectionPage) + page, ok := t.(vocab.ActivityStreamsOrderedCollectionPage) assert.True(suite.T(), ok) - assert.Equal(suite.T(), page.GetActivityStreamsItems().Len(), 0) + assert.Equal(suite.T(), page.GetActivityStreamsOrderedItems().Len(), 0) } func TestRepliesGetTestSuite(t *testing.T) { suite.Run(t, new(RepliesGetTestSuite)) } + +// toJSON will return indented JSON serialized form of 'a'. +func toJSON(a any) string { + v, ok := a.(vocab.Type) + if ok { + m, err := ap.Serialize(v) + if err != nil { + panic(err) + } + a = m + } + b, err := json.MarshalIndent(a, "", " ") + if err != nil { + panic(err) + } + return string(b) +} + +// indentJSON will return indented JSON from raw provided JSON. +func indentJSON(b []byte) []byte { + var dst bytes.Buffer + err := json.Indent(&dst, b, "", " ") + if err != nil { + panic(err) + } + return dst.Bytes() +} diff --git a/internal/api/util/parsequery.go b/internal/api/util/parsequery.go index 6a9116dcf..da6320b67 100644 --- a/internal/api/util/parsequery.go +++ b/internal/api/util/parsequery.go @@ -41,6 +41,10 @@ const ( SinceIDKey = "since_id" MinIDKey = "min_id" + /* AP endpoint keys */ + + OnlyOtherAccountsKey = "only_other_accounts" + /* Search keys */ SearchExcludeUnreviewedKey = "exclude_unreviewed" @@ -66,20 +70,6 @@ const ( DomainPermissionImportKey = "import" ) -// parseError returns gtserror.WithCode set to 400 Bad Request, to indicate -// to the caller that a key was set to a value that could not be parsed. -func parseError(key string, value, defaultValue any, err error) gtserror.WithCode { - err = fmt.Errorf("error parsing key %s with value %s as %T: %w", key, value, defaultValue, err) - return gtserror.NewErrorBadRequest(err, err.Error()) -} - -// requiredError returns gtserror.WithCode set to 400 Bad Request, to indicate -// to the caller a required key value was not provided, or was empty. -func requiredError(key string) gtserror.WithCode { - err := fmt.Errorf("required key %s was not set or had empty value", key) - return gtserror.NewErrorBadRequest(err, err.Error()) -} - /* Parse functions for *OPTIONAL* parameters with default values. */ @@ -129,6 +119,10 @@ func ParseDomainPermissionImport(value string, defaultValue bool) (bool, gtserro return parseBool(value, defaultValue, DomainPermissionImportKey) } +func ParseOnlyOtherAccounts(value string, defaultValue bool) (bool, gtserror.WithCode) { + return parseBool(value, defaultValue, OnlyOtherAccountsKey) +} + /* Parse functions for *REQUIRED* parameters. */ @@ -248,3 +242,17 @@ func parseInt(value string, defaultValue int, max int, min int, key string) (int return i, nil } + +// parseError returns gtserror.WithCode set to 400 Bad Request, to indicate +// to the caller that a key was set to a value that could not be parsed. +func parseError(key string, value, defaultValue any, err error) gtserror.WithCode { + err = fmt.Errorf("error parsing key %s with value %s as %T: %w", key, value, defaultValue, err) + return gtserror.NewErrorBadRequest(err, err.Error()) +} + +// requiredError returns gtserror.WithCode set to 400 Bad Request, to indicate +// to the caller a required key value was not provided, or was empty. +func requiredError(key string) gtserror.WithCode { + err := fmt.Errorf("required key %s was not set or had empty value", key) + return gtserror.NewErrorBadRequest(err, err.Error()) +} diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 80346412c..dd161e1ec 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -18,7 +18,6 @@ package bundb import ( - "container/list" "context" "errors" "time" @@ -515,16 +514,7 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([ return s.GetStatusesByIDs(ctx, statusIDs) } -func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) { - if onlyDirect { - // Only want the direct parent, no further than first level - parent, err := s.GetStatusByID(ctx, status.InReplyToID) - if err != nil { - return nil, err - } - return []*gtsmodel.Status{parent}, nil - } - +func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) { var parents []*gtsmodel.Status for id := status.InReplyToID; id != ""; { @@ -533,7 +523,7 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status return nil, err } - // Append parent to slice + // Append parent status to slice parents = append(parents, parent) // Set the next parent ID @@ -543,67 +533,33 @@ func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status return parents, nil } -func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) { - foundStatuses := &list.List{} - foundStatuses.PushFront(status) - s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID) +func (s *statusDB) GetStatusChildren(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { + // Get all replies for the currently set status. + replies, err := s.GetStatusReplies(ctx, statusID) + if err != nil { + return nil, err + } - children := []*gtsmodel.Status{} - for e := foundStatuses.Front(); e != nil; e = e.Next() { - // only append children, not the overall parent status - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - log.Panic(ctx, "found status could not be asserted to *gtsmodel.Status") + // Make estimated preallocation based on direct replies. + children := make([]*gtsmodel.Status, 0, len(replies)*2) + + for _, status := range replies { + // Append status to children. + children = append(children, status) + + // Further, recursively get all children for this reply. + grandChildren, err := s.GetStatusChildren(ctx, status.ID) + if err != nil { + return nil, err } - if entry.ID != status.ID { - children = append(children, entry) - } + // Append all sub children after status. + children = append(children, grandChildren...) } return children, nil } -func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { - childIDs, err := s.getStatusReplyIDs(ctx, status.ID) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - log.Errorf(ctx, "error getting status %s children: %v", status.ID, err) - return - } - - for _, id := range childIDs { - if id <= minID { - continue - } - - // Fetch child with ID from database - child, err := s.GetStatusByID(ctx, id) - if err != nil { - log.Errorf(ctx, "error getting child status %q: %v", id, err) - continue - } - - insertLoop: - for e := foundStatuses.Front(); e != nil; e = e.Next() { - entry, ok := e.Value.(*gtsmodel.Status) - if !ok { - log.Panic(ctx, "found status could not be asserted to *gtsmodel.Status") - } - - if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID { - foundStatuses.InsertAfter(child, e) - break insertLoop - } - } - - // if we're not only looking for direct children of status, then do the same children-finding - // operation for the found child status too. - if !onlyDirect { - s.statusChildren(ctx, child, foundStatuses, false, minID) - } - } -} - func (s *statusDB) GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) { statusIDs, err := s.getStatusReplyIDs(ctx, statusID) if err != nil { diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index a69608796..c0ff6c0da 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -163,9 +163,21 @@ func (suite *StatusTestSuite) TestGetStatusTwice() { suite.Less(duration2, duration1) } +func (suite *StatusTestSuite) TestGetStatusReplies() { + targetStatus := suite.testStatuses["local_account_1_status_1"] + children, err := suite.db.GetStatusReplies(context.Background(), targetStatus.ID) + suite.NoError(err) + suite.Len(children, 2) + for _, c := range children { + suite.Equal(targetStatus.URI, c.InReplyToURI) + suite.Equal(targetStatus.AccountID, c.InReplyToAccountID) + suite.Equal(targetStatus.ID, c.InReplyToID) + } +} + func (suite *StatusTestSuite) TestGetStatusChildren() { targetStatus := suite.testStatuses["local_account_1_status_1"] - children, err := suite.db.GetStatusChildren(context.Background(), targetStatus, true, "") + children, err := suite.db.GetStatusChildren(context.Background(), targetStatus.ID) suite.NoError(err) suite.Len(children, 2) for _, c := range children { diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index 1d820d081..a2bc87b88 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -18,6 +18,7 @@ package bundb import ( + "slices" "strings" "github.com/superseriousbusiness/gotosocial/internal/cache" @@ -99,7 +100,7 @@ func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page // order. Depending on the paging requested // this may be an unexpected order. if page.GetOrder().Ascending() { - ids = paging.Reverse(ids) + slices.Reverse(ids) } // Page the resulting IDs. diff --git a/internal/db/status.go b/internal/db/status.go index 0be37421a..1ebf503a8 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -55,7 +55,7 @@ type Status interface { // GetStatusesUsingEmoji fetches all status models using emoji with given ID stored in their 'emojis' column. GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Status, error) - // GetStatusReplies returns the *direct* (i.e. in_reply_to_id column) replies to this status ID. + // GetStatusReplies returns the *direct* (i.e. in_reply_to_id column) replies to this status ID, ordered DESC by ID. GetStatusReplies(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) // CountStatusReplies returns the number of stored *direct* (i.e. in_reply_to_id column) replies to this status ID. @@ -71,14 +71,10 @@ type Status interface { IsStatusBoostedBy(ctx context.Context, statusID string, accountID string) (bool, error) // GetStatusParents gets the parent statuses of a given status. - // - // If onlyDirect is true, only the immediate parent will be returned. - GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) + GetStatusParents(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, error) // GetStatusChildren gets the child statuses of a given status. - // - // If onlyDirect is true, only the immediate children will be returned. - GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) + GetStatusChildren(ctx context.Context, statusID string) ([]*gtsmodel.Status, error) // IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, error) diff --git a/internal/paging/boundary.go b/internal/paging/boundary.go index 15af65e0c..83d265515 100644 --- a/internal/paging/boundary.go +++ b/internal/paging/boundary.go @@ -131,3 +131,20 @@ func (b Boundary) Find(in []string) int { } return -1 } + +// Boundary_FindFunc is functionally equivalent to Boundary{}.Find() but for an arbitrary type with ID. +// Note: this is not a Boundary{} method as Go generics are not supported in method receiver functions. +func Boundary_FindFunc[T any](b Boundary, in []T, get func(T) string) int { //nolint:revive + if get == nil { + panic("nil function") + } + if b.Value == "" { + return -1 + } + for i := range in { + if get(in[i]) == b.Value { + return i + } + } + return -1 +} diff --git a/internal/paging/page.go b/internal/paging/page.go index a56f674dd..a1cf76c74 100644 --- a/internal/paging/page.go +++ b/internal/paging/page.go @@ -19,9 +19,8 @@ package paging import ( "net/url" + "slices" "strconv" - - "golang.org/x/exp/slices" ) type Page struct { @@ -117,7 +116,7 @@ func (p *Page) Page(in []string) []string { // Output slice must // ALWAYS be descending. - in = Reverse(in) + slices.Reverse(in) } } else { // Default sort is descending, @@ -143,6 +142,66 @@ func (p *Page) Page(in []string) []string { return in } +// Page_PageFunc is functionally equivalent to Page{}.Page(), but for an arbitrary type with ID. +// Note: this is not a Page{} method as Go generics are not supported in method receiver functions. +func Page_PageFunc[WithID any](p *Page, in []WithID, get func(WithID) string) []WithID { //nolint:revive + if p == nil { + // no paging. + return in + } + + if p.order().Ascending() { + // Sort type is ascending, input + // data is assumed to be ascending. + + if minIdx := Boundary_FindFunc(p.Min, in, get); minIdx != -1 { + // Reslice skipping up to min. + in = in[minIdx+1:] + } + + if maxIdx := Boundary_FindFunc(p.Max, in, get); maxIdx != -1 { + // Reslice stripping past max. + in = in[:maxIdx] + } + + if p.Limit > 0 && p.Limit < len(in) { + // Reslice input to limit. + in = in[:p.Limit] + } + + if len(in) > 1 { + // Clone input before + // any modifications. + in = slices.Clone(in) + + // Output slice must + // ALWAYS be descending. + slices.Reverse(in) + } + } else { + // Default sort is descending, + // catching all cases when NOT + // ascending (even zero value). + + if maxIdx := Boundary_FindFunc(p.Max, in, get); maxIdx != -1 { + // Reslice skipping up to max. + in = in[maxIdx+1:] + } + + if minIdx := Boundary_FindFunc(p.Min, in, get); minIdx != -1 { + // Reslice stripping past min. + in = in[:minIdx] + } + + if p.Limit > 0 && p.Limit < len(in) { + // Reslice input to limit. + in = in[:p.Limit] + } + } + + return in +} + // Next creates a new instance for the next returnable page, using // given max value. This preserves original limit and max key name. func (p *Page) Next(lo, hi string) *Page { @@ -225,21 +284,24 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url. if queryParams == nil { // Allocate new query parameters. queryParams = make(url.Values) + } else { + // Before edit clone existing params. + queryParams = cloneQuery(queryParams) } if p.Min.Value != "" { // A page-minimum query parameter is available. - queryParams.Add(p.Min.Name, p.Min.Value) + queryParams.Set(p.Min.Name, p.Min.Value) } if p.Max.Value != "" { // A page-maximum query parameter is available. - queryParams.Add(p.Max.Name, p.Max.Value) + queryParams.Set(p.Max.Name, p.Max.Value) } if p.Limit > 0 { // A page limit query parameter is available. - queryParams.Add("limit", strconv.Itoa(p.Limit)) + queryParams.Set("limit", strconv.Itoa(p.Limit)) } // Build URL string. @@ -250,3 +312,12 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url. RawQuery: queryParams.Encode(), } } + +// cloneQuery clones input map of url values. +func cloneQuery(src url.Values) url.Values { + dst := make(url.Values, len(src)) + for k, vs := range src { + dst[k] = slices.Clone(vs) + } + return dst +} diff --git a/internal/paging/page_test.go b/internal/paging/page_test.go index 9eeb90882..3046dfcdd 100644 --- a/internal/paging/page_test.go +++ b/internal/paging/page_test.go @@ -19,12 +19,12 @@ package paging_test import ( "math/rand" + "slices" "testing" "time" "github.com/oklog/ulid" "github.com/superseriousbusiness/gotosocial/internal/paging" - "golang.org/x/exp/slices" ) // random reader according to current-time source seed. @@ -77,9 +77,7 @@ func TestPage(t *testing.T) { var cases = []Case{ CreateCase("minID and maxID set", func(ids []string) ([]string, *paging.Page, []string) { // Ensure input slice sorted ascending for min_id - slices.SortFunc(ids, func(a, b string) bool { - return a > b // i.e. largest at lowest idx - }) + slices.SortFunc(ids, ascending) // Select random indices in slice. minIdx := randRd.Intn(len(ids)) @@ -93,7 +91,7 @@ var cases = []Case{ expect := slices.Clone(ids) expect = cutLower(expect, minID) expect = cutUpper(expect, maxID) - expect = paging.Reverse(expect) + slices.Reverse(expect) // Return page and expected IDs. return ids, &paging.Page{ @@ -103,9 +101,7 @@ var cases = []Case{ }), CreateCase("minID, maxID and limit set", func(ids []string) ([]string, *paging.Page, []string) { // Ensure input slice sorted ascending for min_id - slices.SortFunc(ids, func(a, b string) bool { - return a > b // i.e. largest at lowest idx - }) + slices.SortFunc(ids, ascending) // Select random parameters in slice. minIdx := randRd.Intn(len(ids)) @@ -120,7 +116,7 @@ var cases = []Case{ expect := slices.Clone(ids) expect = cutLower(expect, minID) expect = cutUpper(expect, maxID) - expect = paging.Reverse(expect) + slices.Reverse(expect) // Now limit the slice. if limit < len(expect) { @@ -136,9 +132,7 @@ var cases = []Case{ }), CreateCase("minID, maxID and too-large limit set", func(ids []string) ([]string, *paging.Page, []string) { // Ensure input slice sorted ascending for min_id - slices.SortFunc(ids, func(a, b string) bool { - return a > b // i.e. largest at lowest idx - }) + slices.SortFunc(ids, ascending) // Select random parameters in slice. minIdx := randRd.Intn(len(ids)) @@ -152,7 +146,7 @@ var cases = []Case{ expect := slices.Clone(ids) expect = cutLower(expect, minID) expect = cutUpper(expect, maxID) - expect = paging.Reverse(expect) + slices.Reverse(expect) // Return page and expected IDs. return ids, &paging.Page{ @@ -163,9 +157,7 @@ var cases = []Case{ }), CreateCase("sinceID and maxID set", func(ids []string) ([]string, *paging.Page, []string) { // Ensure input slice sorted descending for since_id - slices.SortFunc(ids, func(a, b string) bool { - return a < b // i.e. smallest at lowest idx - }) + slices.SortFunc(ids, descending) // Select random indices in slice. sinceIdx := randRd.Intn(len(ids)) @@ -188,9 +180,7 @@ var cases = []Case{ }), CreateCase("maxID set", func(ids []string) ([]string, *paging.Page, []string) { // Ensure input slice sorted descending for max_id - slices.SortFunc(ids, func(a, b string) bool { - return a < b // i.e. smallest at lowest idx - }) + slices.SortFunc(ids, descending) // Select random indices in slice. maxIdx := randRd.Intn(len(ids)) @@ -209,9 +199,7 @@ var cases = []Case{ }), CreateCase("sinceID set", func(ids []string) ([]string, *paging.Page, []string) { // Ensure input slice sorted descending for since_id - slices.SortFunc(ids, func(a, b string) bool { - return a < b - }) + slices.SortFunc(ids, descending) // Select random indices in slice. sinceIdx := randRd.Intn(len(ids)) @@ -230,9 +218,7 @@ var cases = []Case{ }), CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) { // Ensure input slice sorted ascending for min_id - slices.SortFunc(ids, func(a, b string) bool { - return a > b // i.e. largest at lowest idx - }) + slices.SortFunc(ids, ascending) // Select random indices in slice. minIdx := randRd.Intn(len(ids)) @@ -243,7 +229,7 @@ var cases = []Case{ // Create expected output. expect := slices.Clone(ids) expect = cutLower(expect, minID) - expect = paging.Reverse(expect) + slices.Reverse(expect) // Return page and expected IDs. return ids, &paging.Page{ @@ -296,3 +282,21 @@ func generateSlice(len int) []string { } return in } + +func ascending(a, b string) int { + if a > b { + return 1 + } else if a < b { + return -1 + } + return 0 +} + +func descending(a, b string) int { + if a < b { + return 1 + } else if a > b { + return -1 + } + return 0 +} diff --git a/internal/paging/util.go b/internal/paging/util.go deleted file mode 100644 index dd941dd88..000000000 --- a/internal/paging/util.go +++ /dev/null @@ -1,43 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package paging - -// Reverse will reverse the given input slice. -func Reverse(in []string) []string { - var ( - // Start at front. - i = 0 - - // Start at back. - j = len(in) - 1 - ) - - for i < j { - // Swap i,j index values in slice. - in[i], in[j] = in[j], in[i] - - // incr + decr, - // looping until - // they meet in - // the middle. - i++ - j-- - } - - return in -} diff --git a/internal/processing/fedi/collections.go b/internal/processing/fedi/collections.go index cbabbfdd6..ccca10754 100644 --- a/internal/processing/fedi/collections.go +++ b/internal/processing/fedi/collections.go @@ -47,8 +47,15 @@ func (p *Processor) InboxPost(ctx context.Context, w http.ResponseWriter, r *htt // OutboxGet returns the activitypub representation of a local user's outbox. // This contains links to PUBLIC posts made by this user. -func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, page bool, maxID string, minID string) (interface{}, gtserror.WithCode) { - requestedAccount, _, errWithCode := p.authenticate(ctx, requestedUsername) +func (p *Processor) OutboxGet( + ctx context.Context, + requestedUser string, + page bool, + maxID string, + minID string, +) (interface{}, gtserror.WithCode) { + // Authenticate the incoming request, getting related user accounts. + _, receiver, errWithCode := p.authenticate(ctx, requestedUser) if errWithCode != nil { return nil, errWithCode } @@ -70,7 +77,7 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag "last": "https://example.org/users/whatever/outbox?min_id=0&page=true" } */ - collection, err := p.converter.OutboxToASCollection(ctx, requestedAccount.OutboxURI) + collection, err := p.converter.OutboxToASCollection(ctx, receiver.OutboxURI) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -85,15 +92,16 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag // scenario 2 -- get the requested page // limit pages to 30 entries per page - publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true) + publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, receiver.ID, 30, true, true, maxID, minID, false, true) if err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, gtserror.NewErrorInternalError(err) } - outboxPage, err := p.converter.StatusesToASOutboxPage(ctx, requestedAccount.OutboxURI, maxID, minID, publicStatuses) + outboxPage, err := p.converter.StatusesToASOutboxPage(ctx, receiver.OutboxURI, maxID, minID, publicStatuses) if err != nil { return nil, gtserror.NewErrorInternalError(err) } + data, err = ap.Serialize(outboxPage) if err != nil { return nil, gtserror.NewErrorInternalError(err) @@ -104,21 +112,22 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag // FollowersGet handles the getting of a fedi/activitypub representation of a user/account's followers, performing appropriate // authentication before returning a JSON serializable interface to the caller. -func (p *Processor) FollowersGet(ctx context.Context, requestedUsername string, page *paging.Page) (interface{}, gtserror.WithCode) { - requestedAccount, _, errWithCode := p.authenticate(ctx, requestedUsername) +func (p *Processor) FollowersGet(ctx context.Context, requestedUser string, page *paging.Page) (interface{}, gtserror.WithCode) { + // Authenticate the incoming request, getting related user accounts. + _, receiver, errWithCode := p.authenticate(ctx, requestedUser) if errWithCode != nil { return nil, errWithCode } // Parse the collection ID object from account's followers URI. - collectionID, err := url.Parse(requestedAccount.FollowersURI) + collectionID, err := url.Parse(receiver.FollowersURI) if err != nil { - err := gtserror.Newf("error parsing account followers uri %s: %w", requestedAccount.FollowersURI, err) + err := gtserror.Newf("error parsing account followers uri %s: %w", receiver.FollowersURI, err) return nil, gtserror.NewErrorInternalError(err) } // Calculate total number of followers available for account. - total, err := p.state.DB.CountAccountFollowers(ctx, requestedAccount.ID) + total, err := p.state.DB.CountAccountFollowers(ctx, receiver.ID) if err != nil { err := gtserror.Newf("error counting followers: %w", err) return nil, gtserror.NewErrorInternalError(err) @@ -126,30 +135,36 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUsername string, var obj vocab.Type - // Start building AS collection params. + // Start the AS collection params. var params ap.CollectionParams params.ID = collectionID params.Total = total if page == nil { - // i.e. paging disabled, the simplest case. - // - // Just build collection object from params. + // i.e. paging disabled, return collection + // that links to first page (i.e. path below). + params.Query = make(url.Values, 1) + params.Query.Set("limit", "40") // enables paging obj = ap.NewASOrderedCollection(params) } else { // i.e. paging enabled // Get the request page of full follower objects with attached accounts. - followers, err := p.state.DB.GetAccountFollowers(ctx, requestedAccount.ID, page) + followers, err := p.state.DB.GetAccountFollowers(ctx, receiver.ID, page) if err != nil { err := gtserror.Newf("error getting followers: %w", err) return nil, gtserror.NewErrorInternalError(err) } - // Get the lowest and highest - // ID values, used for paging. - lo := followers[len(followers)-1].ID - hi := followers[0].ID + // page ID values. + var lo, hi string + + if len(followers) > 0 { + // Get the lowest and highest + // ID values, used for paging. + lo = followers[len(followers)-1].ID + hi = followers[0].ID + } // Start building AS collection page params. var pageParams ap.CollectionPageParams @@ -196,21 +211,22 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUsername string, // FollowingGet handles the getting of a fedi/activitypub representation of a user/account's following, performing appropriate // authentication before returning a JSON serializable interface to the caller. -func (p *Processor) FollowingGet(ctx context.Context, requestedUsername string, page *paging.Page) (interface{}, gtserror.WithCode) { - requestedAccount, _, errWithCode := p.authenticate(ctx, requestedUsername) +func (p *Processor) FollowingGet(ctx context.Context, requestedUser string, page *paging.Page) (interface{}, gtserror.WithCode) { + // Authenticate the incoming request, getting related user accounts. + _, receiver, errWithCode := p.authenticate(ctx, requestedUser) if errWithCode != nil { return nil, errWithCode } - // Parse the collection ID object from account's following URI. - collectionID, err := url.Parse(requestedAccount.FollowingURI) + // Parse collection ID from account's following URI. + collectionID, err := url.Parse(receiver.FollowingURI) if err != nil { - err := gtserror.Newf("error parsing account following uri %s: %w", requestedAccount.FollowingURI, err) + err := gtserror.Newf("error parsing account following uri %s: %w", receiver.FollowingURI, err) return nil, gtserror.NewErrorInternalError(err) } // Calculate total number of following available for account. - total, err := p.state.DB.CountAccountFollows(ctx, requestedAccount.ID) + total, err := p.state.DB.CountAccountFollows(ctx, receiver.ID) if err != nil { err := gtserror.Newf("error counting follows: %w", err) return nil, gtserror.NewErrorInternalError(err) @@ -218,32 +234,38 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUsername string, var obj vocab.Type - // Start building AS collection params. + // Start AS collection params. var params ap.CollectionParams params.ID = collectionID params.Total = total if page == nil { - // i.e. paging disabled, the simplest case. - // - // Just build collection object from params. + // i.e. paging disabled, return collection + // that links to first page (i.e. path below). + params.Query = make(url.Values, 1) + params.Query.Set("limit", "40") // enables paging obj = ap.NewASOrderedCollection(params) } else { // i.e. paging enabled // Get the request page of full follower objects with attached accounts. - follows, err := p.state.DB.GetAccountFollows(ctx, requestedAccount.ID, page) + follows, err := p.state.DB.GetAccountFollows(ctx, receiver.ID, page) if err != nil { err := gtserror.Newf("error getting follows: %w", err) return nil, gtserror.NewErrorInternalError(err) } - // Get the lowest and highest - // ID values, used for paging. - lo := follows[len(follows)-1].ID - hi := follows[0].ID + // page ID values. + var lo, hi string - // Start building AS collection page params. + if len(follows) > 0 { + // Get the lowest and highest + // ID values, used for paging. + lo = follows[len(follows)-1].ID + hi = follows[0].ID + } + + // Start AS collection page params. var pageParams ap.CollectionPageParams pageParams.CollectionParams = params @@ -288,20 +310,21 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUsername string, // FeaturedCollectionGet returns an ordered collection of the requested username's Pinned posts. // The returned collection have an `items` property which contains an ordered list of status URIs. -func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUsername string) (interface{}, gtserror.WithCode) { - requestedAccount, _, errWithCode := p.authenticate(ctx, requestedUsername) +func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUser string) (interface{}, gtserror.WithCode) { + // Authenticate the incoming request, getting related user accounts. + _, receiver, errWithCode := p.authenticate(ctx, requestedUser) if errWithCode != nil { return nil, errWithCode } - statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, requestedAccount.ID) + statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, receiver.ID) if err != nil { if !errors.Is(err, db.ErrNoEntries) { return nil, gtserror.NewErrorInternalError(err) } } - collection, err := p.converter.StatusesToASFeaturedCollection(ctx, requestedAccount.FeaturedCollectionURI, statuses) + collection, err := p.converter.StatusesToASFeaturedCollection(ctx, receiver.FeaturedCollectionURI, statuses) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/fedi/common.go b/internal/processing/fedi/common.go index c41f1e00c..f395ec3cf 100644 --- a/internal/processing/fedi/common.go +++ b/internal/processing/fedi/common.go @@ -20,7 +20,6 @@ package fedi import ( "context" "errors" - "fmt" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" @@ -28,17 +27,17 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *Processor) authenticate(ctx context.Context, requestedUsername string) ( - *gtsmodel.Account, // requestedAccount - *gtsmodel.Account, // requestingAccount +func (p *Processor) authenticate(ctx context.Context, requestedUser string) ( + *gtsmodel.Account, // requester: i.e. user making the request + *gtsmodel.Account, // receiver: i.e. the receiving inbox user gtserror.WithCode, ) { - // Get LOCAL account with the requested username. - requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "") + // First get the requested (receiving) LOCAL account with username from database. + receiver, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUser, "") if err != nil { if !errors.Is(err, db.ErrNoEntries) { // Real db error. - err = gtserror.Newf("db error getting account %s: %w", requestedUsername, err) + err = gtserror.Newf("db error getting account %s: %w", requestedUser, err) return nil, nil, gtserror.NewErrorInternalError(err) } @@ -46,41 +45,43 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string) return nil, nil, gtserror.NewErrorNotFound(err) } + var requester *gtsmodel.Account + // Ensure request signed, and use signature URI to // get requesting account, dereferencing if necessary. - pubKeyAuth, errWithCode := p.federator.AuthenticateFederatedRequest(ctx, requestedUsername) + pubKeyAuth, errWithCode := p.federator.AuthenticateFederatedRequest(ctx, requestedUser) if errWithCode != nil { return nil, nil, errWithCode } - requestingAccount, _, err := p.federator.GetAccountByURI( - gtscontext.SetFastFail(ctx), - requestedUsername, - pubKeyAuth.OwnerURI, - ) - if err != nil { - err = gtserror.Newf("error getting account %s: %w", pubKeyAuth.OwnerURI, err) - return nil, nil, gtserror.NewErrorUnauthorized(err) + if requester = pubKeyAuth.Owner; requester == nil { + requester, _, err = p.federator.GetAccountByURI( + gtscontext.SetFastFail(ctx), + requestedUser, + pubKeyAuth.OwnerURI, + ) + if err != nil { + err = gtserror.Newf("error getting account %s: %w", pubKeyAuth.OwnerURI, err) + return nil, nil, gtserror.NewErrorUnauthorized(err) + } } - if !requestingAccount.SuspendedAt.IsZero() { + if !requester.SuspendedAt.IsZero() { // Account was marked as suspended by a // local admin action. Stop request early. - err = fmt.Errorf("account %s marked as suspended", requestingAccount.ID) - return nil, nil, gtserror.NewErrorForbidden(err) + const text = "requesting account is suspended" + return nil, nil, gtserror.NewErrorForbidden(errors.New(text)) } // Ensure no block exists between requester + requested. - blocked, err := p.state.DB.IsEitherBlocked(ctx, requestedAccount.ID, requestingAccount.ID) + blocked, err := p.state.DB.IsEitherBlocked(ctx, receiver.ID, requester.ID) if err != nil { err = gtserror.Newf("db error getting checking block: %w", err) return nil, nil, gtserror.NewErrorInternalError(err) - } - - if blocked { - err = fmt.Errorf("block exists between accounts %s and %s", requestedAccount.ID, requestingAccount.ID) + } else if blocked { + err = gtserror.Newf("block exists between accounts %s and %s", requester.ID, receiver.ID) return nil, nil, gtserror.NewErrorForbidden(err) } - return requestedAccount, requestingAccount, nil + return requester, receiver, nil } diff --git a/internal/processing/fedi/fedi.go b/internal/processing/fedi/fedi.go index 11be26a3e..eeef94113 100644 --- a/internal/processing/fedi/fedi.go +++ b/internal/processing/fedi/fedi.go @@ -19,21 +19,33 @@ package fedi import ( "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/processing/common" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" ) type Processor struct { + // embed common logic + c *common.Processor + state *state.State federator *federation.Federator converter *typeutils.Converter filter *visibility.Filter } -// New returns a new fedi processor. -func New(state *state.State, converter *typeutils.Converter, federator *federation.Federator, filter *visibility.Filter) Processor { +// New returns a +// new fedi processor. +func New( + state *state.State, + common *common.Processor, + converter *typeutils.Converter, + federator *federation.Federator, + filter *visibility.Filter, +) Processor { return Processor{ + c: common, state: state, federator: federator, converter: converter, diff --git a/internal/processing/fedi/status.go b/internal/processing/fedi/status.go index c8534eb5e..b8b75841c 100644 --- a/internal/processing/fedi/status.go +++ b/internal/processing/fedi/status.go @@ -19,161 +19,192 @@ package fedi import ( "context" - "fmt" + "errors" "net/url" + "slices" + "strconv" + "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // StatusGet handles the getting of a fedi/activitypub representation of a local status. // It performs appropriate authentication before returning a JSON serializable interface. -func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, requestedStatusID string) (interface{}, gtserror.WithCode) { +func (p *Processor) StatusGet(ctx context.Context, requestedUser string, statusID string) (interface{}, gtserror.WithCode) { // Authenticate using http signature. - requestedAccount, requestingAccount, errWithCode := p.authenticate(ctx, requestedUsername) + // Authenticate the incoming request, getting related user accounts. + requester, receiver, errWithCode := p.authenticate(ctx, requestedUser) if errWithCode != nil { return nil, errWithCode } - status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID) + status, err := p.state.DB.GetStatusByID(ctx, statusID) if err != nil { return nil, gtserror.NewErrorNotFound(err) } - if status.AccountID != requestedAccount.ID { - err := fmt.Errorf("status with id %s does not belong to account with id %s", status.ID, requestedAccount.ID) - return nil, gtserror.NewErrorNotFound(err) + if status.AccountID != receiver.ID { + const text = "status does not belong to receiving account" + return nil, gtserror.NewErrorNotFound(errors.New(text)) } - visible, err := p.filter.StatusVisible(ctx, requestingAccount, status) + visible, err := p.filter.StatusVisible(ctx, requester, status) if err != nil { return nil, gtserror.NewErrorInternalError(err) } if !visible { - err := fmt.Errorf("status with id %s not visible to user with id %s", status.ID, requestingAccount.ID) - return nil, gtserror.NewErrorNotFound(err) + const text = "status not vising to requesting account" + return nil, gtserror.NewErrorNotFound(errors.New(text)) } statusable, err := p.converter.StatusToAS(ctx, status) if err != nil { + err := gtserror.Newf("error converting status: %w", err) return nil, gtserror.NewErrorInternalError(err) } data, err := ap.Serialize(statusable) if err != nil { + err := gtserror.Newf("error serializing status: %w", err) return nil, gtserror.NewErrorInternalError(err) } return data, nil } -// GetStatus handles the getting of a fedi/activitypub representation of replies to a status, performing appropriate -// authentication before returning a JSON serializable interface to the caller. -func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername string, requestedStatusID string, page bool, onlyOtherAccounts bool, onlyOtherAccountsSet bool, minID string) (interface{}, gtserror.WithCode) { - requestedAccount, requestingAccount, errWithCode := p.authenticate(ctx, requestedUsername) +// GetStatus handles the getting of a fedi/activitypub representation of replies to a status, +// performing appropriate authentication before returning a JSON serializable interface to the caller. +func (p *Processor) StatusRepliesGet( + ctx context.Context, + requestedUser string, + statusID string, + page *paging.Page, + onlyOtherAccounts bool, +) (interface{}, gtserror.WithCode) { + // Authenticate the incoming request, getting related user accounts. + requester, receiver, errWithCode := p.authenticate(ctx, requestedUser) if errWithCode != nil { return nil, errWithCode } - status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID) - if err != nil { - return nil, gtserror.NewErrorNotFound(err) + // Get target status and ensure visible to requester. + status, errWithCode := p.c.GetVisibleTargetStatus(ctx, + requester, + statusID, + ) + if errWithCode != nil { + return nil, errWithCode } - if status.AccountID != requestedAccount.ID { - return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s does not belong to account with id %s", status.ID, requestedAccount.ID)) + // Ensure status is by receiving account. + if status.AccountID != receiver.ID { + const text = "status does not belong to receiving account" + return nil, gtserror.NewErrorNotFound(errors.New(text)) } - visible, err := p.filter.StatusVisible(ctx, requestedAccount, status) + // Parse replies collection ID from status' URI with onlyOtherAccounts param. + onlyOtherAccStr := "only_other_accounts=" + strconv.FormatBool(onlyOtherAccounts) + collectionID, err := url.Parse(status.URI + "/replies?" + onlyOtherAccStr) if err != nil { + err := gtserror.Newf("error parsing status uri %s: %w", status.URI, err) return nil, gtserror.NewErrorInternalError(err) } - if !visible { - return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s not visible to user with id %s", status.ID, requestingAccount.ID)) + + // Get *all* available replies for status (i.e. without paging). + replies, err := p.state.DB.GetStatusReplies(ctx, status.ID) + if err != nil { + err := gtserror.Newf("error getting status replies: %w", err) + return nil, gtserror.NewErrorInternalError(err) } - var data map[string]interface{} + if onlyOtherAccounts { + // If 'onlyOtherAccounts' is set, drop all by original status author. + replies = slices.DeleteFunc(replies, func(reply *gtsmodel.Status) bool { + return reply.AccountID == status.AccountID + }) + } - // now there are three scenarios: - // 1. we're asked for the whole collection and not a page -- we can just return the collection, with no items, but a link to 'first' page. - // 2. we're asked for a page but only_other_accounts has not been set in the query -- so we should just return the first page of the collection, with no items. - // 3. we're asked for a page, and only_other_accounts has been set, and min_id has optionally been set -- so we need to return some actual items! - switch { - case !page: - // scenario 1 - // get the collection - collection, err := p.converter.StatusToASRepliesCollection(ctx, status, onlyOtherAccounts) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) + // Reslice replies dropping all those invisible to requester. + replies, err = p.filter.StatusesVisible(ctx, requester, replies) + if err != nil { + err := gtserror.Newf("error filtering status replies: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + var obj vocab.Type + + // Start AS collection params. + var params ap.CollectionParams + params.ID = collectionID + params.Total = len(replies) + + if page == nil { + // i.e. paging disabled, return collection + // that links to first page (i.e. path below). + params.Query = make(url.Values, 1) + params.Query.Set("limit", "20") // enables paging + obj = ap.NewASOrderedCollection(params) + } else { + // i.e. paging enabled + + // Page and reslice the replies according to given parameters. + replies = paging.Page_PageFunc(page, replies, func(reply *gtsmodel.Status) string { + return reply.ID + }) + + // page ID values. + var lo, hi string + + if len(replies) > 0 { + // Get the lowest and highest + // ID values, used for paging. + lo = replies[len(replies)-1].ID + hi = replies[0].ID } - data, err = ap.Serialize(collection) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - case page && !onlyOtherAccountsSet: - // scenario 2 - // get the collection - collection, err := p.converter.StatusToASRepliesCollection(ctx, status, onlyOtherAccounts) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - // but only return the first page - data, err = ap.Serialize(collection.GetActivityStreamsFirst().GetActivityStreamsCollectionPage()) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - default: - // scenario 3 - // get immediate children - replies, err := p.state.DB.GetStatusChildren(ctx, status, true, minID) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } + // Start AS collection page params. + var pageParams ap.CollectionPageParams + pageParams.CollectionParams = params - // filter children and extract URIs - replyURIs := map[string]*url.URL{} - for _, r := range replies { - // only show public or unlocked statuses as replies - if r.Visibility != gtsmodel.VisibilityPublic && r.Visibility != gtsmodel.VisibilityUnlocked { - continue - } + // Current page details. + pageParams.Current = page + pageParams.Count = len(replies) - // respect onlyOtherAccounts parameter - if onlyOtherAccounts && r.AccountID == requestedAccount.ID { - continue - } + // Set linked next/prev parameters. + pageParams.Next = page.Next(lo, hi) + pageParams.Prev = page.Prev(lo, hi) - // only show replies that the status owner can see - visibleToStatusOwner, err := p.filter.StatusVisible(ctx, requestedAccount, r) - if err != nil || !visibleToStatusOwner { - continue - } + // Set the collection item property builder function. + pageParams.Append = func(i int, itemsProp ap.ItemsPropertyBuilder) { + // Get follower URI at index. + status := replies[i] + uri := status.URI - // only show replies that the requester can see - visibleToRequester, err := p.filter.StatusVisible(ctx, requestingAccount, r) - if err != nil || !visibleToRequester { - continue - } - - rURI, err := url.Parse(r.URI) + // Parse URL object from URI. + iri, err := url.Parse(uri) if err != nil { - continue + log.Errorf(ctx, "error parsing status uri %s: %v", uri, err) + return } - replyURIs[r.ID] = rURI + // Add to item property. + itemsProp.AppendIRI(iri) } - repliesPage, err := p.converter.StatusURIsToASRepliesPage(ctx, status, onlyOtherAccounts, minID, replyURIs) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - data, err = ap.Serialize(repliesPage) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } + // Build AS collection page object from params. + obj = ap.NewASOrderedCollectionPage(pageParams) + } + + // Serialized the prepared object. + data, err := ap.Serialize(obj) + if err != nil { + err := gtserror.Newf("error serializing: %w", err) + return nil, gtserror.NewErrorInternalError(err) } return data, nil diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 65f05f49e..ac930aeb2 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -156,23 +156,23 @@ func NewProcessor( // // Start with sub processors that will // be required by the workers processor. - commonProcessor := common.New(state, converter, federator, filter) - processor.account = account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) + common := common.New(state, converter, federator, filter) + processor.account = account.New(&common, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) processor.media = media.New(state, converter, mediaManager, federator.TransportController()) processor.stream = stream.New(state, oauthServer) // Instantiate the rest of the sub // processors + pin them to this struct. - processor.account = account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) + processor.account = account.New(&common, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender) - processor.fedi = fedi.New(state, converter, federator, filter) + processor.fedi = fedi.New(state, &common, converter, federator, filter) processor.list = list.New(state, converter) processor.markers = markers.New(state, converter) - processor.polls = polls.New(&commonProcessor, state, converter) + processor.polls = polls.New(&common, state, converter) processor.report = report.New(state, converter) processor.timeline = timeline.New(state, converter, filter) processor.search = search.New(state, federator, converter, filter) - processor.status = status.New(state, &commonProcessor, &processor.polls, federator, converter, filter, parseMentionFunc) + processor.status = status.New(state, &common, &processor.polls, federator, converter, filter, parseMentionFunc) processor.user = user.New(state, emailSender) // Workers processor handles asynchronous diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index ae6918e3f..c182bd148 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -67,7 +67,7 @@ func (p *Processor) contextGet( Descendants: []apimodel.Status{}, } - parents, err := p.state.DB.GetStatusParents(ctx, targetStatus, false) + parents, err := p.state.DB.GetStatusParents(ctx, targetStatus) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -85,7 +85,7 @@ func (p *Processor) contextGet( return context.Ancestors[i].ID < context.Ancestors[j].ID }) - children, err := p.state.DB.GetStatusChildren(ctx, targetStatus, false, "") + children, err := p.state.DB.GetStatusChildren(ctx, targetStatus.ID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/util/ptr.go b/internal/util/ptr.go index 2ce96e1d1..0ad207617 100644 --- a/internal/util/ptr.go +++ b/internal/util/ptr.go @@ -33,3 +33,11 @@ func EqualPtrs[T comparable](t1, t2 *T) bool { func Ptr[T any](t T) *T { return &t } + +// PtrValueOr returns either value of ptr, or default. +func PtrValueOr[T any](t *T, _default T) T { + if t != nil { + return *t + } + return _default +} diff --git a/internal/visibility/home_timeline.go b/internal/visibility/home_timeline.go index 56290e836..273ca8457 100644 --- a/internal/visibility/home_timeline.go +++ b/internal/visibility/home_timeline.go @@ -20,7 +20,6 @@ package visibility import ( "context" "errors" - "fmt" "time" "github.com/superseriousbusiness/gotosocial/internal/cache" @@ -219,7 +218,7 @@ func (f *Filter) isVisibleConversation(ctx context.Context, owner *gtsmodel.Acco status.AccountID, ) if err != nil { - return false, fmt.Errorf("error checking follow %s->%s: %w", owner.ID, status.AccountID, err) + return false, gtserror.Newf("error checking follow %s->%s: %w", owner.ID, status.AccountID, err) } if !followAuthor { @@ -236,7 +235,7 @@ func (f *Filter) isVisibleConversation(ctx context.Context, owner *gtsmodel.Acco mention.TargetAccountID, ) if err != nil { - return false, fmt.Errorf("error checking mention follow %s->%s: %w", owner.ID, mention.TargetAccountID, err) + return false, gtserror.Newf("error checking mention follow %s->%s: %w", owner.ID, mention.TargetAccountID, err) } if follow { diff --git a/internal/visibility/public_timeline.go b/internal/visibility/public_timeline.go index 77ce5760c..63e802614 100644 --- a/internal/visibility/public_timeline.go +++ b/internal/visibility/public_timeline.go @@ -19,11 +19,11 @@ package visibility import ( "context" - "fmt" "time" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" ) @@ -105,7 +105,7 @@ func (f *Filter) isStatusPublicTimelineable(ctx context.Context, requester *gtsm parentID, ) if err != nil { - return false, fmt.Errorf("isStatusPublicTimelineable: error getting status parent %s: %w", parentID, err) + return false, gtserror.Newf("error getting status parent %s: %w", parentID, err) } if parent.AccountID != status.AccountID { diff --git a/internal/visibility/status.go b/internal/visibility/status.go index d41bbc80b..3684bae4f 100644 --- a/internal/visibility/status.go +++ b/internal/visibility/status.go @@ -19,32 +19,26 @@ package visibility import ( "context" - "fmt" + "slices" "github.com/superseriousbusiness/gotosocial/internal/cache" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" ) // StatusesVisible calls StatusVisible for each status in the statuses slice, and returns a slice of only statuses which are visible to the requester. func (f *Filter) StatusesVisible(ctx context.Context, requester *gtsmodel.Account, statuses []*gtsmodel.Status) ([]*gtsmodel.Status, error) { - // Preallocate slice of maximum possible length. - filtered := make([]*gtsmodel.Status, 0, len(statuses)) - - for _, status := range statuses { - // Check whether status is visible to requester. + var errs gtserror.MultiError + filtered := slices.DeleteFunc(statuses, func(status *gtsmodel.Status) bool { visible, err := f.StatusVisible(ctx, requester, status) if err != nil { - return nil, err + errs.Append(err) + return true } - - if visible { - // Add filtered status to ret slice. - filtered = append(filtered, status) - } - } - - return filtered, nil + return !visible + }) + return filtered, errs.Combine() } // StatusVisible will check if given status is visible to requester, accounting for requester with no auth (i.e is nil), suspensions, disabled local users, account blocks and status privacy. @@ -85,13 +79,13 @@ func (f *Filter) StatusVisible(ctx context.Context, requester *gtsmodel.Account, func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) { // Ensure that status is fully populated for further processing. if err := f.state.DB.PopulateStatus(ctx, status); err != nil { - return false, fmt.Errorf("isStatusVisible: error populating status %s: %w", status.ID, err) + return false, gtserror.Newf("error populating status %s: %w", status.ID, err) } // Check whether status accounts are visible to the requester. visible, err := f.areStatusAccountsVisible(ctx, requester, status) if err != nil { - return false, fmt.Errorf("isStatusVisible: error checking status %s account visibility: %w", status.ID, err) + return false, gtserror.Newf("error checking status %s account visibility: %w", status.ID, err) } else if !visible { return false, nil } @@ -127,7 +121,7 @@ func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Accoun // Boosted status needs its mentions populating, fetch these from database. status.BoostOf.Mentions, err = f.state.DB.GetMentions(ctx, status.BoostOf.MentionIDs) if err != nil { - return false, fmt.Errorf("isStatusVisible: error populating boosted status %s mentions: %w", status.BoostOfID, err) + return false, gtserror.Newf("error populating boosted status %s mentions: %w", status.BoostOfID, err) } } @@ -145,7 +139,7 @@ func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Accoun status.AccountID, ) if err != nil { - return false, fmt.Errorf("isStatusVisible: error checking follow %s->%s: %w", requester.ID, status.AccountID, err) + return false, gtserror.Newf("error checking follow %s->%s: %w", requester.ID, status.AccountID, err) } if !follows { @@ -162,7 +156,7 @@ func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Accoun status.AccountID, ) if err != nil { - return false, fmt.Errorf("isStatusVisible: error checking mutual follow %s<->%s: %w", requester.ID, status.AccountID, err) + return false, gtserror.Newf("error checking mutual follow %s<->%s: %w", requester.ID, status.AccountID, err) } if !mutuals { @@ -187,7 +181,7 @@ func (f *Filter) areStatusAccountsVisible(ctx context.Context, requester *gtsmod // Check whether status author's account is visible to requester. visible, err := f.AccountVisible(ctx, requester, status.Account) if err != nil { - return false, fmt.Errorf("error checking status author visibility: %w", err) + return false, gtserror.Newf("error checking status author visibility: %w", err) } if !visible { @@ -206,7 +200,7 @@ func (f *Filter) areStatusAccountsVisible(ctx context.Context, requester *gtsmod // Check whether boosted status author's account is visible to requester. visible, err := f.AccountVisible(ctx, requester, status.BoostOfAccount) if err != nil { - return false, fmt.Errorf("error checking boosted author visibility: %w", err) + return false, gtserror.Newf("error checking boosted author visibility: %w", err) } if !visible { diff --git a/testrig/testmodels.go b/testrig/testmodels.go index 248508013..b04e202a7 100644 --- a/testrig/testmodels.go +++ b/testrig/testmodels.go @@ -3163,7 +3163,7 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin DateHeader: date, } - target = URLMustParse(statuses["local_account_1_status_1"].URI + "/replies") + target = URLMustParse(statuses["local_account_1_status_1"].URI + "/replies?only_other_accounts=false") sig, digest, date = GetSignatureForDereference(accounts["remote_account_1"].PublicKeyURI, accounts["remote_account_1"].PrivateKey, target) fossSatanDereferenceLocalAccount1Status1Replies := ActivityWithSignature{ SignatureHeader: sig, @@ -3179,7 +3179,7 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin DateHeader: date, } - target = URLMustParse(statuses["local_account_1_status_1"].URI + "/replies?only_other_accounts=false&page=true&min_id=01FF25D5Q0DH7CHD57CTRS6WK0") + target = URLMustParse(statuses["local_account_1_status_1"].URI + "/replies?min_id=01FF25D5Q0DH7CHD57CTRS6WK0&only_other_accounts=false") sig, digest, date = GetSignatureForDereference(accounts["remote_account_1"].PublicKeyURI, accounts["remote_account_1"].PrivateKey, target) fossSatanDereferenceLocalAccount1Status1RepliesLast := ActivityWithSignature{ SignatureHeader: sig,