From 10eed3c80df541d9b2f7e625773eb865fdab6dbe Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Fri, 27 Oct 2023 15:24:21 +0200 Subject: [PATCH] start working on multiple tag names for tag timeline --- internal/api/client/timelines/tag.go | 9 ++++- internal/api/util/parsequery.go | 1 + internal/db/bundb/timeline.go | 6 +-- internal/db/bundb/timeline_test.go | 2 +- internal/db/timeline.go | 4 +- internal/processing/timeline/tag.go | 56 ++++++++++++++++++---------- 6 files changed, 51 insertions(+), 27 deletions(-) diff --git a/internal/api/client/timelines/tag.go b/internal/api/client/timelines/tag.go index 58754705b..caa7dc5c6 100644 --- a/internal/api/client/timelines/tag.go +++ b/internal/api/client/timelines/tag.go @@ -125,10 +125,17 @@ func (m *Module) TagTimelineGETHandler(c *gin.Context) { return } + // Append any additional tags + // passed as `any[]` parameter. + tagNames := append( + []string{tagName}, + c.QueryArray(apiutil.TagAnyKey)..., + ) + resp, errWithCode := m.processor.Timeline().TagTimelineGet( c.Request.Context(), authed.Account, - tagName, + tagNames, c.Query(apiutil.MaxIDKey), c.Query(apiutil.SinceIDKey), c.Query(apiutil.MinIDKey), diff --git a/internal/api/util/parsequery.go b/internal/api/util/parsequery.go index 6a9116dcf..70e6269e5 100644 --- a/internal/api/util/parsequery.go +++ b/internal/api/util/parsequery.go @@ -54,6 +54,7 @@ const ( /* Tag keys */ TagNameKey = "tag_name" + TagAnyKey = "any[]" /* Web endpoint keys */ diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index a07f1a844..e766d2a43 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -463,7 +463,7 @@ func (t *timelineDB) GetListTimeline( func (t *timelineDB) GetTagTimeline( ctx context.Context, - tagID string, + tagIDs []string, maxID string, sinceID string, minID string, @@ -492,8 +492,8 @@ func (t *timelineDB) GetTagTimeline( ). // Public only. Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). - // This tag only. - Where("? = ?", bun.Ident("status_to_tag.tag_id"), tagID) + // Provided tag IDs only. + Where("? IN (?)", bun.Ident("status_to_tag.tag_id"), bun.In(tagIDs)) if maxID == "" || maxID >= id.Highest { const future = 24 * time.Hour diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index ac169ec4a..bd125ab43 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -311,7 +311,7 @@ func (suite *TimelineTestSuite) TestGetTagTimelineNoParams() { tag = suite.testTags["welcome"] ) - s, err := suite.db.GetTagTimeline(ctx, tag.ID, "", "", "", 1) + s, err := suite.db.GetTagTimeline(ctx, []string{tag.ID}, "", "", "", 1) if err != nil { suite.FailNow(err.Error()) } diff --git a/internal/db/timeline.go b/internal/db/timeline.go index 43ac655d0..7c3224746 100644 --- a/internal/db/timeline.go +++ b/internal/db/timeline.go @@ -49,7 +49,7 @@ type Timeline interface { // Statuses should be returned in descending order of when they were created (newest first). GetListTimeline(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error) - // GetTagTimeline returns a slice of public-visibility statuses that use the given tagID. + // GetTagTimeline returns a slice of public-visibility statuses that use the given tagIDs. // Statuses should be returned in descending order of when they were created (newest first). - GetTagTimeline(ctx context.Context, tagID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error) + GetTagTimeline(ctx context.Context, tagIDs []string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error) } diff --git a/internal/processing/timeline/tag.go b/internal/processing/timeline/tag.go index 45632ce06..e7b99b0b6 100644 --- a/internal/processing/timeline/tag.go +++ b/internal/processing/timeline/tag.go @@ -32,30 +32,35 @@ import ( ) // TagTimelineGet gets a pageable timeline for the given -// tagName and given paging parameters. It will ensure +// tagNames and given paging parameters. It will ensure // that each status in the timeline is actually visible // to requestingAcct before returning it. func (p *Processor) TagTimelineGet( ctx context.Context, requestingAcct *gtsmodel.Account, - tagName string, + tagNames []string, maxID string, sinceID string, minID string, limit int, ) (*apimodel.PageableResponse, gtserror.WithCode) { - tag, errWithCode := p.getTag(ctx, tagName) - if errWithCode != nil { - return nil, errWithCode + tagIDs := make([]string, 0, len(tagNames)) + for _, tagName := range tagNames { + tag, errWithCode := p.getTag(ctx, tagName) + if errWithCode != nil { + return nil, errWithCode + } + + if tag == nil || !*tag.Useable || !*tag.Listable { + // Obey mastodon API by returning 404 for this. + err := fmt.Errorf("tag was not found, or not useable/listable on this instance") + return nil, gtserror.NewErrorNotFound(err, err.Error()) + } + + tagIDs = append(tagIDs, tag.ID) } - if tag == nil || !*tag.Useable || !*tag.Listable { - // Obey mastodon API by returning 404 for this. - err := fmt.Errorf("tag was not found, or not useable/listable on this instance") - return nil, gtserror.NewErrorNotFound(err, err.Error()) - } - - statuses, err := p.state.DB.GetTagTimeline(ctx, tag.ID, maxID, sinceID, minID, limit) + statuses, err := p.state.DB.GetTagTimeline(ctx, tagIDs, maxID, sinceID, minID, limit) if err != nil && !errors.Is(err, db.ErrNoEntries) { err = gtserror.Newf("db error getting statuses: %w", err) return nil, gtserror.NewErrorInternalError(err) @@ -66,8 +71,7 @@ func (p *Processor) TagTimelineGet( requestingAcct, statuses, limit, - // Use API URL for tag. - "/api/v1/timelines/tag/"+tagName, + tagNames, ) } @@ -95,7 +99,7 @@ func (p *Processor) packageTagResponse( requestingAcct *gtsmodel.Account, statuses []*gtsmodel.Status, limit int, - requestPath string, + tagNames []string, ) (*apimodel.PageableResponse, gtserror.WithCode) { count := len(statuses) if count == 0 { @@ -131,11 +135,23 @@ func (p *Processor) packageTagResponse( items = append(items, apiStatus) } + // Use first / "primary" tag for API endpoint. + path := "/api/v1/timelines/tag/" + tagNames[0] + + // Add any additional tags. + var extraQueryParams []string + if len(tagNames) > 1 { + for _, tagName := range tagNames[1:] { + extraQueryParams = append(extraQueryParams, "any[]="+tagName) + } + } + return util.PackagePageableResponse(util.PageableResponseParams{ - Items: items, - Path: requestPath, - NextMaxIDValue: nextMaxIDValue, - PrevMinIDValue: prevMinIDValue, - Limit: limit, + Items: items, + Path: path, + NextMaxIDValue: nextMaxIDValue, + PrevMinIDValue: prevMinIDValue, + Limit: limit, + ExtraQueryParams: extraQueryParams, }) }