diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index 76e58c2f8..fa4ec9b82 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -32,7 +32,10 @@ import ( apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/middleware" + tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" + "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/tracing" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "go.uber.org/automaxprocs/maxprocs" "github.com/superseriousbusiness/gotosocial/internal/config" @@ -72,7 +75,6 @@ var Start action.GTSAction = func(ctx context.Context) error { defer state.Caches.Stop() // Initialize Tracing - if err := tracing.Initialize(); err != nil { return fmt.Errorf("error initializing tracing: %w", err) } @@ -110,36 +112,56 @@ var Start action.GTSAction = func(ctx context.Context) error { state.Workers.Start() defer state.Workers.Stop() - // build backend handlers + // Build handlers used in later initializations. mediaManager := media.NewManager(&state) oauthServer := oauth.New(ctx, dbService) typeConverter := typeutils.NewConverter(dbService) + filter := visibility.NewFilter(&state) federatingDB := federatingdb.New(&state, typeConverter) transportController := transport.NewController(&state, federatingDB, &federation.Clock{}, client) federator := federation.NewFederator(&state, federatingDB, transportController, typeConverter, mediaManager) - // decide whether to create a noop email sender (won't send emails) or a real one + // Decide whether to create a noop email + // sender (won't send emails) or a real one. var emailSender email.Sender if smtpHost := config.GetSMTPHost(); smtpHost != "" { - // host is defined so create a proper sender + // Host is defined; create a proper sender. emailSender, err = email.NewSender() if err != nil { return fmt.Errorf("error creating email sender: %s", err) } } else { - // no host is defined so create a noop sender + // No host is defined; create a noop sender. emailSender, err = email.NewNoopSender(nil) if err != nil { return fmt.Errorf("error creating noop email sender: %s", err) } } - // create the message processor using the other services we've created so far - processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender) - if err := processor.Start(); err != nil { - return fmt.Errorf("error creating processor: %s", err) + // Initialize timelines. + state.Timelines.Home = timeline.NewManager( + tlprocessor.HomeTimelineGrab(&state), + tlprocessor.HomeTimelineFilter(&state, filter), + tlprocessor.HomeTimelineStatusPrepare(&state, typeConverter), + tlprocessor.SkipInsert(), + ) + if err := state.Timelines.Home.Start(); err != nil { + return fmt.Errorf("error starting home timeline: %s", err) } + state.Timelines.List = timeline.NewManager( + tlprocessor.ListTimelineGrab(&state), + tlprocessor.ListTimelineFilter(&state, filter), + tlprocessor.ListTimelineStatusPrepare(&state, typeConverter), + tlprocessor.SkipInsert(), + ) + if err := state.Timelines.List.Start(); err != nil { + return fmt.Errorf("error starting list timeline: %s", err) + } + + // Create the processor using all the other services we've created so far. + processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender) + // Set state client / federator worker enqueue functions state.Workers.EnqueueClientAPI = processor.EnqueueClientAPI state.Workers.EnqueueFederator = processor.EnqueueFederator diff --git a/cmd/gotosocial/action/testrig/testrig.go b/cmd/gotosocial/action/testrig/testrig.go index 5d4f20773..8f55c4b4a 100644 --- a/cmd/gotosocial/action/testrig/testrig.go +++ b/cmd/gotosocial/action/testrig/testrig.go @@ -38,9 +38,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/oidc" + tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/tracing" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/internal/web" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -89,11 +92,31 @@ var Start action.GTSAction = func(ctx context.Context) error { federator := testrig.NewTestFederator(&state, transportController, mediaManager) emailSender := testrig.NewEmailSender("./web/template/", nil) + typeConverter := testrig.NewTestTypeConverter(state.DB) + filter := visibility.NewFilter(&state) + + // Initialize timelines. + state.Timelines.Home = timeline.NewManager( + tlprocessor.HomeTimelineGrab(&state), + tlprocessor.HomeTimelineFilter(&state, filter), + tlprocessor.HomeTimelineStatusPrepare(&state, typeConverter), + tlprocessor.SkipInsert(), + ) + if err := state.Timelines.Home.Start(); err != nil { + return fmt.Errorf("error starting home timeline: %s", err) + } + + state.Timelines.List = timeline.NewManager( + tlprocessor.ListTimelineGrab(&state), + tlprocessor.ListTimelineFilter(&state, filter), + tlprocessor.ListTimelineStatusPrepare(&state, typeConverter), + tlprocessor.SkipInsert(), + ) + if err := state.Timelines.List.Start(); err != nil { + return fmt.Errorf("error starting list timeline: %s", err) + } processor := testrig.NewTestProcessor(&state, federator, emailSender, mediaManager) - if err := processor.Start(); err != nil { - return fmt.Errorf("error starting processor: %s", err) - } /* HTTP router initialization diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml index fb76bd6a2..6c6b97a34 100644 --- a/docs/api/swagger.yaml +++ b/docs/api/swagger.yaml @@ -1635,6 +1635,28 @@ definitions: type: object x-go-name: InstanceV2Users x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model + list: + properties: + id: + description: The ID of the list. + type: string + x-go-name: ID + replies_policy: + description: |- + RepliesPolicy for this list. + followed = Show replies to any followed user + list = Show replies to members of the list + none = Show replies to no one + type: string + x-go-name: RepliesPolicy + title: + description: The user-defined title of the list. + type: string + x-go-name: Title + title: List represents a user-created list of accounts that the user follows. + type: object + x-go-name: List + x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model mediaDimensions: properties: aspect: @@ -2881,6 +2903,40 @@ paths: summary: See accounts followed by given account id. tags: - accounts + /api/v1/accounts/{id}/lists: + get: + operationId: accountLists + parameters: + - description: Account ID. + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "200": + description: Array of all lists containing this account. + schema: + items: + $ref: '#/definitions/list' + type: array + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:lists + summary: See all lists of yours that contain requested account. + tags: + - accounts /api/v1/accounts/{id}/statuses: get: description: The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer). @@ -3211,7 +3267,7 @@ paths: name: id required: true type: string - - description: Type of action to be taken (`disable`, `silence`, or `suspend`). + - description: Type of action to be taken, currently only supports `suspend`. in: formData name: type required: true @@ -4453,6 +4509,343 @@ paths: description: internal server error tags: - instance + /api/v1/list: + post: + consumes: + - application/json + - application/xml + - application/x-www-form-urlencoded + operationId: listCreate + parameters: + - description: Title of this list. + example: Cool People + in: formData + name: title + required: true + type: string + x-go-name: Title + - default: list + description: |- + RepliesPolicy for this list. + followed = Show replies to any followed user + list = Show replies to members of the list + none = Show replies to no one + example: list + in: formData + name: replies_policy + type: string + x-go-name: RepliesPolicy + produces: + - application/json + responses: + "200": + description: The newly created list. + schema: + $ref: '#/definitions/list' + "400": + description: bad request + "401": + description: unauthorized + "403": + description: forbidden + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - write:lists + summary: Create a new list. + tags: + - lists + put: + consumes: + - application/json + - application/xml + - application/x-www-form-urlencoded + operationId: listUpdate + parameters: + - description: ID of the list + example: Cool People + in: path + name: id + required: true + type: string + x-go-name: Title + - description: Title of this list. + example: Cool People + in: formData + name: title + type: string + x-go-name: RepliesPolicy + - description: |- + RepliesPolicy for this list. + followed = Show replies to any followed user + list = Show replies to members of the list + none = Show replies to no one + example: list + in: formData + name: replies_policy + type: string + produces: + - application/json + responses: + "200": + description: The newly updated list. + schema: + $ref: '#/definitions/list' + "400": + description: bad request + "401": + description: unauthorized + "403": + description: forbidden + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - write:lists + summary: Update an existing list. + tags: + - lists + /api/v1/list/{id}: + delete: + operationId: listDelete + parameters: + - description: ID of the list + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "200": + description: list deleted + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - write:lists + summary: Delete a single list with the given ID. + tags: + - lists + get: + operationId: list + parameters: + - description: ID of the list + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "200": + description: Requested list. + schema: + $ref: '#/definitions/list' + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:lists + summary: Get a single list with the given ID. + tags: + - lists + /api/v1/list/{id}/accounts: + delete: + consumes: + - application/json + - application/xml + - application/x-www-form-urlencoded + operationId: removeListAccounts + parameters: + - description: ID of the list + in: path + name: id + required: true + type: string + - description: Array of accountIDs to modify. Each accountID must correspond to an account that the requesting account follows. + in: formData + items: + type: string + name: account_ids + required: true + type: array + produces: + - application/json + responses: + "200": + description: list accounts updated + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:lists + summary: Remove one or more accounts from the given list. + tags: + - lists + get: + description: |- + The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline. + + Example: + + ``` + ; rel="next", ; rel="prev" + ```` + operationId: listAccounts + parameters: + - description: ID of the list + in: path + name: id + required: true + type: string + - description: Return only list entries *OLDER* than the given max ID. The account from the list entry with the specified ID will not be included in the response. + in: query + name: max_id + type: string + - description: Return only list entries *NEWER* than the given since ID. The account from the list entry with the specified ID will not be included in the response. + in: query + name: since_id + type: string + - description: Return only list entries *IMMEDIATELY NEWER* than the given min ID. The account from the list entry with the specified ID will not be included in the response. + in: query + name: min_id + type: string + - default: 20 + description: Number of accounts to return. + in: query + name: limit + type: integer + produces: + - application/json + responses: + "200": + description: Array of accounts. + headers: + Link: + description: Links to the next and previous queries. + type: string + schema: + items: + $ref: '#/definitions/account' + type: array + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:lists + summary: Page through accounts in this list. + tags: + - lists + post: + consumes: + - application/json + - application/xml + - application/x-www-form-urlencoded + operationId: addListAccounts + parameters: + - description: ID of the list + in: path + name: id + required: true + type: string + - description: Array of accountIDs to modify. Each accountID must correspond to an account that the requesting account follows. + in: formData + items: + type: string + name: account_ids + required: true + type: array + produces: + - application/json + responses: + "200": + description: list accounts updated + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:lists + summary: Add one or more accounts to the given list. + tags: + - lists + /api/v1/lists: + get: + operationId: lists + produces: + - application/json + responses: + "200": + description: Array of all lists owned by the requesting user. + schema: + items: + $ref: '#/definitions/list' + type: array + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:lists + summary: Get all lists for owned by authorized user. + tags: + - lists /api/v1/media/{id}: get: operationId: mediaGet @@ -5579,6 +5972,18 @@ paths: name: stream required: true type: string + - description: |- + ID of the list to subscribe to. + Only used if stream type is 'list'. + in: query + name: list + type: string + - description: |- + Name of the tag to subscribe to. + Only used if stream type is 'hashtag' or 'hashtag:local'. + in: query + name: tag + type: string produces: - application/json responses: @@ -5696,6 +6101,65 @@ paths: summary: See statuses/posts by accounts you follow. tags: - timelines + /api/v1/timelines/list/{id}: + get: + description: |- + The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer). + + The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline. + + Example: + + ``` + ; rel="next", ; rel="prev" + ```` + operationId: listTimeline + parameters: + - description: ID of the list + in: path + name: id + required: true + type: string + - description: Return only statuses *OLDER* than the given max status ID. The status with the specified ID will not be included in the response. + in: query + name: max_id + type: string + - description: Return only statuses *NEWER* than the given since status ID. The status with the specified ID will not be included in the response. + in: query + name: since_id + type: string + - description: Return only statuses *NEWER* than the given since status ID. The status with the specified ID will not be included in the response. + in: query + name: min_id + type: string + - default: 20 + description: Number of statuses to return. + in: query + name: limit + type: integer + produces: + - application/json + responses: + "200": + description: Array of statuses. + headers: + Link: + description: Links to the next and previous queries. + type: string + schema: + items: + $ref: '#/definitions/status' + type: array + "400": + description: bad request + "401": + description: unauthorized + security: + - OAuth2 Bearer: + - read:lists + summary: See statuses/posts from the given list timeline. + tags: + - timelines /api/v1/timelines/public: get: description: |- @@ -5980,6 +6444,7 @@ securityDefinitions: read:custom_emojis: grant read access to custom_emojis read:favourites: grant read access to favourites read:follows: grant read access to follows + read:lists: grant read access to lists read:media: grant read access to media read:notifications: grants read access to notifications read:search: grant read access to searches @@ -5990,6 +6455,7 @@ securityDefinitions: write:accounts: grants write access to accounts write:blocks: grants write access to blocks write:follows: grants write access to follows + write:lists: grants write access to lists write:media: grants write access to media write:statuses: grants write access to statuses write:user: grants write access to user-level info diff --git a/docs/swagger.go b/docs/swagger.go index 546579772..a65b4bf40 100644 --- a/docs/swagger.go +++ b/docs/swagger.go @@ -37,6 +37,7 @@ // read:custom_emojis: grant read access to custom_emojis // read:favourites: grant read access to favourites // read:follows: grant read access to follows +// read:lists: grant read access to lists // read:media: grant read access to media // read:search: grant read access to searches // read:statuses: grants read access to statuses @@ -47,6 +48,7 @@ // write:accounts: grants write access to accounts // write:blocks: grants write access to blocks // write:follows: grants write access to follows +// write:lists: grants write access to lists // write:media: grants write access to media // write:statuses: grants write access to statuses // write:user: grants write access to user-level info diff --git a/example/config.yaml b/example/config.yaml index 668d1729c..7f119ab16 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -289,6 +289,14 @@ cache: follow-request-ttl: "30m" follow-request-sweep-freq: "1m" + list-max-size: 2000 + list-ttl: "30m" + list-sweep-freq: "1m" + + list-entry-max-size: 2000 + list-entry-ttl: "30m" + list-entry-sweep-freq: "1m" + media-max-size: 1000 media-ttl: "30m" media-sweep-freq: "1m" diff --git a/internal/api/activitypub/emoji/emojiget_test.go b/internal/api/activitypub/emoji/emojiget_test.go index 16b004299..2438e09f4 100644 --- a/internal/api/activitypub/emoji/emojiget_test.go +++ b/internal/api/activitypub/emoji/emojiget_test.go @@ -36,6 +36,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -74,9 +75,14 @@ func (suite *EmojiGetTestSuite) SetupTest() { suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage - suite.tc = testrig.NewTestTypeConverter(suite.db) + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) @@ -86,8 +92,6 @@ func (suite *EmojiGetTestSuite) SetupTest() { testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") suite.signatureCheck = middleware.SignatureCheck(suite.db.IsURIBlocked) - - suite.NoError(suite.processor.Start()) } func (suite *EmojiGetTestSuite) TearDownTest() { diff --git a/internal/api/activitypub/users/inboxpost_test.go b/internal/api/activitypub/users/inboxpost_test.go index 26c4029a2..4d494cf64 100644 --- a/internal/api/activitypub/users/inboxpost_test.go +++ b/internal/api/activitypub/users/inboxpost_test.go @@ -89,7 +89,6 @@ func (suite *InboxPostTestSuite) TestPostBlock() { emailSender := testrig.NewEmailSender("../../../../web/template/", nil) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) - suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() @@ -190,7 +189,6 @@ func (suite *InboxPostTestSuite) TestPostUnblock() { emailSender := testrig.NewEmailSender("../../../../web/template/", nil) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) - suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() @@ -296,7 +294,6 @@ func (suite *InboxPostTestSuite) TestPostUpdate() { emailSender := testrig.NewEmailSender("../../../../web/template/", nil) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) - suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() @@ -425,7 +422,6 @@ func (suite *InboxPostTestSuite) TestPostDelete() { emailSender := testrig.NewEmailSender("../../../../web/template/", nil) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) - suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() diff --git a/internal/api/activitypub/users/outboxget_test.go b/internal/api/activitypub/users/outboxget_test.go index a75d850b6..1abe31ef6 100644 --- a/internal/api/activitypub/users/outboxget_test.go +++ b/internal/api/activitypub/users/outboxget_test.go @@ -106,7 +106,6 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() { emailSender := testrig.NewEmailSender("../../../../web/template/", nil) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) - suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() @@ -181,7 +180,6 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() { emailSender := testrig.NewEmailSender("../../../../web/template/", nil) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) - suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() diff --git a/internal/api/activitypub/users/repliesget_test.go b/internal/api/activitypub/users/repliesget_test.go index a37868fd7..f81dddadd 100644 --- a/internal/api/activitypub/users/repliesget_test.go +++ b/internal/api/activitypub/users/repliesget_test.go @@ -106,7 +106,6 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() { emailSender := testrig.NewEmailSender("../../../../web/template/", nil) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) - suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() @@ -171,7 +170,6 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() { emailSender := testrig.NewEmailSender("../../../../web/template/", nil) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) - suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() diff --git a/internal/api/activitypub/users/user_test.go b/internal/api/activitypub/users/user_test.go index aed687fef..8c33ce2f2 100644 --- a/internal/api/activitypub/users/user_test.go +++ b/internal/api/activitypub/users/user_test.go @@ -31,6 +31,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -83,6 +84,13 @@ func (suite *UserStandardTestSuite) SetupTest() { suite.db = testrig.NewTestDB(&suite.state) suite.state.DB = suite.db suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage suite.mediaManager = testrig.NewTestMediaManager(&suite.state) @@ -94,8 +102,6 @@ func (suite *UserStandardTestSuite) SetupTest() { testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") suite.signatureCheck = middleware.SignatureCheck(suite.db.IsURIBlocked) - - suite.NoError(suite.processor.Start()) } func (suite *UserStandardTestSuite) TearDownTest() { diff --git a/internal/api/client/accounts/account_test.go b/internal/api/client/accounts/account_test.go index b168f216c..678fc8a5d 100644 --- a/internal/api/client/accounts/account_test.go +++ b/internal/api/client/accounts/account_test.go @@ -36,6 +36,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -86,6 +87,12 @@ func (suite *AccountStandardTestSuite) SetupTest() { suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) @@ -94,8 +101,6 @@ func (suite *AccountStandardTestSuite) SetupTest() { suite.accountsModule = accounts.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - - suite.NoError(suite.processor.Start()) } func (suite *AccountStandardTestSuite) TearDownTest() { diff --git a/internal/api/client/accounts/accounts.go b/internal/api/client/accounts/accounts.go index a6bedd6e1..298104a8d 100644 --- a/internal/api/client/accounts/accounts.go +++ b/internal/api/client/accounts/accounts.go @@ -70,6 +70,8 @@ const ( UnblockPath = BasePathWithID + "/unblock" // DeleteAccountPath is for deleting one's account via the API DeleteAccountPath = BasePath + "/delete" + // ListsPath is for seeing which lists an account is. + ListsPath = BasePathWithID + "/lists" ) type Module struct { @@ -115,4 +117,7 @@ func (m *Module) Route(attachHandler func(method string, path string, f ...gin.H // block or unblock account attachHandler(http.MethodPost, BlockPath, m.AccountBlockPOSTHandler) attachHandler(http.MethodPost, UnblockPath, m.AccountUnblockPOSTHandler) + + // account lists + attachHandler(http.MethodGet, ListsPath, m.AccountListsGETHandler) } diff --git a/internal/api/client/accounts/lists.go b/internal/api/client/accounts/lists.go new file mode 100644 index 000000000..4ce1bf729 --- /dev/null +++ b/internal/api/client/accounts/lists.go @@ -0,0 +1,97 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package accounts + +import ( + "errors" + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// AccountListsGETHandler swagger:operation GET /api/v1/accounts/{id}/lists accountLists +// +// See all lists of yours that contain requested account. +// +// --- +// tags: +// - accounts +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: Account ID. +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - read:lists +// +// responses: +// '200': +// name: lists +// description: Array of all lists containing this account. +// schema: +// type: array +// items: +// "$ref": "#/definitions/list" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) AccountListsGETHandler(c *gin.Context) { + authed, err := oauth.Authed(c, false, false, false, false) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + targetAcctID := c.Param(IDKey) + if targetAcctID == "" { + err := errors.New("no account id specified") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + lists, errWithCode := m.processor.Account().ListsGet(c.Request.Context(), authed.Account, targetAcctID) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, lists) +} diff --git a/internal/api/client/accounts/lists_test.go b/internal/api/client/accounts/lists_test.go new file mode 100644 index 000000000..6984d6ef8 --- /dev/null +++ b/internal/api/client/accounts/lists_test.go @@ -0,0 +1,103 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package accounts_test + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type ListsTestSuite struct { + AccountStandardTestSuite +} + +func (suite *ListsTestSuite) getLists(targetAccountID string, expectedHTTPStatus int, expectedBody string) []*apimodel.List { + var ( + recorder = httptest.NewRecorder() + ctx, _ = testrig.CreateGinTestContext(recorder, nil) + request = httptest.NewRequest(http.MethodGet, "http://localhost:8080/api/v1/accounts/"+targetAccountID+"/lists", nil) + ) + + // Set up the test context. + ctx.Request = request + ctx.AddParam("id", targetAccountID) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // Trigger the handler. + suite.accountsModule.AccountListsGETHandler(ctx) + + // Read the result. + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + suite.FailNow(err.Error()) + } + + errs := gtserror.MultiError{} + + // Check expected code + body. + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + } + + // If we got an expected body, return early. + if expectedBody != "" && string(b) != expectedBody { + errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + } + + if err := errs.Combine(); err != nil { + suite.FailNow("", "%v (body %s)", err, string(b)) + } + + // Return list response. + resp := new([]*apimodel.List) + if err := json.Unmarshal(b, resp); err != nil { + suite.FailNow(err.Error()) + } + + return *resp +} + +func (suite *ListsTestSuite) TestGetListsHit() { + targetAccount := suite.testAccounts["admin_account"] + suite.getLists(targetAccount.ID, http.StatusOK, `[{"id":"01H0G8E4Q2J3FE3JDWJVWEDCD1","title":"Cool Ass Posters From This Instance","replies_policy":"followed"}]`) +} + +func (suite *ListsTestSuite) TestGetListsNoHit() { + targetAccount := suite.testAccounts["remote_account_1"] + suite.getLists(targetAccount.ID, http.StatusOK, `[]`) +} + +func TestListsTestSuite(t *testing.T) { + suite.Run(t, new(ListsTestSuite)) +} diff --git a/internal/api/client/admin/admin_test.go b/internal/api/client/admin/admin_test.go index c6de665fa..261e9ff4e 100644 --- a/internal/api/client/admin/admin_test.go +++ b/internal/api/client/admin/admin_test.go @@ -36,6 +36,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -92,6 +93,12 @@ func (suite *AdminStandardTestSuite) SetupTest() { suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) diff --git a/internal/api/client/bookmarks/bookmarks_test.go b/internal/api/client/bookmarks/bookmarks_test.go index 6f20c4762..b41964584 100644 --- a/internal/api/client/bookmarks/bookmarks_test.go +++ b/internal/api/client/bookmarks/bookmarks_test.go @@ -42,6 +42,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -98,6 +99,13 @@ func (suite *BookmarkTestSuite) SetupTest() { suite.state.Storage = suite.storage suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -107,8 +115,6 @@ func (suite *BookmarkTestSuite) SetupTest() { suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.statusModule = statuses.New(suite.processor) suite.bookmarkModule = bookmarks.New(suite.processor) - - suite.NoError(suite.processor.Start()) } func (suite *BookmarkTestSuite) TearDownTest() { diff --git a/internal/api/client/favourites/favourites_test.go b/internal/api/client/favourites/favourites_test.go index c6e42e113..1a3a324a8 100644 --- a/internal/api/client/favourites/favourites_test.go +++ b/internal/api/client/favourites/favourites_test.go @@ -29,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -82,6 +83,13 @@ func (suite *FavouritesStandardTestSuite) SetupTest() { suite.state.Storage = suite.storage suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -90,8 +98,6 @@ func (suite *FavouritesStandardTestSuite) SetupTest() { suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.favModule = favourites.New(suite.processor) - - suite.NoError(suite.processor.Start()) } func (suite *FavouritesStandardTestSuite) TearDownTest() { diff --git a/internal/api/client/favourites/favouritesget.go b/internal/api/client/favourites/favouritesget.go index 198dd1b12..112bbd856 100644 --- a/internal/api/client/favourites/favouritesget.go +++ b/internal/api/client/favourites/favouritesget.go @@ -128,7 +128,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) { limit = int(i) } - resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit) + resp, errWithCode := m.processor.Timeline().FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return diff --git a/internal/api/client/followrequests/followrequest_test.go b/internal/api/client/followrequests/followrequest_test.go index a1aca89ff..58d191fa7 100644 --- a/internal/api/client/followrequests/followrequest_test.go +++ b/internal/api/client/followrequests/followrequest_test.go @@ -35,6 +35,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -83,6 +84,12 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() { suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) @@ -90,8 +97,6 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() { suite.followRequestModule = followrequests.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - - suite.NoError(suite.processor.Start()) } func (suite *FollowRequestStandardTestSuite) TearDownTest() { diff --git a/internal/api/client/instance/instance_test.go b/internal/api/client/instance/instance_test.go index 2fe29f75e..745d76a24 100644 --- a/internal/api/client/instance/instance_test.go +++ b/internal/api/client/instance/instance_test.go @@ -35,6 +35,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -85,6 +86,12 @@ func (suite *InstanceStandardTestSuite) SetupTest() { suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) diff --git a/internal/api/client/lists/list.go b/internal/api/client/lists/list.go index b1c193397..515075271 100644 --- a/internal/api/client/lists/list.go +++ b/internal/api/client/lists/list.go @@ -25,8 +25,15 @@ import ( ) const ( + IDKey = "id" // BasePath is the base path for serving the lists API, minus the 'api' prefix - BasePath = "/v1/lists" + BasePath = "/v1/lists" + BasePathWithID = BasePath + "/:" + IDKey + AccountsPath = BasePathWithID + "/accounts" + MaxIDKey = "max_id" + LimitKey = "limit" + SinceIDKey = "since_id" + MinIDKey = "min_id" ) type Module struct { @@ -40,5 +47,15 @@ func New(processor *processing.Processor) *Module { } func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { + // create / get / update / delete lists + attachHandler(http.MethodPost, BasePath, m.ListCreatePOSTHandler) attachHandler(http.MethodGet, BasePath, m.ListsGETHandler) + attachHandler(http.MethodGet, BasePathWithID, m.ListGETHandler) + attachHandler(http.MethodPut, BasePathWithID, m.ListUpdatePUTHandler) + attachHandler(http.MethodDelete, BasePathWithID, m.ListDELETEHandler) + + // get / add / remove list accounts + attachHandler(http.MethodGet, AccountsPath, m.ListAccountsGETHandler) + attachHandler(http.MethodPost, AccountsPath, m.ListAccountsPOSTHandler) + attachHandler(http.MethodDelete, AccountsPath, m.ListAccountsDELETEHandler) } diff --git a/internal/api/client/lists/listaccounts.go b/internal/api/client/lists/listaccounts.go new file mode 100644 index 000000000..3a24cab27 --- /dev/null +++ b/internal/api/client/lists/listaccounts.go @@ -0,0 +1,156 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists + +import ( + "errors" + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// ListAccountsGETHandler swagger:operation GET /api/v1/list/{id}/accounts listAccounts +// +// Page through accounts in this list. +// +// The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline. +// +// Example: +// +// ``` +// ; rel="next", ; rel="prev" +// ```` +// +// --- +// tags: +// - lists +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// - +// name: max_id +// type: string +// description: >- +// Return only list entries *OLDER* than the given max ID. +// The account from the list entry with the specified ID will not be included in the response. +// in: query +// required: false +// - +// name: since_id +// type: string +// description: >- +// Return only list entries *NEWER* than the given since ID. +// The account from the list entry with the specified ID will not be included in the response. +// in: query +// - +// name: min_id +// type: string +// description: >- +// Return only list entries *IMMEDIATELY NEWER* than the given min ID. +// The account from the list entry with the specified ID will not be included in the response. +// in: query +// required: false +// - +// name: limit +// type: integer +// description: Number of accounts to return. +// default: 20 +// in: query +// required: false +// +// security: +// - OAuth2 Bearer: +// - read:lists +// +// responses: +// '200': +// headers: +// Link: +// type: string +// description: Links to the next and previous queries. +// name: accounts +// description: Array of accounts. +// schema: +// type: array +// items: +// "$ref": "#/definitions/account" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) ListAccountsGETHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + targetListID := c.Param(IDKey) + if targetListID == "" { + err := errors.New("no list id specified") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + resp, errWithCode := m.processor.List().GetListAccounts( + c.Request.Context(), + authed.Account, + targetListID, + c.Query(MaxIDKey), + c.Query(SinceIDKey), + c.Query(MinIDKey), + limit, + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if resp.LinkHeader != "" { + c.Header("Link", resp.LinkHeader) + } + c.JSON(http.StatusOK, resp.Items) +} diff --git a/internal/api/client/lists/listaccountsadd.go b/internal/api/client/lists/listaccountsadd.go new file mode 100644 index 000000000..5cf907b06 --- /dev/null +++ b/internal/api/client/lists/listaccountsadd.go @@ -0,0 +1,120 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists + +import ( + "errors" + "net/http" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// ListAccountsPOSTHandler swagger:operation POST /api/v1/list/{id}/accounts addListAccounts +// +// Add one or more accounts to the given list. +// +// --- +// tags: +// - lists +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// - +// name: account_ids +// type: array +// items: +// type: string +// description: >- +// Array of accountIDs to modify. +// Each accountID must correspond to an account +// that the requesting account follows. +// in: formData +// required: true +// +// security: +// - OAuth2 Bearer: +// - read:lists +// +// responses: +// '200': +// description: list accounts updated +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) ListAccountsPOSTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + targetListID := c.Param(IDKey) + if targetListID == "" { + err := errors.New("no list id specified") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + form := &apimodel.ListAccountsChangeRequest{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if len(form.AccountIDs) == 0 { + err := errors.New("no account IDs given") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if errWithCode := m.processor.List().AddToList(c.Request.Context(), authed.Account, targetListID, form.AccountIDs); errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, gin.H{}) +} diff --git a/internal/api/client/lists/listaccountsremove.go b/internal/api/client/lists/listaccountsremove.go new file mode 100644 index 000000000..6ce7e3cd3 --- /dev/null +++ b/internal/api/client/lists/listaccountsremove.go @@ -0,0 +1,120 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists + +import ( + "errors" + "net/http" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// ListAccountsDELETEHandler swagger:operation DELETE /api/v1/list/{id}/accounts removeListAccounts +// +// Remove one or more accounts from the given list. +// +// --- +// tags: +// - lists +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// - +// name: account_ids +// type: array +// items: +// type: string +// description: >- +// Array of accountIDs to modify. +// Each accountID must correspond to an account +// that the requesting account follows. +// in: formData +// required: true +// +// security: +// - OAuth2 Bearer: +// - read:lists +// +// responses: +// '200': +// description: list accounts updated +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) ListAccountsDELETEHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + targetListID := c.Param(IDKey) + if targetListID == "" { + err := errors.New("no list id specified") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + form := &apimodel.ListAccountsChangeRequest{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if len(form.AccountIDs) == 0 { + err := errors.New("no account IDs given") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if errWithCode := m.processor.List().RemoveFromList(c.Request.Context(), authed.Account, targetListID, form.AccountIDs); errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, gin.H{}) +} diff --git a/internal/api/client/lists/listcreate.go b/internal/api/client/lists/listcreate.go new file mode 100644 index 000000000..09a654c74 --- /dev/null +++ b/internal/api/client/lists/listcreate.go @@ -0,0 +1,106 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/validate" +) + +// ListCreatePOSTHandler swagger:operation POST /api/v1/list listCreate +// +// Create a new list. +// +// --- +// tags: +// - lists +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// security: +// - OAuth2 Bearer: +// - write:lists +// +// responses: +// '200': +// description: "The newly created list." +// schema: +// "$ref": "#/definitions/list" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) ListCreatePOSTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + form := &apimodel.ListCreateRequest{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if err := validate.ListTitle(form.Title); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + repliesPolicy := gtsmodel.RepliesPolicy(strings.ToLower(form.RepliesPolicy)) + if err := validate.ListRepliesPolicy(repliesPolicy); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + apiList, errWithCode := m.processor.List().Create(c.Request.Context(), authed.Account, form.Title, repliesPolicy) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, apiList) +} diff --git a/internal/api/client/lists/listdelete.go b/internal/api/client/lists/listdelete.go new file mode 100644 index 000000000..394ddfb6b --- /dev/null +++ b/internal/api/client/lists/listdelete.go @@ -0,0 +1,91 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists + +import ( + "errors" + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// ListDELETEHandler swagger:operation DELETE /api/v1/list/{id} listDelete +// +// Delete a single list with the given ID. +// +// --- +// tags: +// - lists +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - write:lists +// +// responses: +// '200': +// description: list deleted +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) ListDELETEHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + targetListID := c.Param(IDKey) + if targetListID == "" { + err := errors.New("no list id specified") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if errWithCode := m.processor.List().Delete(c.Request.Context(), authed.Account, targetListID); errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, gin.H{}) +} diff --git a/internal/api/client/lists/listget.go b/internal/api/client/lists/listget.go new file mode 100644 index 000000000..3aed594d4 --- /dev/null +++ b/internal/api/client/lists/listget.go @@ -0,0 +1,95 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists + +import ( + "errors" + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// ListGETHandler swagger:operation GET /api/v1/list/{id} list +// +// Get a single list with the given ID. +// +// --- +// tags: +// - lists +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - read:lists +// +// responses: +// '200': +// name: list +// description: Requested list. +// schema: +// "$ref": "#/definitions/list" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) ListGETHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + targetListID := c.Param(IDKey) + if targetListID == "" { + err := errors.New("no list id specified") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + resp, errWithCode := m.processor.List().Get(c.Request.Context(), authed.Account, targetListID) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, resp) +} diff --git a/internal/api/client/lists/listsgets.go b/internal/api/client/lists/listsget.go similarity index 60% rename from internal/api/client/lists/listsgets.go rename to internal/api/client/lists/listsget.go index 66b713611..f16152a9d 100644 --- a/internal/api/client/lists/listsgets.go +++ b/internal/api/client/lists/listsget.go @@ -26,9 +26,42 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -// ListsGETHandler returns a list of lists created by/for the authed account +// ListsGETHandler swagger:operation GET /api/v1/lists lists +// +// Get all lists for owned by authorized user. +// +// --- +// tags: +// - lists +// +// produces: +// - application/json +// +// security: +// - OAuth2 Bearer: +// - read:lists +// +// responses: +// '200': +// name: lists +// description: Array of all lists owned by the requesting user. +// schema: +// type: array +// items: +// "$ref": "#/definitions/list" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error func (m *Module) ListsGETHandler(c *gin.Context) { - if _, err := oauth.Authed(c, true, true, true, true); err != nil { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) return } @@ -38,6 +71,11 @@ func (m *Module) ListsGETHandler(c *gin.Context) { return } - // todo: implement this; currently it's a no-op - c.JSON(http.StatusOK, []string{}) + lists, errWithCode := m.processor.List().GetAll(c.Request.Context(), authed.Account) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, lists) } diff --git a/internal/api/client/lists/listupdate.go b/internal/api/client/lists/listupdate.go new file mode 100644 index 000000000..80c5a8be3 --- /dev/null +++ b/internal/api/client/lists/listupdate.go @@ -0,0 +1,152 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists + +import ( + "errors" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/validate" +) + +// ListUpdatePUTHandler swagger:operation PUT /api/v1/list listUpdate +// +// Update an existing list. +// +// --- +// tags: +// - lists +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// - +// name: title +// type: string +// description: Title of this list. +// in: formData +// example: Cool People +// - +// name: replies_policy +// type: string +// description: |- +// RepliesPolicy for this list. +// followed = Show replies to any followed user +// list = Show replies to members of the list +// none = Show replies to no one +// in: formData +// example: list +// +// security: +// - OAuth2 Bearer: +// - write:lists +// +// responses: +// '200': +// description: "The newly updated list." +// schema: +// "$ref": "#/definitions/list" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) ListUpdatePUTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + targetListID := c.Param(IDKey) + if targetListID == "" { + err := errors.New("no list id specified") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + form := &apimodel.ListUpdateRequest{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if form.Title != nil { + if err := validate.ListTitle(*form.Title); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + } + + var repliesPolicy *gtsmodel.RepliesPolicy + if form.RepliesPolicy != nil { + rp := gtsmodel.RepliesPolicy(strings.ToLower(*form.RepliesPolicy)) + + if err := validate.ListRepliesPolicy(rp); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + repliesPolicy = &rp + } + + if form.Title == nil && repliesPolicy == nil { + err = errors.New("neither title nor replies_policy was set; nothing to update") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + apiList, errWithCode := m.processor.List().Update(c.Request.Context(), authed.Account, targetListID, form.Title, repliesPolicy) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, apiList) +} diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index d41222dd2..f67144ce1 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -44,6 +44,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -90,6 +91,13 @@ func (suite *MediaCreateTestSuite) SetupSuite() { suite.state.Storage = suite.storage suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go index cd0e65013..c436ee000 100644 --- a/internal/api/client/media/mediaupdate_test.go +++ b/internal/api/client/media/mediaupdate_test.go @@ -42,6 +42,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -87,6 +88,13 @@ func (suite *MediaUpdateTestSuite) SetupSuite() { suite.state.Storage = suite.storage suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) diff --git a/internal/api/client/notifications/notificationget.go b/internal/api/client/notifications/notificationget.go index 3efdf171d..98e32498b 100644 --- a/internal/api/client/notifications/notificationget.go +++ b/internal/api/client/notifications/notificationget.go @@ -77,7 +77,7 @@ func (m *Module) NotificationGETHandler(c *gin.Context) { return } - resp, errWithCode := m.processor.NotificationGet(c.Request.Context(), authed.Account, targetNotifID) + resp, errWithCode := m.processor.Timeline().NotificationGet(c.Request.Context(), authed.Account, targetNotifID) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return diff --git a/internal/api/client/notifications/notificationsclear.go b/internal/api/client/notifications/notificationsclear.go index 17592f36d..cf3706a7c 100644 --- a/internal/api/client/notifications/notificationsclear.go +++ b/internal/api/client/notifications/notificationsclear.go @@ -69,7 +69,7 @@ func (m *Module) NotificationsClearPOSTHandler(c *gin.Context) { return } - errWithCode := m.processor.NotificationsClear(c.Request.Context(), authed) + errWithCode := m.processor.Timeline().NotificationsClear(c.Request.Context(), authed) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return diff --git a/internal/api/client/notifications/notificationsget.go b/internal/api/client/notifications/notificationsget.go index 6ce8adcab..fd175a115 100644 --- a/internal/api/client/notifications/notificationsget.go +++ b/internal/api/client/notifications/notificationsget.go @@ -138,7 +138,7 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) { limit = int(i) } - resp, errWithCode := m.processor.NotificationsGet( + resp, errWithCode := m.processor.Timeline().NotificationsGet( c.Request.Context(), authed, c.Query(MaxIDKey), diff --git a/internal/api/client/reports/reports_test.go b/internal/api/client/reports/reports_test.go index bf0514122..a28f8ffa3 100644 --- a/internal/api/client/reports/reports_test.go +++ b/internal/api/client/reports/reports_test.go @@ -28,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -77,6 +78,12 @@ func (suite *ReportsStandardTestSuite) SetupTest() { suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) @@ -85,8 +92,6 @@ func (suite *ReportsStandardTestSuite) SetupTest() { suite.reportsModule = reports.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - - suite.NoError(suite.processor.Start()) } func (suite *ReportsStandardTestSuite) TearDownTest() { diff --git a/internal/api/client/search/search_test.go b/internal/api/client/search/search_test.go index 626a366f3..95507fcd9 100644 --- a/internal/api/client/search/search_test.go +++ b/internal/api/client/search/search_test.go @@ -35,6 +35,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -81,6 +82,12 @@ func (suite *SearchStandardTestSuite) SetupTest() { suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) @@ -89,8 +96,6 @@ func (suite *SearchStandardTestSuite) SetupTest() { suite.searchModule = search.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - - suite.NoError(suite.processor.Start()) } func (suite *SearchStandardTestSuite) TearDownTest() { diff --git a/internal/api/client/statuses/status_test.go b/internal/api/client/statuses/status_test.go index 0a006631c..84e71f9c1 100644 --- a/internal/api/client/statuses/status_test.go +++ b/internal/api/client/statuses/status_test.go @@ -29,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -83,6 +84,12 @@ func (suite *StatusStandardTestSuite) SetupTest() { suite.tc = testrig.NewTestTypeConverter(suite.db) + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -91,8 +98,6 @@ func (suite *StatusStandardTestSuite) SetupTest() { suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.statusModule = statuses.New(suite.processor) - - suite.NoError(suite.processor.Start()) } func (suite *StatusStandardTestSuite) TearDownTest() { diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index f41bc0ac2..88c682a75 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -82,6 +82,20 @@ import ( // `direct`: receive updates for direct messages. // in: query // required: true +// - +// name: list +// type: string +// description: |- +// ID of the list to subscribe to. +// Only used if stream type is 'list'. +// in: query +// - +// name: tag +// type: string +// description: |- +// Name of the tag to subscribe to. +// Only used if stream type is 'hashtag' or 'hashtag:local'. +// in: query // // security: // - OAuth2 Bearer: @@ -164,8 +178,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) { } // Get the initial stream type, if there is one. - // streamType will be an empty string if one wasn't supplied. Open() will deal with this + // By appending other query params to the streamType, + // we can allow for streaming for specific list IDs + // or hashtags. streamType := c.Query(StreamQueryKey) + if list := c.Query(StreamListKey); list != "" { + streamType += ":" + list + } else if tag := c.Query(StreamTagKey); tag != "" { + streamType += ":" + tag + } + stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) @@ -240,28 +262,41 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // If the message contains 'stream' and 'type' fields, we can // update the set of timelines that are subscribed for events. - // everything else is ignored. - action := msg["type"] - streamType := msg["stream"] - - // Ignore if the streamType is unknown (or missing), so a bad - // client can't cause extra memory allocations - if !slices.Contains(streampkg.AllStatusTimelines, streamType) { - l.Warnf("Unknown 'stream' field: %v", msg) + updateType, ok := msg["type"] + if !ok { + l.Warn("'type' field not provided") continue } - switch action { + updateStream, ok := msg["stream"] + if !ok { + l.Warn("'stream' field not provided") + continue + } + + // Ignore if the updateStreamType is unknown (or missing), + // so a bad client can't cause extra memory allocations + if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { + l.Warnf("unknown 'stream' field: %v", msg) + continue + } + + updateList, ok := msg["list"] + if ok { + updateStream += ":" + updateList + } + + switch updateType { case "subscribe": stream.Lock() - stream.Timelines[streamType] = true + stream.StreamTypes[updateStream] = true stream.Unlock() case "unsubscribe": stream.Lock() - delete(stream.Timelines, streamType) + delete(stream.StreamTypes, updateStream) stream.Unlock() default: - l.Warnf("Invalid 'type' field: %v", msg) + l.Warnf("invalid 'type' field: %v", msg) } } }() @@ -276,7 +311,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { case msg := <-stream.Messages: l.Tracef("sending message to websocket: %+v", msg) if err := wsConn.WriteJSON(msg); err != nil { - l.Errorf("error writing json to websocket: %v", err) + l.Debugf("error writing json to websocket: %v", err) return } @@ -290,7 +325,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { websocket.PingMessage, []byte{}, ); err != nil { - l.Errorf("error writing ping to websocket: %v", err) + l.Debugf("error writing ping to websocket: %v", err) return } } diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go index 71b325089..edddeab73 100644 --- a/internal/api/client/streaming/streaming.go +++ b/internal/api/client/streaming/streaming.go @@ -27,17 +27,12 @@ import ( ) const ( - // BasePath is the path for the streaming api, minus the 'api' prefix - BasePath = "/v1/streaming" - - // StreamQueryKey is the query key for the type of stream being requested - StreamQueryKey = "stream" - - // AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests. - AccessTokenQueryKey = "access_token" - // AccessTokenHeader is the header for an oauth access token that can be passed in streaming requests instead of AccessTokenQueryKey - //nolint:gosec - AccessTokenHeader = "Sec-Websocket-Protocol" + BasePath = "/v1/streaming" // path for the streaming api, minus the 'api' prefix + StreamQueryKey = "stream" // type of stream being requested + StreamListKey = "list" // id of list being requested + StreamTagKey = "tag" // name of tag being requested + AccessTokenQueryKey = "access_token" // oauth access token + AccessTokenHeader = "Sec-Websocket-Protocol" //nolint:gosec ) type Module struct { diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go index b429461c6..cece99bac 100644 --- a/internal/api/client/streaming/streaming_test.go +++ b/internal/api/client/streaming/streaming_test.go @@ -41,6 +41,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -94,6 +95,13 @@ func (suite *StreamingTestSuite) SetupTest() { suite.state.Storage = suite.storage suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -102,7 +110,6 @@ func (suite *StreamingTestSuite) SetupTest() { suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.streamingModule = streaming.New(suite.processor, 1, 4096) - suite.NoError(suite.processor.Start()) } func (suite *StreamingTestSuite) TearDownTest() { diff --git a/internal/api/client/timelines/home.go b/internal/api/client/timelines/home.go index f63d14fd3..f64d61287 100644 --- a/internal/api/client/timelines/home.go +++ b/internal/api/client/timelines/home.go @@ -18,9 +18,7 @@ package timelines import ( - "fmt" "net/http" - "strconv" "github.com/gin-gonic/gin" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" @@ -120,49 +118,27 @@ func (m *Module) HomeTimelineGETHandler(c *gin.Context) { return } - maxID := "" - maxIDString := c.Query(MaxIDKey) - if maxIDString != "" { - maxID = maxIDString + limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - sinceID := "" - sinceIDString := c.Query(SinceIDKey) - if sinceIDString != "" { - sinceID = sinceIDString + local, errWithCode := apiutil.ParseLocal(c.Query(apiutil.LocalKey), false) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - minID := "" - minIDString := c.Query(MinIDKey) - if minIDString != "" { - minID = minIDString - } - - limit := 20 - limitString := c.Query(LimitKey) - if limitString != "" { - i, err := strconv.ParseInt(limitString, 10, 32) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", LimitKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - limit = int(i) - } - - local := false - localString := c.Query(LocalKey) - if localString != "" { - i, err := strconv.ParseBool(localString) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", LocalKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - local = i - } - - resp, errWithCode := m.processor.HomeTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local) + resp, errWithCode := m.processor.Timeline().HomeTimelineGet( + c.Request.Context(), + authed, + c.Query(MaxIDKey), + c.Query(SinceIDKey), + c.Query(MinIDKey), + limit, + local, + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return diff --git a/internal/api/client/timelines/list.go b/internal/api/client/timelines/list.go new file mode 100644 index 000000000..4f5232d8b --- /dev/null +++ b/internal/api/client/timelines/list.go @@ -0,0 +1,152 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package timelines + +import ( + "errors" + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// ListTimelineGETHandler swagger:operation GET /api/v1/timelines/list/{id} listTimeline +// +// See statuses/posts from the given list timeline. +// +// The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer). +// +// The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline. +// +// Example: +// +// ``` +// ; rel="next", ; rel="prev" +// ```` +// +// --- +// tags: +// - timelines +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// - +// name: max_id +// type: string +// description: >- +// Return only statuses *OLDER* than the given max status ID. +// The status with the specified ID will not be included in the response. +// in: query +// required: false +// - +// name: since_id +// type: string +// description: >- +// Return only statuses *NEWER* than the given since status ID. +// The status with the specified ID will not be included in the response. +// in: query +// - +// name: min_id +// type: string +// description: >- +// Return only statuses *NEWER* than the given since status ID. +// The status with the specified ID will not be included in the response. +// in: query +// required: false +// - +// name: limit +// type: integer +// description: Number of statuses to return. +// default: 20 +// in: query +// required: false +// +// security: +// - OAuth2 Bearer: +// - read:lists +// +// responses: +// '200': +// name: statuses +// description: Array of statuses. +// schema: +// type: array +// items: +// "$ref": "#/definitions/status" +// headers: +// Link: +// type: string +// description: Links to the next and previous queries. +// '401': +// description: unauthorized +// '400': +// description: bad request +func (m *Module) ListTimelineGETHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + targetListID := c.Param(IDKey) + if targetListID == "" { + err := errors.New("no list id specified") + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + resp, errWithCode := m.processor.Timeline().ListTimelineGet( + c.Request.Context(), + authed, + targetListID, + c.Query(MaxIDKey), + c.Query(SinceIDKey), + c.Query(MinIDKey), + limit, + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if resp.LinkHeader != "" { + c.Header("Link", resp.LinkHeader) + } + c.JSON(http.StatusOK, resp.Items) +} diff --git a/internal/api/client/timelines/public.go b/internal/api/client/timelines/public.go index a8a61c398..5be9fcaa8 100644 --- a/internal/api/client/timelines/public.go +++ b/internal/api/client/timelines/public.go @@ -18,9 +18,7 @@ package timelines import ( - "fmt" "net/http" - "strconv" "github.com/gin-gonic/gin" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" @@ -131,49 +129,27 @@ func (m *Module) PublicTimelineGETHandler(c *gin.Context) { return } - maxID := "" - maxIDString := c.Query(MaxIDKey) - if maxIDString != "" { - maxID = maxIDString + limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - sinceID := "" - sinceIDString := c.Query(SinceIDKey) - if sinceIDString != "" { - sinceID = sinceIDString + local, errWithCode := apiutil.ParseLocal(c.Query(apiutil.LocalKey), false) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return } - minID := "" - minIDString := c.Query(MinIDKey) - if minIDString != "" { - minID = minIDString - } - - limit := 20 - limitString := c.Query(LimitKey) - if limitString != "" { - i, err := strconv.ParseInt(limitString, 10, 32) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", LimitKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - limit = int(i) - } - - local := false - localString := c.Query(LocalKey) - if localString != "" { - i, err := strconv.ParseBool(localString) - if err != nil { - err := fmt.Errorf("error parsing %s: %s", LocalKey, err) - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) - return - } - local = i - } - - resp, errWithCode := m.processor.PublicTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local) + resp, errWithCode := m.processor.Timeline().PublicTimelineGet( + c.Request.Context(), + authed, + c.Query(MaxIDKey), + c.Query(SinceIDKey), + c.Query(MinIDKey), + limit, + local, + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return diff --git a/internal/api/client/timelines/timeline.go b/internal/api/client/timelines/timeline.go index bf8ef1e2e..2580333d9 100644 --- a/internal/api/client/timelines/timeline.go +++ b/internal/api/client/timelines/timeline.go @@ -27,10 +27,12 @@ import ( const ( // BasePath is the base URI path for serving timelines, minus the 'api' prefix. BasePath = "/v1/timelines" + IDKey = "id" // HomeTimeline is the path for the home timeline HomeTimeline = BasePath + "/home" // PublicTimeline is the path for the public (and public local) timeline PublicTimeline = BasePath + "/public" + ListTimeline = BasePath + "/list/:" + IDKey // MaxIDKey is the url query for setting a max status ID to return MaxIDKey = "max_id" // SinceIDKey is the url query for returning results newer than the given ID @@ -56,4 +58,5 @@ func New(processor *processing.Processor) *Module { func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { attachHandler(http.MethodGet, HomeTimeline, m.HomeTimelineGETHandler) attachHandler(http.MethodGet, PublicTimeline, m.PublicTimelineGETHandler) + attachHandler(http.MethodGet, ListTimeline, m.ListTimelineGETHandler) } diff --git a/internal/api/client/user/user_test.go b/internal/api/client/user/user_test.go index c26a04f31..06fc2c000 100644 --- a/internal/api/client/user/user_test.go +++ b/internal/api/client/user/user_test.go @@ -29,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -73,6 +74,13 @@ func (suite *UserStandardTestSuite) SetupTest() { suite.state.Storage = suite.storage suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) @@ -81,8 +89,6 @@ func (suite *UserStandardTestSuite) SetupTest() { suite.userModule = user.New(suite.processor) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - - suite.NoError(suite.processor.Start()) } func (suite *UserStandardTestSuite) TearDownTest() { diff --git a/internal/api/fileserver/fileserver_test.go b/internal/api/fileserver/fileserver_test.go index c2433d94a..70bd23c15 100644 --- a/internal/api/fileserver/fileserver_test.go +++ b/internal/api/fileserver/fileserver_test.go @@ -33,6 +33,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -81,6 +82,13 @@ func (suite *FileserverTestSuite) SetupSuite() { suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) diff --git a/internal/api/model/list.go b/internal/api/model/list.go index d50c68f70..f897bcc88 100644 --- a/internal/api/model/list.go +++ b/internal/api/model/list.go @@ -17,14 +17,57 @@ package model -// List represents a list of some users that the authenticated user follows. +// List represents a user-created list of accounts that the user follows. +// +// swagger:model list type List struct { - // The internal database ID of the list. + // The ID of the list. ID string `json:"id"` // The user-defined title of the list. Title string `json:"title"` - // followed = Show replies to any followed user + // RepliesPolicy for this list. + // followed = Show replies to any followed user // list = Show replies to members of the list // none = Show replies to no one RepliesPolicy string `json:"replies_policy"` } + +// ListCreateRequest models list creation parameters. +// +// swagger:parameters listCreate +type ListCreateRequest struct { + // Title of this list. + // example: Cool People + // in: formData + // required: true + Title string `form:"title" json:"title" xml:"title"` + // RepliesPolicy for this list. + // followed = Show replies to any followed user + // list = Show replies to members of the list + // none = Show replies to no one + // example: list + // default: list + // in: formData + RepliesPolicy string `form:"replies_policy" json:"replies_policy" xml:"replies_policy"` +} + +// ListUpdateRequest models list update parameters. +// +// swagger:parameters listUpdate +type ListUpdateRequest struct { + // Title of this list. + // example: Cool People + // in: formData + Title *string `form:"title" json:"title" xml:"title"` + // RepliesPolicy for this list. + // followed = Show replies to any followed user + // list = Show replies to members of the list + // none = Show replies to no one + // in: formData + RepliesPolicy *string `form:"replies_policy" json:"replies_policy" xml:"replies_policy"` +} + +// swagger:ignore +type ListAccountsChangeRequest struct { + AccountIDs []string `form:"account_ids[]" json:"account_ids" xml:"account_ids"` +} diff --git a/internal/api/util/parsequery.go b/internal/api/util/parsequery.go new file mode 100644 index 000000000..92578a739 --- /dev/null +++ b/internal/api/util/parsequery.go @@ -0,0 +1,58 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package util + +import ( + "fmt" + "strconv" + + "github.com/superseriousbusiness/gotosocial/internal/gtserror" +) + +const ( + LimitKey = "limit" + LocalKey = "local" +) + +func ParseLimit(limit string, defaultLimit int) (int, gtserror.WithCode) { + if limit == "" { + return defaultLimit, nil + } + + i, err := strconv.Atoi(limit) + if err != nil { + err := fmt.Errorf("error parsing %s: %w", LimitKey, err) + return 0, gtserror.NewErrorBadRequest(err, err.Error()) + } + + return i, nil +} + +func ParseLocal(local string, defaultLocal bool) (bool, gtserror.WithCode) { + if local == "" { + return defaultLocal, nil + } + + i, err := strconv.ParseBool(local) + if err != nil { + err := fmt.Errorf("error parsing %s: %w", LocalKey, err) + return false, gtserror.NewErrorBadRequest(err, err.Error()) + } + + return i, nil +} diff --git a/internal/api/wellknown/webfinger/webfinger_test.go b/internal/api/wellknown/webfinger/webfinger_test.go index 26143942c..df730e302 100644 --- a/internal/api/wellknown/webfinger/webfinger_test.go +++ b/internal/api/wellknown/webfinger/webfinger_test.go @@ -30,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -79,6 +80,13 @@ func (suite *WebfingerStandardTestSuite) SetupTest() { suite.db = testrig.NewTestDB(&suite.state) suite.state.DB = suite.db suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage suite.mediaManager = testrig.NewTestMediaManager(&suite.state) @@ -89,8 +97,6 @@ func (suite *WebfingerStandardTestSuite) SetupTest() { suite.oauthServer = testrig.NewTestOauthServer(suite.db) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - - suite.NoError(suite.processor.Start()) } func (suite *WebfingerStandardTestSuite) TearDownTest() { diff --git a/internal/cache/gts.go b/internal/cache/gts.go index 1032a5611..3a2d09736 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -35,6 +35,8 @@ type GTSCaches struct { emojiCategory *result.Cache[*gtsmodel.EmojiCategory] follow *result.Cache[*gtsmodel.Follow] followRequest *result.Cache[*gtsmodel.FollowRequest] + list *result.Cache[*gtsmodel.List] + listEntry *result.Cache[*gtsmodel.ListEntry] media *result.Cache[*gtsmodel.MediaAttachment] mention *result.Cache[*gtsmodel.Mention] notification *result.Cache[*gtsmodel.Notification] @@ -57,6 +59,8 @@ func (c *GTSCaches) Init() { c.initEmojiCategory() c.initFollow() c.initFollowRequest() + c.initList() + c.initListEntry() c.initMedia() c.initMention() c.initNotification() @@ -76,6 +80,8 @@ func (c *GTSCaches) Start() { tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStart(c.follow, config.GetCacheGTSFollowSweepFreq()) tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) + tryStart(c.list, config.GetCacheGTSListSweepFreq()) + tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) tryStart(c.media, config.GetCacheGTSMediaSweepFreq()) tryStart(c.mention, config.GetCacheGTSMentionSweepFreq()) tryStart(c.notification, config.GetCacheGTSNotificationSweepFreq()) @@ -100,6 +106,8 @@ func (c *GTSCaches) Stop() { tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq()) tryStop(c.follow, config.GetCacheGTSFollowSweepFreq()) tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq()) + tryStop(c.list, config.GetCacheGTSListSweepFreq()) + tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq()) tryStop(c.media, config.GetCacheGTSMediaSweepFreq()) tryStop(c.mention, config.GetCacheGTSNotificationSweepFreq()) tryStop(c.notification, config.GetCacheGTSNotificationSweepFreq()) @@ -146,6 +154,16 @@ func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] { return c.followRequest } +// List provides access to the gtsmodel List database cache. +func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] { + return c.list +} + +// ListEntry provides access to the gtsmodel ListEntry database cache. +func (c *GTSCaches) ListEntry() *result.Cache[*gtsmodel.ListEntry] { + return c.listEntry +} + // Media provides access to the gtsmodel Media database cache. func (c *GTSCaches) Media() *result.Cache[*gtsmodel.MediaAttachment] { return c.media @@ -283,6 +301,30 @@ func (c *GTSCaches) initFollowRequest() { c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true) } +func (c *GTSCaches) initList() { + c.list = result.New([]result.Lookup{ + {Name: "ID"}, + }, func(l1 *gtsmodel.List) *gtsmodel.List { + l2 := new(gtsmodel.List) + *l2 = *l1 + return l2 + }, config.GetCacheGTSListMaxSize()) + c.list.SetTTL(config.GetCacheGTSListTTL(), true) + c.list.IgnoreErrors(ignoreErrors) +} + +func (c *GTSCaches) initListEntry() { + c.listEntry = result.New([]result.Lookup{ + {Name: "ID"}, + }, func(l1 *gtsmodel.ListEntry) *gtsmodel.ListEntry { + l2 := new(gtsmodel.ListEntry) + *l2 = *l1 + return l2 + }, config.GetCacheGTSListEntryMaxSize()) + c.list.SetTTL(config.GetCacheGTSListEntryTTL(), true) + c.list.IgnoreErrors(ignoreErrors) +} + func (c *GTSCaches) initMedia() { c.media = result.New([]result.Lookup{ {Name: "ID"}, @@ -359,7 +401,6 @@ func (c *GTSCaches) initStatusFave() { c.status.IgnoreErrors(ignoreErrors) } -// initTombstone will initialize the gtsmodel.Tombstone cache. func (c *GTSCaches) initTombstone() { c.tombstone = result.New([]result.Lookup{ {Name: "ID"}, diff --git a/internal/config/config.go b/internal/config/config.go index 7119fc4a7..0b7b527ea 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -199,6 +199,14 @@ type GTSCacheConfiguration struct { FollowRequestTTL time.Duration `name:"follow-request-ttl"` FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"` + ListMaxSize int `name:"list-max-size"` + ListTTL time.Duration `name:"list-ttl"` + ListSweepFreq time.Duration `name:"list-sweep-freq"` + + ListEntryMaxSize int `name:"list-entry-max-size"` + ListEntryTTL time.Duration `name:"list-entry-ttl"` + ListEntrySweepFreq time.Duration `name:"list-entry-sweep-freq"` + MediaMaxSize int `name:"media-max-size"` MediaTTL time.Duration `name:"media-ttl"` MediaSweepFreq time.Duration `name:"media-sweep-freq"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 6e2d141be..53d994cda 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -153,6 +153,14 @@ var Defaults = Configuration{ FollowRequestTTL: time.Minute * 30, FollowRequestSweepFreq: time.Minute, + ListMaxSize: 2000, + ListTTL: time.Minute * 30, + ListSweepFreq: time.Minute, + + ListEntryMaxSize: 2000, + ListEntryTTL: time.Minute * 30, + ListEntrySweepFreq: time.Minute, + MediaMaxSize: 1000, MediaTTL: time.Minute * 30, MediaSweepFreq: time.Minute, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index 9993d7aaf..2a70488e8 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2778,6 +2778,156 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration { // SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) } +// GetCacheGTSListMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListMaxSize' field +func (st *ConfigState) GetCacheGTSListMaxSize() (v int) { + st.mutex.Lock() + v = st.config.Cache.GTS.ListMaxSize + st.mutex.Unlock() + return +} + +// SetCacheGTSListMaxSize safely sets the Configuration value for state's 'Cache.GTS.ListMaxSize' field +func (st *ConfigState) SetCacheGTSListMaxSize(v int) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.ListMaxSize = v + st.reloadToViper() +} + +// CacheGTSListMaxSizeFlag returns the flag name for the 'Cache.GTS.ListMaxSize' field +func CacheGTSListMaxSizeFlag() string { return "cache-gts-list-max-size" } + +// GetCacheGTSListMaxSize safely fetches the value for global configuration 'Cache.GTS.ListMaxSize' field +func GetCacheGTSListMaxSize() int { return global.GetCacheGTSListMaxSize() } + +// SetCacheGTSListMaxSize safely sets the value for global configuration 'Cache.GTS.ListMaxSize' field +func SetCacheGTSListMaxSize(v int) { global.SetCacheGTSListMaxSize(v) } + +// GetCacheGTSListTTL safely fetches the Configuration value for state's 'Cache.GTS.ListTTL' field +func (st *ConfigState) GetCacheGTSListTTL() (v time.Duration) { + st.mutex.Lock() + v = st.config.Cache.GTS.ListTTL + st.mutex.Unlock() + return +} + +// SetCacheGTSListTTL safely sets the Configuration value for state's 'Cache.GTS.ListTTL' field +func (st *ConfigState) SetCacheGTSListTTL(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.ListTTL = v + st.reloadToViper() +} + +// CacheGTSListTTLFlag returns the flag name for the 'Cache.GTS.ListTTL' field +func CacheGTSListTTLFlag() string { return "cache-gts-list-ttl" } + +// GetCacheGTSListTTL safely fetches the value for global configuration 'Cache.GTS.ListTTL' field +func GetCacheGTSListTTL() time.Duration { return global.GetCacheGTSListTTL() } + +// SetCacheGTSListTTL safely sets the value for global configuration 'Cache.GTS.ListTTL' field +func SetCacheGTSListTTL(v time.Duration) { global.SetCacheGTSListTTL(v) } + +// GetCacheGTSListSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.ListSweepFreq' field +func (st *ConfigState) GetCacheGTSListSweepFreq() (v time.Duration) { + st.mutex.Lock() + v = st.config.Cache.GTS.ListSweepFreq + st.mutex.Unlock() + return +} + +// SetCacheGTSListSweepFreq safely sets the Configuration value for state's 'Cache.GTS.ListSweepFreq' field +func (st *ConfigState) SetCacheGTSListSweepFreq(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.ListSweepFreq = v + st.reloadToViper() +} + +// CacheGTSListSweepFreqFlag returns the flag name for the 'Cache.GTS.ListSweepFreq' field +func CacheGTSListSweepFreqFlag() string { return "cache-gts-list-sweep-freq" } + +// GetCacheGTSListSweepFreq safely fetches the value for global configuration 'Cache.GTS.ListSweepFreq' field +func GetCacheGTSListSweepFreq() time.Duration { return global.GetCacheGTSListSweepFreq() } + +// SetCacheGTSListSweepFreq safely sets the value for global configuration 'Cache.GTS.ListSweepFreq' field +func SetCacheGTSListSweepFreq(v time.Duration) { global.SetCacheGTSListSweepFreq(v) } + +// GetCacheGTSListEntryMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListEntryMaxSize' field +func (st *ConfigState) GetCacheGTSListEntryMaxSize() (v int) { + st.mutex.Lock() + v = st.config.Cache.GTS.ListEntryMaxSize + st.mutex.Unlock() + return +} + +// SetCacheGTSListEntryMaxSize safely sets the Configuration value for state's 'Cache.GTS.ListEntryMaxSize' field +func (st *ConfigState) SetCacheGTSListEntryMaxSize(v int) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.ListEntryMaxSize = v + st.reloadToViper() +} + +// CacheGTSListEntryMaxSizeFlag returns the flag name for the 'Cache.GTS.ListEntryMaxSize' field +func CacheGTSListEntryMaxSizeFlag() string { return "cache-gts-list-entry-max-size" } + +// GetCacheGTSListEntryMaxSize safely fetches the value for global configuration 'Cache.GTS.ListEntryMaxSize' field +func GetCacheGTSListEntryMaxSize() int { return global.GetCacheGTSListEntryMaxSize() } + +// SetCacheGTSListEntryMaxSize safely sets the value for global configuration 'Cache.GTS.ListEntryMaxSize' field +func SetCacheGTSListEntryMaxSize(v int) { global.SetCacheGTSListEntryMaxSize(v) } + +// GetCacheGTSListEntryTTL safely fetches the Configuration value for state's 'Cache.GTS.ListEntryTTL' field +func (st *ConfigState) GetCacheGTSListEntryTTL() (v time.Duration) { + st.mutex.Lock() + v = st.config.Cache.GTS.ListEntryTTL + st.mutex.Unlock() + return +} + +// SetCacheGTSListEntryTTL safely sets the Configuration value for state's 'Cache.GTS.ListEntryTTL' field +func (st *ConfigState) SetCacheGTSListEntryTTL(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.ListEntryTTL = v + st.reloadToViper() +} + +// CacheGTSListEntryTTLFlag returns the flag name for the 'Cache.GTS.ListEntryTTL' field +func CacheGTSListEntryTTLFlag() string { return "cache-gts-list-entry-ttl" } + +// GetCacheGTSListEntryTTL safely fetches the value for global configuration 'Cache.GTS.ListEntryTTL' field +func GetCacheGTSListEntryTTL() time.Duration { return global.GetCacheGTSListEntryTTL() } + +// SetCacheGTSListEntryTTL safely sets the value for global configuration 'Cache.GTS.ListEntryTTL' field +func SetCacheGTSListEntryTTL(v time.Duration) { global.SetCacheGTSListEntryTTL(v) } + +// GetCacheGTSListEntrySweepFreq safely fetches the Configuration value for state's 'Cache.GTS.ListEntrySweepFreq' field +func (st *ConfigState) GetCacheGTSListEntrySweepFreq() (v time.Duration) { + st.mutex.Lock() + v = st.config.Cache.GTS.ListEntrySweepFreq + st.mutex.Unlock() + return +} + +// SetCacheGTSListEntrySweepFreq safely sets the Configuration value for state's 'Cache.GTS.ListEntrySweepFreq' field +func (st *ConfigState) SetCacheGTSListEntrySweepFreq(v time.Duration) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.GTS.ListEntrySweepFreq = v + st.reloadToViper() +} + +// CacheGTSListEntrySweepFreqFlag returns the flag name for the 'Cache.GTS.ListEntrySweepFreq' field +func CacheGTSListEntrySweepFreqFlag() string { return "cache-gts-list-entry-sweep-freq" } + +// GetCacheGTSListEntrySweepFreq safely fetches the value for global configuration 'Cache.GTS.ListEntrySweepFreq' field +func GetCacheGTSListEntrySweepFreq() time.Duration { return global.GetCacheGTSListEntrySweepFreq() } + +// SetCacheGTSListEntrySweepFreq safely sets the value for global configuration 'Cache.GTS.ListEntrySweepFreq' field +func SetCacheGTSListEntrySweepFreq(v time.Duration) { global.SetCacheGTSListEntrySweepFreq(v) } + // GetCacheGTSMediaMaxSize safely fetches the Configuration value for state's 'Cache.GTS.MediaMaxSize' field func (st *ConfigState) GetCacheGTSMediaMaxSize() (v int) { st.mutex.Lock() diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index f095d1728..f0329e898 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -65,6 +65,7 @@ type DBService struct { db.Domain db.Emoji db.Instance + db.List db.Media db.Mention db.Notification @@ -179,6 +180,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { Instance: &instanceDB{ conn: conn, }, + List: &listDB{ + conn: conn, + state: state, + }, Media: &mediaDB{ conn: conn, state: state, diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index 2566be2ba..84e11447a 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -22,6 +22,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -46,6 +47,8 @@ type BunDBStandardTestSuite struct { testReports map[string]*gtsmodel.Report testBookmarks map[string]*gtsmodel.StatusBookmark testFaves map[string]*gtsmodel.StatusFave + testLists map[string]*gtsmodel.List + testListEntries map[string]*gtsmodel.ListEntry } func (suite *BunDBStandardTestSuite) SetupSuite() { @@ -63,6 +66,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() { suite.testReports = testrig.NewTestReports() suite.testBookmarks = testrig.NewTestBookmarks() suite.testFaves = testrig.NewTestFaves() + suite.testLists = testrig.NewTestLists() + suite.testListEntries = testrig.NewTestListEntries() } func (suite *BunDBStandardTestSuite) SetupTest() { @@ -70,6 +75,7 @@ func (suite *BunDBStandardTestSuite) SetupTest() { testrig.InitTestLog() suite.state.Caches.Init() suite.db = testrig.NewTestDB(&suite.state) + testrig.StartTimelines(&suite.state, visibility.NewFilter(&suite.state), testrig.NewTestTypeConverter(suite.db)) testrig.StandardDBSetup(suite.db, suite.testAccounts) } diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go new file mode 100644 index 000000000..38701cc07 --- /dev/null +++ b/internal/db/bundb/list.go @@ -0,0 +1,467 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/uptrace/bun" +) + +type listDB struct { + conn *DBConn + state *state.State +} + +/* + LIST FUNCTIONS +*/ + +func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) { + list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) { + var list gtsmodel.List + + // Not cached! Perform database query. + if err := dbQuery(&list); err != nil { + return nil, l.conn.ProcessError(err) + } + + return &list, nil + }, keyParts...) + if err != nil { + return nil, err // already processed + } + + if gtscontext.Barebones(ctx) { + // Only a barebones model was requested. + return list, nil + } + + if err := l.state.DB.PopulateList(ctx, list); err != nil { + return nil, err + } + + return list, nil +} + +func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) { + return l.getList( + ctx, + "ID", + func(list *gtsmodel.List) error { + return l.conn.NewSelect(). + Model(list). + Where("? = ?", bun.Ident("list.id"), id). + Scan(ctx) + }, + id, + ) +} + +func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) { + // Fetch IDs of all lists owned by this account. + var listIDs []string + if err := l.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")). + Column("list.id"). + Where("? = ?", bun.Ident("list.account_id"), accountID). + Order("list.id DESC"). + Scan(ctx, &listIDs); err != nil { + return nil, l.conn.ProcessError(err) + } + + if len(listIDs) == 0 { + return nil, nil + } + + // Select each list using its ID to ensure cache used. + lists := make([]*gtsmodel.List, 0, len(listIDs)) + for _, id := range listIDs { + list, err := l.state.DB.GetListByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching list %q: %v", id, err) + continue + } + + // Append list. + lists = append(lists, list) + } + + return lists, nil +} + +func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { + var ( + err error + errs = make(gtserror.MultiError, 0, 2) + ) + + if list.Account == nil { + // List account is not set, fetch from the database. + list.Account, err = l.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + list.AccountID, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating list account: %w", err)) + } + } + + if list.ListEntries == nil { + // List entries are not set, fetch from the database. + list.ListEntries, err = l.state.DB.GetListEntries( + gtscontext.SetBarebones(ctx), + list.ID, + "", "", "", 0, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating list entries: %w", err)) + } + } + + return errs.Combine() +} + +func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { + return l.state.Caches.GTS.List().Store(list, func() error { + _, err := l.conn.NewInsert().Model(list).Exec(ctx) + return l.conn.ProcessError(err) + }) +} + +func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error { + list.UpdatedAt = time.Now() + if len(columns) > 0 { + // If we're updating by column, ensure "updated_at" is included. + columns = append(columns, "updated_at") + } + + return l.state.Caches.GTS.List().Store(list, func() error { + if _, err := l.conn.NewUpdate(). + Model(list). + Where("? = ?", bun.Ident("list.id"), list.ID). + Column(columns...). + Exec(ctx); err != nil { + return l.conn.ProcessError(err) + } + + return nil + }) +} + +func (l *listDB) DeleteListByID(ctx context.Context, id string) error { + defer l.state.Caches.GTS.List().Invalidate("ID", id) + + // Select all entries that belong to this list. + listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0) + if err != nil { + return fmt.Errorf("error selecting entries from list %q: %w", id, err) + } + + // Delete each list entry. This will + // invalidate the list timeline too. + for _, listEntry := range listEntries { + err := l.state.DB.DeleteListEntry(ctx, listEntry.ID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } + } + + // Finally delete list itself from DB. + _, err = l.conn.NewDelete(). + Table("lists"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + return l.conn.ProcessError(err) +} + +/* + LIST ENTRY functions +*/ + +func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) { + listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) { + var listEntry gtsmodel.ListEntry + + // Not cached! Perform database query. + if err := dbQuery(&listEntry); err != nil { + return nil, l.conn.ProcessError(err) + } + + return &listEntry, nil + }, keyParts...) + if err != nil { + return nil, err // already processed + } + + if gtscontext.Barebones(ctx) { + // Only a barebones model was requested. + return listEntry, nil + } + + // Further populate the list entry fields where applicable. + if err := l.state.DB.PopulateListEntry(ctx, listEntry); err != nil { + return nil, err + } + + return listEntry, nil +} + +func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) { + return l.getListEntry( + ctx, + "ID", + func(listEntry *gtsmodel.ListEntry) error { + return l.conn.NewSelect(). + Model(listEntry). + Where("? = ?", bun.Ident("list_entry.id"), id). + Scan(ctx) + }, + id, + ) +} + +func (l *listDB) GetListEntries(ctx context.Context, + listID string, + maxID string, + sinceID string, + minID string, + limit int, +) ([]*gtsmodel.ListEntry, error) { + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make educated guess for slice size + var ( + entryIDs = make([]string, 0, limit) + frontToBack = true + ) + + q := l.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). + // Select only IDs from table + Column("entry.id"). + // Select only entries belonging to listID. + Where("? = ?", bun.Ident("entry.list_id"), listID) + + if maxID != "" { + // return only entries LOWER (ie., older) than maxID + q = q.Where("? < ?", bun.Ident("entry.id"), maxID) + } + + if sinceID != "" { + // return only entries HIGHER (ie., newer) than sinceID + q = q.Where("? > ?", bun.Ident("entry.id"), sinceID) + } + + if minID != "" { + // return only entries HIGHER (ie., newer) than minID + q = q.Where("? > ?", bun.Ident("entry.id"), minID) + + // page up + frontToBack = false + } + + if limit > 0 { + // limit amount of entries returned + q = q.Limit(limit) + } + + if frontToBack { + // Page down. + q = q.Order("entry.id DESC") + } else { + // Page up. + q = q.Order("entry.id ASC") + } + + if err := q.Scan(ctx, &entryIDs); err != nil { + return nil, l.conn.ProcessError(err) + } + + if len(entryIDs) == 0 { + return nil, nil + } + + // If we're paging up, we still want entries + // to be sorted by ID desc, so reverse ids slice. + // https://zchee.github.io/golang-wiki/SliceTricks/#reversing + if !frontToBack { + for l, r := 0, len(entryIDs)-1; l < r; l, r = l+1, r-1 { + entryIDs[l], entryIDs[r] = entryIDs[r], entryIDs[l] + } + } + + // Select each list entry using its ID to ensure cache used. + listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) + for _, id := range entryIDs { + listEntry, err := l.state.DB.GetListEntryByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching list entry %q: %v", id, err) + continue + } + + // Append list entries. + listEntries = append(listEntries, listEntry) + } + + return listEntries, nil +} + +func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) { + entryIDs := []string{} + + if err := l.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")). + // Select only IDs from table + Column("entry.id"). + // Select only entries belonging with given followID. + Where("? = ?", bun.Ident("entry.follow_id"), followID). + Scan(ctx, &entryIDs); err != nil { + return nil, l.conn.ProcessError(err) + } + + if len(entryIDs) == 0 { + return nil, nil + } + + // Select each list entry using its ID to ensure cache used. + listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs)) + for _, id := range entryIDs { + listEntry, err := l.state.DB.GetListEntryByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching list entry %q: %v", id, err) + continue + } + + // Append list entries. + listEntries = append(listEntries, listEntry) + } + + return listEntries, nil +} + +func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error { + var err error + + if listEntry.Follow == nil { + // ListEntry follow is not set, fetch from the database. + listEntry.Follow, err = l.state.DB.GetFollowByID( + gtscontext.SetBarebones(ctx), + listEntry.FollowID, + ) + if err != nil { + return fmt.Errorf("error populating listEntry follow: %w", err) + } + } + + return nil +} + +func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error { + return l.conn.RunInTx(ctx, func(tx bun.Tx) error { + for _, listEntry := range listEntries { + if _, err := tx. + NewInsert(). + Model(listEntry). + Exec(ctx); err != nil { + return err + } + + // Invalidate the timeline for the list this entry belongs to. + if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { + log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err) + } + } + + return nil + }) +} + +func (l *listDB) DeleteListEntry(ctx context.Context, id string) error { + defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id) + + // Load list entry into cache before attempting a delete, + // as we need the followID from it in order to trigger + // timeline invalidation. + listEntry, err := l.GetListEntryByID( + // Don't populate the entry; + // we only want the list ID. + gtscontext.SetBarebones(ctx), + id, + ) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + // Already gone. + return nil + } + return err + } + + defer func() { + // Invalidate the timeline for the list this entry belongs to. + if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil { + log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err) + } + }() + + if _, err := l.conn.NewDelete(). + Table("list_entries"). + Where("? = ?", bun.Ident("id"), listEntry.ID). + Exec(ctx); err != nil { + return l.conn.ProcessError(err) + } + + return nil +} + +func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error { + // Fetch IDs of all entries that pertain to this follow. + var listEntryIDs []string + if err := l.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")). + Column("list_entry.id"). + Where("? = ?", bun.Ident("list_entry.follow_id"), followID). + Order("list_entry.id DESC"). + Scan(ctx, &listEntryIDs); err != nil { + return l.conn.ProcessError(err) + } + + for _, id := range listEntryIDs { + if err := l.DeleteListEntry(ctx, id); err != nil { + return err + } + } + + return nil +} diff --git a/internal/db/bundb/list_test.go b/internal/db/bundb/list_test.go new file mode 100644 index 000000000..296ab7c1a --- /dev/null +++ b/internal/db/bundb/list_test.go @@ -0,0 +1,315 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "golang.org/x/exp/slices" +) + +type ListTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) { + testList := >smodel.List{} + *testList = *suite.testLists["local_account_1_list_1"] + + // Populate entries on this list as we'd expect them back from the db. + entries := make([]*gtsmodel.ListEntry, 0, len(suite.testListEntries)) + for _, entry := range suite.testListEntries { + entries = append(entries, entry) + } + + // Sort by ID descending (again, as we'd expect from the db). + slices.SortFunc(entries, func(a, b *gtsmodel.ListEntry) bool { + return b.ID < a.ID + }) + + testList.ListEntries = entries + + testAccount := >smodel.Account{} + *testAccount = *suite.testAccounts["local_account_1"] + + return testList, testAccount +} + +func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) { + suite.Equal(expected.ID, actual.ID) + suite.Equal(expected.Title, actual.Title) + suite.Equal(expected.AccountID, actual.AccountID) + suite.Equal(expected.RepliesPolicy, actual.RepliesPolicy) + suite.NotNil(actual.Account) +} + +func (suite *ListTestSuite) checkListEntry(expected *gtsmodel.ListEntry, actual *gtsmodel.ListEntry) { + suite.Equal(expected.ID, actual.ID) + suite.Equal(expected.ListID, actual.ListID) + suite.Equal(expected.FollowID, actual.FollowID) +} + +func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, actual []*gtsmodel.ListEntry) { + var ( + lExpected = len(expected) + lActual = len(actual) + ) + + if lExpected != lActual { + suite.FailNow("", "expected %d list entries, got %d", lExpected, lActual) + } + + var topID string + for i, expectedEntry := range expected { + actualEntry := actual[i] + + // Ensure ID descending. + if topID == "" { + topID = actualEntry.ID + } else { + suite.Less(actualEntry.ID, topID) + } + + suite.checkListEntry(expectedEntry, actualEntry) + } +} + +func (suite *ListTestSuite) TestGetListByID() { + testList, _ := suite.testStructs() + + dbList, err := suite.db.GetListByID(context.Background(), testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkList(testList, dbList) + suite.checkListEntries(testList.ListEntries, dbList.ListEntries) +} + +func (suite *ListTestSuite) TestGetListsForAccountID() { + testList, testAccount := suite.testStructs() + + dbLists, err := suite.db.GetListsForAccountID(context.Background(), testAccount.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + if l := len(dbLists); l != 1 { + suite.FailNow("", "expected %d lists, got %d", 1, l) + } + + suite.checkList(testList, dbLists[0]) +} + +func (suite *ListTestSuite) TestGetListEntries() { + testList, _ := suite.testStructs() + + dbListEntries, err := suite.db.GetListEntries(context.Background(), testList.ID, "", "", "", 0) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkListEntries(testList.ListEntries, dbListEntries) +} + +func (suite *ListTestSuite) TestPutList() { + ctx := context.Background() + _, testAccount := suite.testStructs() + + testList := >smodel.List{ + ID: "01H0J2PMYM54618VCV8Y8QYAT4", + Title: "Test List!", + AccountID: testAccount.ID, + } + + if err := suite.db.PutList(ctx, testList); err != nil { + suite.FailNow(err.Error()) + } + + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Bodge testlist as though default had been set. + testList.RepliesPolicy = gtsmodel.RepliesPolicyFollowed + suite.checkList(testList, dbList) +} + +func (suite *ListTestSuite) TestUpdateList() { + ctx := context.Background() + testList, _ := suite.testStructs() + + // Get List in the cache first. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Now do the update. + testList.Title = "New Title!" + if err := suite.db.UpdateList(ctx, testList, "title"); err != nil { + suite.FailNow(err.Error()) + } + + // Cache should be invalidated + // + we should have updated list. + dbList, err = suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkList(testList, dbList) +} + +func (suite *ListTestSuite) TestDeleteList() { + ctx := context.Background() + testList, _ := suite.testStructs() + + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Now do the delete. + if err := suite.db.DeleteListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Cache should be invalidated + // + we should have no list. + _, err := suite.db.GetListByID(ctx, testList.ID) + suite.ErrorIs(err, db.ErrNoEntries) + + // All entries belonging to this + // list should now be deleted. + listEntries, err := suite.db.GetListEntries(ctx, testList.ID, "", "", "", 0) + if err != nil { + suite.FailNow(err.Error()) + } + suite.Empty(listEntries) +} + +func (suite *ListTestSuite) TestPutListEntries() { + ctx := context.Background() + testList, _ := suite.testStructs() + + listEntries := []*gtsmodel.ListEntry{ + { + ID: "01H0MKMQY69HWDSDR2SWGA17R4", + ListID: testList.ID, + FollowID: "01H0MKNFRFZS8R9WV6DBX31Y03", // random id, doesn't exist + }, + { + ID: "01H0MKPGQF0E7QAVW5BKTHZ630", + ListID: testList.ID, + FollowID: "01H0MKP6RR8VEHN3GVWFBP2H30", // random id, doesn't exist + }, + { + ID: "01H0MKPPP2DT68FRBMR1FJM32T", + ListID: testList.ID, + FollowID: "01H0MKQ0KA29C6NFJ27GTZD16J", // random id, doesn't exist + }, + } + + if err := suite.db.PutListEntries(ctx, listEntries); err != nil { + suite.FailNow(err.Error()) + } + + // Add these entries to the test list, sort it again + // to reflect what we'd expect to get from the db. + testList.ListEntries = append(testList.ListEntries, listEntries...) + slices.SortFunc(testList.ListEntries, func(a, b *gtsmodel.ListEntry) bool { + return b.ID < a.ID + }) + + // Now get all list entries from the db. + // Use barebones for this because the ones + // we just added will fail if we try to get + // the nonexistent follows. + dbListEntries, err := suite.db.GetListEntries( + gtscontext.SetBarebones(ctx), + testList.ID, + "", "", "", 0) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkListEntries(testList.ListEntries, dbListEntries) +} + +func (suite *ListTestSuite) TestDeleteListEntry() { + ctx := context.Background() + testList, _ := suite.testStructs() + + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Delete the first entry. + if err := suite.db.DeleteListEntry(ctx, testList.ListEntries[0].ID); err != nil { + suite.FailNow(err.Error()) + } + + // Get list from the db again. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Bodge the testlist as though + // we'd removed the first entry. + testList.ListEntries = testList.ListEntries[1:] + suite.checkList(testList, dbList) +} + +func (suite *ListTestSuite) TestDeleteListEntriesForFollowID() { + ctx := context.Background() + testList, _ := suite.testStructs() + + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } + + // Delete the first entry. + if err := suite.db.DeleteListEntriesForFollowID(ctx, testList.ListEntries[0].FollowID); err != nil { + suite.FailNow(err.Error()) + } + + // Get list from the db again. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } + + // Bodge the testlist as though + // we'd removed the first entry. + testList.ListEntries = testList.ListEntries[1:] + suite.checkList(testList, dbList) +} + +func TestListTestSuite(t *testing.T) { + suite.Run(t, new(ListTestSuite)) +} diff --git a/internal/db/bundb/migrations/20230515173919_lists.go b/internal/db/bundb/migrations/20230515173919_lists.go new file mode 100644 index 000000000..e0ea5c7b6 --- /dev/null +++ b/internal/db/bundb/migrations/20230515173919_lists.go @@ -0,0 +1,92 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package migrations + +import ( + "context" + + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // List table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.List{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Add indexes to the List table. + for index, columns := range map[string][]string{ + "lists_id_idx": {"id"}, + "lists_account_id_idx": {"account_id"}, + } { + if _, err := tx. + NewCreateIndex(). + Table("lists"). + Index(index). + Column(columns...). + Exec(ctx); err != nil { + return err + } + } + + // List entry table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.ListEntry{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Add indexes to the List entry table. + for index, columns := range map[string][]string{ + "list_entries_id_idx": {"id"}, + "list_entries_list_id_idx": {"list_id"}, + "list_entries_follow_id_idx": {"follow_id"}, + } { + if _, err := tx. + NewCreateIndex(). + Table("list_entries"). + Index(index). + Column(columns...). + Exec(ctx); err != nil { + return err + } + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index fe1f26bf1..39b85075c 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" @@ -149,27 +150,44 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f return follow, nil } - // Set the follow source account - follow.Account, err = r.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - follow.AccountID, - ) - if err != nil { - return nil, fmt.Errorf("error getting follow source account: %w", err) - } - - // Set the follow target account - follow.TargetAccount, err = r.state.DB.GetAccountByID( - gtscontext.SetBarebones(ctx), - follow.TargetAccountID, - ) - if err != nil { - return nil, fmt.Errorf("error getting follow target account: %w", err) + if err := r.state.DB.PopulateFollow(ctx, follow); err != nil { + return nil, err } return follow, nil } +func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error { + var ( + err error + errs = make(gtserror.MultiError, 0, 2) + ) + + if follow.Account == nil { + // Follow account is not set, fetch from the database. + follow.Account, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + follow.AccountID, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating follow account: %w", err)) + } + } + + if follow.TargetAccount == nil { + // Follow target account is not set, fetch from the database. + follow.TargetAccount, err = r.state.DB.GetAccountByID( + gtscontext.SetBarebones(ctx), + follow.TargetAccountID, + ) + if err != nil { + errs.Append(fmt.Errorf("error populating follow target account: %w", err)) + } + } + + return errs.Combine() +} + func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { return r.state.Caches.GTS.Follow().Store(follow, func() error { _, err := r.conn.NewInsert().Model(follow).Exec(ctx) @@ -197,27 +215,40 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll }) } +func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error { + // Delete the follow itself using the given ID. + if _, err := r.conn.NewDelete(). + Table("follows"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + // Delete every list entry that used this followID. + if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil { + return fmt.Errorf("deleteFollow: error deleting list entries: %w", err) + } + + return nil +} + func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error { defer r.state.Caches.GTS.Follow().Invalidate("ID", id) // Load follow into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. - _, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id) + follow, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id) if err != nil { if errors.Is(err, db.ErrNoEntries) { - // not an issue. - err = nil + // Already gone. + return nil } return err } // Finally delete follow from DB. - _, err = r.conn.NewDelete(). - Table("follows"). - Where("? = ?", bun.Ident("id"), id). - Exec(ctx) - return r.conn.ProcessError(err) + return r.deleteFollow(ctx, follow.ID) } func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error { @@ -226,21 +257,17 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro // Load follow into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. - _, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri) + follow, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri) if err != nil { if errors.Is(err, db.ErrNoEntries) { - // not an issue. - err = nil + // Already gone. + return nil } return err } // Finally delete follow from DB. - _, err = r.conn.NewDelete(). - Table("follows"). - Where("? = ?", bun.Ident("uri"), uri). - Exec(ctx) - return r.conn.ProcessError(err) + return r.deleteFollow(ctx, follow.ID) } func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error { @@ -272,16 +299,16 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str // but it is the only way we can ensure we invalidate all // related caches correctly (e.g. visibility). for _, id := range followIDs { - _, err := r.GetFollowByID(ctx, id) + follow, err := r.GetFollowByID(ctx, id) if err != nil && !errors.Is(err, db.ErrNoEntries) { return err } + + // Delete each follow from DB. + if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } } - // Finally delete all from DB. - _, err := r.conn.NewDelete(). - Table("follows"). - Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)). - Exec(ctx) - return r.conn.ProcessError(err) + return nil } diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index 0e38d19fe..63fdb9632 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -807,16 +807,27 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() { follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID) suite.NoError(err) suite.NotNil(follow) + followID := follow.ID - err = suite.db.DeleteFollowByID(context.Background(), follow.ID) + // We should have list entries for this follow. + listEntries, err := suite.db.GetListEntriesForFollowID(context.Background(), followID) + suite.NoError(err) + suite.NotEmpty(listEntries) + + err = suite.db.DeleteFollowByID(context.Background(), followID) suite.NoError(err) follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID) suite.EqualError(err, db.ErrNoEntries.Error()) suite.Nil(follow) + + // ListEntries pertaining to this follow should be deleted too. + listEntries, err = suite.db.GetListEntriesForFollowID(context.Background(), followID) + suite.NoError(err) + suite.Empty(listEntries) } -func (suite *RelationshipTestSuite) TestUnfollowNotExisting() { +func (suite *RelationshipTestSuite) TestGetFollowNotExisting() { originAccount := suite.testAccounts["local_account_1"] targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ" diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index 87e7751d2..d33840a7b 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -19,9 +19,11 @@ package bundb import ( "context" + "fmt" "time" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" @@ -281,3 +283,130 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max prevMinID := faves[0].ID return statuses, nextMaxID, prevMinID, nil } + +func (t *timelineDB) GetListTimeline( + ctx context.Context, + listID string, + maxID string, + sinceID string, + minID string, + limit int, +) ([]*gtsmodel.Status, error) { + // Ensure reasonable + if limit < 0 { + limit = 0 + } + + // Make educated guess for slice size + var ( + statusIDs = make([]string, 0, limit) + frontToBack = true + ) + + // Fetch all listEntries entries from the database. + listEntries, err := t.state.DB.GetListEntries( + // Don't need actual follows + // for this, just the IDs. + gtscontext.SetBarebones(ctx), + listID, + "", "", "", 0, + ) + if err != nil { + return nil, fmt.Errorf("error getting entries for list %s: %w", listID, err) + } + + // Extract just the IDs of each follow. + followIDs := make([]string, 0, len(listEntries)) + for _, listEntry := range listEntries { + followIDs = append(followIDs, listEntry.FollowID) + } + + // Select target account IDs from follows. + subQ := t.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). + Column("follow.target_account_id"). + Where("? IN (?)", bun.Ident("follow.id"), bun.In(followIDs)) + + // Select only status IDs created + // by one of the followed accounts. + q := t.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + // Select only IDs from table + Column("status.id"). + Where("? IN (?)", bun.Ident("status.account_id"), subQ) + + if maxID == "" || maxID >= id.Highest { + const future = 24 * time.Hour + + var err error + + // don't return statuses more than 24hr in the future + maxID, err = id.NewULIDFromTime(time.Now().Add(future)) + if err != nil { + return nil, err + } + } + + // return only statuses LOWER (ie., older) than maxID + q = q.Where("? < ?", bun.Ident("status.id"), maxID) + + if sinceID != "" { + // return only statuses HIGHER (ie., newer) than sinceID + q = q.Where("? > ?", bun.Ident("status.id"), sinceID) + } + + if minID != "" { + // return only statuses HIGHER (ie., newer) than minID + q = q.Where("? > ?", bun.Ident("status.id"), minID) + + // page up + frontToBack = false + } + + if limit > 0 { + // limit amount of statuses returned + q = q.Limit(limit) + } + + if frontToBack { + // Page down. + q = q.Order("status.id DESC") + } else { + // Page up. + q = q.Order("status.id ASC") + } + + if err := q.Scan(ctx, &statusIDs); err != nil { + return nil, t.conn.ProcessError(err) + } + + if len(statusIDs) == 0 { + return nil, nil + } + + // If we're paging up, we still want statuses + // to be sorted by ID desc, so reverse ids slice. + // https://zchee.github.io/golang-wiki/SliceTricks/#reversing + if !frontToBack { + for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 { + statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l] + } + } + + statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) + for _, id := range statusIDs { + // Fetch status from db for ID + status, err := t.state.DB.GetStatusByID(ctx, id) + if err != nil { + log.Errorf(ctx, "error fetching status %q: %v", id, err) + continue + } + + // Append status to slice + statuses = append(statuses, status) + } + + return statuses, nil +} diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index f954c78dd..7e8fd0838 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -33,99 +33,6 @@ type TimelineTestSuite struct { BunDBStandardTestSuite } -func (suite *TimelineTestSuite) TestGetPublicTimeline() { - var count int - - for _, status := range suite.testStatuses { - if status.Visibility == gtsmodel.VisibilityPublic && - status.BoostOfID == "" { - count++ - } - } - - ctx := context.Background() - s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) - suite.NoError(err) - - suite.Len(s, count) -} - -func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { - var count int - - for _, status := range suite.testStatuses { - if status.Visibility == gtsmodel.VisibilityPublic && - status.BoostOfID == "" { - count++ - } - } - - ctx := context.Background() - - futureStatus := getFutureStatus() - err := suite.db.PutStatus(ctx, futureStatus) - suite.NoError(err) - - s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) - suite.NoError(err) - - suite.NotContains(s, futureStatus) - suite.Len(s, count) -} - -func (suite *TimelineTestSuite) TestGetHomeTimeline() { - ctx := context.Background() - - viewingAccount := suite.testAccounts["local_account_1"] - - s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false) - suite.NoError(err) - - suite.Len(s, 16) -} - -func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { - ctx := context.Background() - - viewingAccount := suite.testAccounts["local_account_1"] - - futureStatus := getFutureStatus() - err := suite.db.PutStatus(ctx, futureStatus) - suite.NoError(err) - - s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) - suite.NoError(err) - - suite.NotContains(s, futureStatus) - suite.Len(s, 16) -} - -func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() { - ctx := context.Background() - - viewingAccount := suite.testAccounts["local_account_1"] - - s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", id.Lowest, 5, false) - suite.NoError(err) - - suite.Len(s, 5) - suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", s[0].ID) - suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID) -} - -func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() { - ctx := context.Background() - - viewingAccount := suite.testAccounts["local_account_1"] - - s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, id.Highest, "", "", 5, false) - suite.NoError(err) - - suite.Len(s, 5) - suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) - suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID) -} - func getFutureStatus() *gtsmodel.Status { theDistantFuture := time.Now().Add(876600 * time.Hour) id, err := id.NewULIDFromTime(theDistantFuture) @@ -163,6 +70,208 @@ func getFutureStatus() *gtsmodel.Status { } } +func (suite *TimelineTestSuite) publicCount() int { + var publicCount int + + for _, status := range suite.testStatuses { + if status.Visibility == gtsmodel.VisibilityPublic && + status.BoostOfID == "" { + publicCount++ + } + } + + return publicCount +} + +func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) { + if l := len(statuses); l != expectedLength { + suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l) + } else if l == 0 { + // Can't test empty slice. + return + } + + // Check ordering + bounds of statuses. + highest := statuses[0].ID + for _, status := range statuses { + id := status.ID + + if id >= maxID { + suite.FailNow("", "%s greater than maxID %s", id, maxID) + } + + if id <= minID { + suite.FailNow("", "%s smaller than minID %s", id, minID) + } + + if id > highest { + suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID") + } + + highest = id + } +} + +func (suite *TimelineTestSuite) TestGetPublicTimeline() { + ctx := context.Background() + + s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount()) +} + +func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() { + ctx := context.Background() + + // Insert a status set far in the + // future, it shouldn't be retrieved. + futureStatus := getFutureStatus() + if err := suite.db.PutStatus(ctx, futureStatus); err != nil { + suite.FailNow(err.Error()) + } + + s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.NotContains(s, futureStatus) + suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount()) +} + +func (suite *TimelineTestSuite) TestGetHomeTimeline() { + var ( + ctx = context.Background() + viewingAccount = suite.testAccounts["local_account_1"] + ) + + s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, 16) +} + +func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { + var ( + ctx = context.Background() + viewingAccount = suite.testAccounts["local_account_1"] + ) + + // Insert a status set far in the + // future, it shouldn't be retrieved. + futureStatus := getFutureStatus() + if err := suite.db.PutStatus(ctx, futureStatus); err != nil { + suite.FailNow(err.Error()) + } + + s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.NotContains(s, futureStatus) + suite.checkStatuses(s, id.Highest, id.Lowest, 16) +} + +func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() { + var ( + ctx = context.Background() + viewingAccount = suite.testAccounts["local_account_1"] + ) + + s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", id.Lowest, 5, false) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, 5) + suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", s[0].ID) + suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID) +} + +func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() { + var ( + ctx = context.Background() + viewingAccount = suite.testAccounts["local_account_1"] + ) + + s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, id.Highest, "", "", 5, false) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, 5) + suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) + suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID) +} + +func (suite *TimelineTestSuite) TestGetListTimelineNoParams() { + var ( + ctx = context.Background() + list = suite.testLists["local_account_1_list_1"] + ) + + s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "", 20) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, 11) +} + +func (suite *TimelineTestSuite) TestGetListTimelineMaxID() { + var ( + ctx = context.Background() + list = suite.testLists["local_account_1_list_1"] + ) + + s, err := suite.db.GetListTimeline(ctx, list.ID, id.Highest, "", "", 5) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, 5) + suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) + suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID) +} + +func (suite *TimelineTestSuite) TestGetListTimelineMinID() { + var ( + ctx = context.Background() + list = suite.testLists["local_account_1_list_1"] + ) + + s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", id.Lowest, 5) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, id.Lowest, 5) + suite.Equal("01F8MHC8VWDRBQR0N1BATDDEM5", s[0].ID) + suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID) +} + +func (suite *TimelineTestSuite) TestGetListTimelineMinIDPagingUp() { + var ( + ctx = context.Background() + list = suite.testLists["local_account_1_list_1"] + ) + + s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "01F8MHC8VWDRBQR0N1BATDDEM5", 5) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.checkStatuses(s, id.Highest, "01F8MHC8VWDRBQR0N1BATDDEM5", 5) + suite.Equal("01G20ZM733MGN8J344T4ZDDFY1", s[0].ID) + suite.Equal("01F8MHCP5P2NWYQ416SBA0XSEV", s[len(s)-1].ID) +} + func TestTimelineTestSuite(t *testing.T) { suite.Run(t, new(TimelineTestSuite)) } diff --git a/internal/db/db.go b/internal/db/db.go index 7b25b3dae..f47a35bb3 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -36,6 +36,7 @@ type DB interface { Domain Emoji Instance + List Media Mention Notification diff --git a/internal/db/list.go b/internal/db/list.go new file mode 100644 index 000000000..4472589dc --- /dev/null +++ b/internal/db/list.go @@ -0,0 +1,67 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +type List interface { + // GetListByID gets one list with the given id. + GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) + + // GetListsForAccountID gets all lists owned by the given accountID. + GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) + + // PopulateList ensures that the list's struct fields are populated. + PopulateList(ctx context.Context, list *gtsmodel.List) error + + // PutList puts a new list in the database. + PutList(ctx context.Context, list *gtsmodel.List) error + + // UpdateList updates the given list. + // Columns is optional, if not specified all will be updated. + UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error + + // DeleteListByID deletes one list with the given ID. + DeleteListByID(ctx context.Context, id string) error + + // GetListEntryByID gets one list entry with the given ID. + GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) + + // GetListEntries gets list entries from the given listID, using the given parameters. + GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error) + + // GetListEntriesForFollowID returns all listEntries that pertain to the given followID. + GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) + + // PopulateListEntry ensures that the listEntry's struct fields are populated. + PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error + + // PutListEntries inserts a slice of listEntries into the database. + // It uses a transaction to ensure no partial updates. + PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error + + // DeleteListEntry deletes one list entry with the given id. + DeleteListEntry(ctx context.Context, id string) error + + // DeleteListEntryForFollowID deletes all list entries with the given followID. + DeleteListEntriesForFollowID(ctx context.Context, followID string) error +} diff --git a/internal/db/relationship.go b/internal/db/relationship.go index ae879b5d2..99093591c 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -64,6 +64,9 @@ type Relationship interface { // GetFollow retrieves a follow if it exists between source and target accounts. GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) + // PopulateFollow populates the struct pointers on the given follow. + PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error + // GetFollowRequestByID fetches follow request with given ID from the database. GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) diff --git a/internal/db/timeline.go b/internal/db/timeline.go index 10149cc09..2635bece2 100644 --- a/internal/db/timeline.go +++ b/internal/db/timeline.go @@ -44,4 +44,8 @@ type Timeline interface { // // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error) + + // GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID. + // Statuses should be returned in descending order of when they were created (newest first). + GetListTimeline(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error) } diff --git a/internal/federation/dereferencing/dereferencer_test.go b/internal/federation/dereferencing/dereferencer_test.go index 3cec176fe..6e9aa729c 100644 --- a/internal/federation/dereferencing/dereferencer_test.go +++ b/internal/federation/dereferencing/dereferencer_test.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -61,6 +62,13 @@ func (suite *DereferencerStandardTestSuite) SetupTest() { testrig.StartWorkers(&suite.state) suite.db = testrig.NewTestDB(&suite.state) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.storage = testrig.NewInMemoryStorage() suite.state.DB = suite.db suite.state.Storage = suite.storage diff --git a/internal/federation/federatingdb/federatingdb_test.go b/internal/federation/federatingdb/federatingdb_test.go index 32495a2f8..6a6cb9262 100644 --- a/internal/federation/federatingdb/federatingdb_test.go +++ b/internal/federation/federatingdb/federatingdb_test.go @@ -28,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -76,8 +77,16 @@ func (suite *FederatingDBTestSuite) SetupTest() { } suite.db = testrig.NewTestDB(&suite.state) + suite.testActivities = testrig.NewTestActivities(suite.testAccounts) suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.federatingDB = testrig.NewTestFederatingDB(&suite.state) testrig.StandardDBSetup(suite.db, suite.testAccounts) diff --git a/internal/federation/federator_test.go b/internal/federation/federator_test.go index abb462ee6..fdcfda19c 100644 --- a/internal/federation/federator_test.go +++ b/internal/federation/federator_test.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -58,6 +59,13 @@ func (suite *FederatorStandardTestSuite) SetupTest() { suite.db = testrig.NewTestDB(&suite.state) suite.tc = testrig.NewTestTypeConverter(suite.db) suite.state.DB = suite.db + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.testActivities = testrig.NewTestActivities(suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts) } diff --git a/internal/gtsmodel/list.go b/internal/gtsmodel/list.go new file mode 100644 index 000000000..98188b113 --- /dev/null +++ b/internal/gtsmodel/list.go @@ -0,0 +1,51 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gtsmodel + +import "time" + +// List refers to a list of follows for which the owning account wants to view a timeline of posts. +type List struct { + ID string `validate:"required,ulid" bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database + CreatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated + Title string `validate:"required" bun:",nullzero,notnull,unique:listaccounttitle"` // Title of this list. + AccountID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listaccounttitle"` // Account that created/owns the list + Account *Account `validate:"-" bun:"-"` // Account corresponding to accountID + ListEntries []*ListEntry `validate:"-" bun:"-"` // Entries contained by this list. + RepliesPolicy RepliesPolicy `validate:"-" bun:",nullzero,notnull,default:'followed'"` // RepliesPolicy for this list. +} + +// ListEntry refers to a single follow entry in a list. +type ListEntry struct { + ID string `validate:"required,ulid" bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database + CreatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated + ListID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listentrylistfollow"` // ID of the list that this entry belongs to. + FollowID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listentrylistfollow"` // Follow that the account owning this entry wants to see posts of in the timeline. + Follow *Follow `validate:"-" bun:"-"` // Follow corresponding to followID. +} + +// RepliesPolicy denotes which replies should be shown in the list. +type RepliesPolicy string + +const ( + RepliesPolicyFollowed RepliesPolicy = "followed" // Show replies to any followed user. + RepliesPolicyList RepliesPolicy = "list" // Show replies to members of the list only. + RepliesPolicyNone RepliesPolicy = "none" // Don't show replies. +) diff --git a/internal/media/manager_test.go b/internal/media/manager_test.go index 4dc0c4fa2..fb0784034 100644 --- a/internal/media/manager_test.go +++ b/internal/media/manager_test.go @@ -33,6 +33,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/state" gtsstorage "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/testrig" ) type ManagerTestSuite struct { @@ -395,9 +396,6 @@ func (suite *ManagerTestSuite) TestSlothVineProcessBlocking() { // fetch the attachment id from the processing media attachmentID := processingMedia.AttachmentID() - // Give time for processing - time.Sleep(time.Second * 3) - // do a blocking call to fetch the attachment attachment, err := processingMedia.LoadAttachment(ctx) suite.NoError(err) @@ -1027,13 +1025,14 @@ func (suite *ManagerTestSuite) TestSimpleJpegProcessAsync() { // fetch the attachment id from the processing media attachmentID := processingMedia.AttachmentID() - // Give time for processing to happen. - time.Sleep(time.Second * 3) - - // fetch the attachment from the database - attachment, err := suite.db.GetAttachmentByID(ctx, attachmentID) - suite.NoError(err) - suite.NotNil(attachment) + // wait for processing to complete + var attachment *gtsmodel.MediaAttachment + if !testrig.WaitFor(func() bool { + attachment, err = suite.db.GetAttachmentByID(ctx, attachmentID) + return err == nil && attachment != nil + }) { + suite.FailNow("timed out waiting for attachment to process") + } // make sure it's got the stuff set on it that we expect // the attachment ID and accountID we expect diff --git a/internal/media/media_test.go b/internal/media/media_test.go index e522fbb90..323f87bf4 100644 --- a/internal/media/media_test.go +++ b/internal/media/media_test.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -41,22 +42,27 @@ type MediaStandardTestSuite struct { testEmojis map[string]*gtsmodel.Emoji } -func (suite *MediaStandardTestSuite) SetupSuite() { +func (suite *MediaStandardTestSuite) SetupTest() { testrig.InitTestConfig() testrig.InitTestLog() + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + suite.db = testrig.NewTestDB(&suite.state) suite.storage = testrig.NewInMemoryStorage() suite.state.DB = suite.db suite.state.Storage = suite.storage -} - -func (suite *MediaStandardTestSuite) SetupTest() { - suite.state.Caches.Init() - testrig.StartWorkers(&suite.state) testrig.StandardStorageSetup(suite.storage, "../../testrig/media") testrig.StandardDBSetup(suite.db, nil) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.testAttachments = testrig.NewTestAttachments() suite.testAccounts = testrig.NewTestAccounts() suite.testEmojis = testrig.NewTestEmojis() diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go index 5d48d1210..3fa8c8991 100644 --- a/internal/processing/account/account_test.go +++ b/internal/processing/account/account_test.go @@ -88,6 +88,13 @@ func (suite *AccountStandardTestSuite) SetupTest() { suite.db = testrig.NewTestDB(&suite.state) suite.state.DB = suite.db suite.tc = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.tc, + ) + suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage suite.mediaManager = testrig.NewTestMediaManager(&suite.state) diff --git a/internal/processing/account/lists.go b/internal/processing/account/lists.go new file mode 100644 index 000000000..167ed3358 --- /dev/null +++ b/internal/processing/account/lists.go @@ -0,0 +1,107 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package account + +import ( + "context" + "errors" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" +) + +var noLists = make([]*apimodel.List, 0) + +// ListsGet returns all lists owned by requestingAccount, which contain a follow for targetAccountID. +func (p *Processor) ListsGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]*apimodel.List, gtserror.WithCode) { + targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorNotFound(errors.New("account not found")) + } + return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err)) + } + + visible, err := p.filter.AccountVisible(ctx, requestingAccount, targetAccount) + if err != nil { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err)) + } + + if !visible { + return nil, gtserror.NewErrorNotFound(errors.New("account not found")) + } + + // Requester has to follow targetAccount + // for them to be in any of their lists. + follow, err := p.state.DB.GetFollow( + // Don't populate follow. + gtscontext.SetBarebones(ctx), + requestingAccount.ID, + targetAccountID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err)) + } + + if follow == nil { + return noLists, nil // by definition we know they're in no lists + } + + listEntries, err := p.state.DB.GetListEntriesForFollowID( + // Don't populate entries. + gtscontext.SetBarebones(ctx), + follow.ID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err)) + } + + count := len(listEntries) + if count == 0 { + return noLists, nil + } + + apiLists := make([]*apimodel.List, 0, count) + for _, listEntry := range listEntries { + list, err := p.state.DB.GetListByID( + // Don't populate list. + gtscontext.SetBarebones(ctx), + listEntry.ListID, + ) + + if err != nil { + log.Debugf(ctx, "skipping list %s due to error %q", listEntry.ListID, err) + continue + } + + apiList, err := p.tc.ListToAPIList(ctx, list) + if err != nil { + log.Debugf(ctx, "skipping list %s due to error %q", listEntry.ListID, err) + continue + } + + apiLists = append(apiLists, apiList) + } + + return apiLists, nil +} diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go index 082a5ba2e..41bf6ee40 100644 --- a/internal/processing/fromclientapi.go +++ b/internal/processing/fromclientapi.go @@ -217,10 +217,10 @@ func (p *Processor) processCreateBlockFromClientAPI(ctx context.Context, clientM } // remove any of the blocking account's statuses from the blocked account's timeline, and vice versa - if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { + if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { return err } - if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { + if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { return err } diff --git a/internal/processing/fromclientapi_test.go b/internal/processing/fromclientapi_test.go index 0b641c091..808f02cd6 100644 --- a/internal/processing/fromclientapi_test.go +++ b/internal/processing/fromclientapi_test.go @@ -20,6 +20,7 @@ package processing_test import ( "context" "encoding/json" + "errors" "testing" "github.com/stretchr/testify/suite" @@ -36,24 +37,21 @@ type FromClientAPITestSuite struct { ProcessingStandardTestSuite } +// This test ensures that when admin_account posts a new +// status, it ends up in the correct streaming timelines +// of local_account_1, which follows it. func (suite *FromClientAPITestSuite) TestProcessStreamNewStatus() { - ctx := context.Background() + var ( + ctx = context.Background() + postingAccount = suite.testAccounts["admin_account"] + receivingAccount = suite.testAccounts["local_account_1"] + testList = suite.testLists["local_account_1_list_1"] + streams = suite.openStreams(ctx, receivingAccount, []string{testList.ID}) + homeStream = streams[stream.TimelineHome] + listStream = streams[stream.TimelineList+":"+testList.ID] + ) - // let's say that the admin account posts a new status: it should end up in the - // timeline of any account that follows it and has a stream open - postingAccount := suite.testAccounts["admin_account"] - receivingAccount := suite.testAccounts["local_account_1"] - - // open a home timeline stream for zork - wssStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelineHome) - suite.NoError(errWithCode) - - // open another stream for zork, but for a different timeline; - // this shouldn't get stuff streamed into it, since it's for the public timeline - irrelevantStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelinePublic) - suite.NoError(errWithCode) - - // make a new status from admin account + // Make a new status from admin account. newStatus := >smodel.Status{ ID: "01FN4B2F88TF9676DYNXWE1WSS", URI: "http://localhost:8080/users/admin/statuses/01FN4B2F88TF9676DYNXWE1WSS", @@ -82,87 +80,110 @@ func (suite *FromClientAPITestSuite) TestProcessStreamNewStatus() { ActivityStreamsType: ap.ObjectNote, } - // put the status in the db first, to mimic what would have already happened earlier up the flow - err := suite.db.PutStatus(ctx, newStatus) - suite.NoError(err) + // Put the status in the db first, to mimic what + // would have already happened earlier up the flow. + if err := suite.db.PutStatus(ctx, newStatus); err != nil { + suite.FailNow(err.Error()) + } - // process the new status - err = suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{ + // Process the new status. + if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, GTSModel: newStatus, OriginAccount: postingAccount, - }) - suite.NoError(err) + }); err != nil { + suite.FailNow(err.Error()) + } - // zork's stream should have the newly created status in it now - msg := <-wssStream.Messages - suite.Equal(stream.EventTypeUpdate, msg.Event) - suite.NotEmpty(msg.Payload) - suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) - statusStreamed := &apimodel.Status{} - err = json.Unmarshal([]byte(msg.Payload), statusStreamed) - suite.NoError(err) - suite.Equal("01FN4B2F88TF9676DYNXWE1WSS", statusStreamed.ID) - suite.Equal("this status should stream :)", statusStreamed.Content) + // Check message in home stream. + homeMsg := <-homeStream.Messages + suite.Equal(stream.EventTypeUpdate, homeMsg.Event) + suite.EqualValues([]string{stream.TimelineHome}, homeMsg.Stream) + suite.Empty(homeStream.Messages) // Stream should now be empty. - // and stream should now be empty - suite.Empty(wssStream.Messages) + // Check status from home stream. + homeStreamStatus := &apimodel.Status{} + if err := json.Unmarshal([]byte(homeMsg.Payload), homeStreamStatus); err != nil { + suite.FailNow(err.Error()) + } + suite.Equal(newStatus.ID, homeStreamStatus.ID) + suite.Equal(newStatus.Content, homeStreamStatus.Content) - // the irrelevant messages stream should also be empty - suite.Empty(irrelevantStream.Messages) + // Check message in list stream. + listMsg := <-listStream.Messages + suite.Equal(stream.EventTypeUpdate, listMsg.Event) + suite.EqualValues([]string{stream.TimelineList + ":" + testList.ID}, listMsg.Stream) + suite.Empty(listStream.Messages) // Stream should now be empty. + + // Check status from list stream. + listStreamStatus := &apimodel.Status{} + if err := json.Unmarshal([]byte(listMsg.Payload), listStreamStatus); err != nil { + suite.FailNow(err.Error()) + } + suite.Equal(newStatus.ID, listStreamStatus.ID) + suite.Equal(newStatus.Content, listStreamStatus.Content) } func (suite *FromClientAPITestSuite) TestProcessStatusDelete() { - ctx := context.Background() + var ( + ctx = context.Background() + deletingAccount = suite.testAccounts["local_account_1"] + receivingAccount = suite.testAccounts["local_account_2"] + deletedStatus = suite.testStatuses["local_account_1_status_1"] + boostOfDeletedStatus = suite.testStatuses["admin_account_status_4"] + streams = suite.openStreams(ctx, receivingAccount, nil) + homeStream = streams[stream.TimelineHome] + ) - deletingAccount := suite.testAccounts["local_account_1"] - receivingAccount := suite.testAccounts["local_account_2"] + // Delete the status from the db first, to mimic what + // would have already happened earlier up the flow + if err := suite.db.DeleteStatusByID(ctx, deletedStatus.ID); err != nil { + suite.FailNow(err.Error()) + } - deletedStatus := suite.testStatuses["local_account_1_status_1"] - boostOfDeletedStatus := suite.testStatuses["admin_account_status_4"] - - // open a home timeline stream for turtle, who follows zork - wssStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelineHome) - suite.NoError(errWithCode) - - // delete the status from the db first, to mimic what would have already happened earlier up the flow - err := suite.db.DeleteStatusByID(ctx, deletedStatus.ID) - suite.NoError(err) - - // process the status delete - err = suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{ + // Process the status delete. + if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityDelete, GTSModel: deletedStatus, OriginAccount: deletingAccount, - }) - suite.NoError(err) + }); err != nil { + suite.FailNow(err.Error()) + } - // turtle's stream should have the delete of admin's boost in it now - msg := <-wssStream.Messages + // Stream should have the delete of admin's boost in it now. + msg := <-homeStream.Messages suite.Equal(stream.EventTypeDelete, msg.Event) suite.Equal(boostOfDeletedStatus.ID, msg.Payload) suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) - // turtle's stream should also have the delete of the message itself in it - msg = <-wssStream.Messages + // Stream should also have the delete of the message itself in it. + msg = <-homeStream.Messages suite.Equal(stream.EventTypeDelete, msg.Event) suite.Equal(deletedStatus.ID, msg.Payload) suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) - // stream should now be empty - suite.Empty(wssStream.Messages) + // Stream should now be empty. + suite.Empty(homeStream.Messages) - // the boost should no longer be in the database - _, err = suite.db.GetStatusByID(ctx, boostOfDeletedStatus.ID) - suite.ErrorIs(err, db.ErrNoEntries) + // Boost should no longer be in the database. + if !testrig.WaitFor(func() bool { + _, err := suite.db.GetStatusByID(ctx, boostOfDeletedStatus.ID) + return errors.Is(err, db.ErrNoEntries) + }) { + suite.FailNow("timed out waiting for status delete") + } } func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() { - ctx := context.Background() - postingAccount := suite.testAccounts["admin_account"] - receivingAccount := suite.testAccounts["local_account_1"] + var ( + ctx = context.Background() + postingAccount = suite.testAccounts["admin_account"] + receivingAccount = suite.testAccounts["local_account_1"] + streams = suite.openStreams(ctx, receivingAccount, nil) + notifStream = streams[stream.TimelineNotifications] + ) // Update the follow from receiving account -> posting account so // that receiving account wants notifs when posting account posts. @@ -204,8 +225,9 @@ func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() { // Put the status in the db first, to mimic what // would have already happened earlier up the flow. - err := suite.db.PutStatus(ctx, newStatus) - suite.NoError(err) + if err := suite.db.PutStatus(ctx, newStatus); err != nil { + suite.FailNow(err.Error()) + } // Process the new status. if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{ @@ -230,6 +252,19 @@ func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() { }) { suite.FailNow("timed out waiting for new status notification") } + + // Check message in notification stream. + notifMsg := <-notifStream.Messages + suite.Equal(stream.EventTypeNotification, notifMsg.Event) + suite.EqualValues([]string{stream.TimelineNotifications}, notifMsg.Stream) + suite.Empty(notifStream.Messages) // Stream should now be empty. + + // Check notif. + notif := &apimodel.Notification{} + if err := json.Unmarshal([]byte(notifMsg.Payload), notif); err != nil { + suite.FailNow(err.Error()) + } + suite.Equal(newStatus.ID, notif.Status.ID) } func TestFromClientAPITestSuite(t *testing.T) { diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index a7ab0b330..0adb576bc 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -30,12 +30,14 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/stream" + "github.com/superseriousbusiness/gotosocial/internal/timeline" ) // timelineAndNotifyStatus processes the given new status and inserts it into -// the HOME timelines of accounts that follow the status author. It will also -// handle notifications for any mentions attached to the account, and also -// notifications for any local accounts that want a notif when this account posts. +// the HOME and LIST timelines of accounts that follow the status author. +// +// It will also handle notifications for any mentions attached to the account, and +// also notifications for any local accounts that want to know when this account posts. func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmodel.Status) error { // Ensure status fully populated; including account, mentions, etc. if err := p.state.DB.PopulateStatus(ctx, status); err != nil { @@ -89,10 +91,43 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta continue } + // Add status to each list that this follow + // is included in, and stream it if applicable. + listEntries, err := p.state.DB.GetListEntriesForFollowID( + // We only need the list IDs. + gtscontext.SetBarebones(ctx), + follow.ID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err)) + continue + } + + for _, listEntry := range listEntries { + if _, err := p.timelineStatus( + ctx, + p.state.Timelines.List.IngestOne, + listEntry.ListID, // list timelines are keyed by list ID + follow.Account, + status, + stream.TimelineList+":"+listEntry.ListID, // key streamType to this specific list + ); err != nil { + errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err)) + continue + } + } + // Add status to home timeline for this // follower, and stream it if applicable. - if timelined, err := p.timelineStatusForAccount(ctx, follow.Account, status); err != nil { - errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error timelining status: %w", err)) + if timelined, err := p.timelineStatus( + ctx, + p.state.Timelines.Home.IngestOne, + follow.AccountID, // home timelines are keyed by account ID + follow.Account, + status, + stream.TimelineHome, + ); err != nil { + errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error home timelining status: %w", err)) continue } else if !timelined { // Status wasn't added to home tomeline, @@ -133,13 +168,21 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta return errs.Combine() } -// timelineStatusForAccount puts the given status in the HOME timeline -// of the account with given accountID, if it's HomeTimelineable. +// timelineStatus uses the provided ingest function to put the given +// status in a timeline with the given ID, if it's timelineable. // -// If the status was inserted into the home timeline of the given account, -// true will be returned + it will also be streamed via websockets to the user. -func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmodel.Account, status *gtsmodel.Status) (bool, error) { +// If the status was inserted into the timeline, true will be returned +// + it will also be streamed to the user using the given streamType. +func (p *Processor) timelineStatus( + ctx context.Context, + ingest func(context.Context, string, timeline.Timelineable) (bool, error), + timelineID string, + account *gtsmodel.Account, + status *gtsmodel.Status, + streamType string, +) (bool, error) { // Make sure the status is timelineable. + // This works for both home and list timelines. if timelineable, err := p.filter.StatusHomeTimelineable(ctx, account, status); err != nil { err = fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", account.ID, err) return false, err @@ -148,8 +191,8 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmo return false, nil } - // Insert status in the home timeline of account. - if inserted, err := p.statusTimelines.IngestOne(ctx, account.ID, status); err != nil { + // Ingest status into given timeline using provided function. + if inserted, err := ingest(ctx, timelineID, status); err != nil { err = fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %w", status.ID, err) return false, err } else if !inserted { @@ -164,7 +207,7 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmo return true, err } - if err := p.stream.Update(apiStatus, account, stream.TimelineHome); err != nil { + if err := p.stream.Update(apiStatus, account, []string{streamType}); err != nil { err = fmt.Errorf("timelineStatusForAccount: error streaming update for status %s: %w", status.ID, err) return true, err } @@ -401,7 +444,7 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta // deleteStatusFromTimelines completely removes the given status from all timelines. // It will also stream deletion of the status to all open streams. func (p *Processor) deleteStatusFromTimelines(ctx context.Context, status *gtsmodel.Status) error { - if err := p.statusTimelines.WipeItemFromAllTimelines(ctx, status.ID); err != nil { + if err := p.state.Timelines.Home.WipeItemFromAllTimelines(ctx, status.ID); err != nil { return err } diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go index eccdbb894..ecb7934c9 100644 --- a/internal/processing/fromfederator.go +++ b/internal/processing/fromfederator.go @@ -342,10 +342,10 @@ func (p *Processor) processCreateBlockFromFederator(ctx context.Context, federat } // remove any of the blocking account's statuses from the blocked account's timeline, and vice versa - if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { + if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil { return err } - if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { + if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil { return err } // TODO: same with notifications diff --git a/internal/processing/list/create.go b/internal/processing/list/create.go new file mode 100644 index 000000000..10dec1050 --- /dev/null +++ b/internal/processing/list/create.go @@ -0,0 +1,50 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package list + +import ( + "context" + "errors" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" +) + +// Create creates one a new list for the given account, using the provided parameters. +// These params should have already been validated by the time they reach this function. +func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, title string, repliesPolicy gtsmodel.RepliesPolicy) (*apimodel.List, gtserror.WithCode) { + list := >smodel.List{ + ID: id.NewULID(), + Title: title, + AccountID: account.ID, + RepliesPolicy: repliesPolicy, + } + + if err := p.state.DB.PutList(ctx, list); err != nil { + if errors.Is(err, db.ErrAlreadyExists) { + err = errors.New("you already have a list with this title") + return nil, gtserror.NewErrorConflict(err, err.Error()) + } + return nil, gtserror.NewErrorInternalError(err) + } + + return p.apiList(ctx, list) +} diff --git a/internal/processing/list/delete.go b/internal/processing/list/delete.go new file mode 100644 index 000000000..1c8ee5700 --- /dev/null +++ b/internal/processing/list/delete.go @@ -0,0 +1,46 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package list + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Delete deletes one list for the given account. +func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, id string) gtserror.WithCode { + list, errWithCode := p.getList( + // Use barebones ctx; no embedded + // structs necessary for this call. + gtscontext.SetBarebones(ctx), + account.ID, + id, + ) + if errWithCode != nil { + return errWithCode + } + + if err := p.state.DB.DeleteListByID(ctx, list.ID); err != nil { + return gtserror.NewErrorInternalError(err) + } + + return nil +} diff --git a/internal/processing/list/get.go b/internal/processing/list/get.go new file mode 100644 index 000000000..3f124fe7c --- /dev/null +++ b/internal/processing/list/get.go @@ -0,0 +1,155 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package list + +import ( + "context" + "errors" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// Get returns the api model of one list with the given ID. +func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.List, gtserror.WithCode) { + list, errWithCode := p.getList( + // Use barebones ctx; no embedded + // structs necessary for this call. + gtscontext.SetBarebones(ctx), + account.ID, + id, + ) + if errWithCode != nil { + return nil, errWithCode + } + + return p.apiList(ctx, list) +} + +// GetMultiple returns multiple lists created by the given account, sorted by list ID DESC (newest first). +func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.List, gtserror.WithCode) { + lists, err := p.state.DB.GetListsForAccountID( + // Use barebones ctx; no embedded + // structs necessary for simple GET. + gtscontext.SetBarebones(ctx), + account.ID, + ) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, nil + } + return nil, gtserror.NewErrorInternalError(err) + } + + apiLists := make([]*apimodel.List, 0, len(lists)) + for _, list := range lists { + apiList, errWithCode := p.apiList(ctx, list) + if errWithCode != nil { + return nil, errWithCode + } + + apiLists = append(apiLists, apiList) + } + + return apiLists, nil +} + +// GetListAccounts returns accounts that are in the given list, owned by the given account. +// The additional parameters can be used for paging. +func (p *Processor) GetListAccounts( + ctx context.Context, + account *gtsmodel.Account, + listID string, + maxID string, + sinceID string, + minID string, + limit int, +) (*apimodel.PageableResponse, gtserror.WithCode) { + // Ensure list exists + is owned by requesting account. + if _, errWithCode := p.getList(ctx, account.ID, listID); errWithCode != nil { + return nil, errWithCode + } + + // To know which accounts are in the list, + // we need to first get requested list entries. + listEntries, err := p.state.DB.GetListEntries(ctx, listID, maxID, sinceID, minID, limit) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err = fmt.Errorf("GetListAccounts: error getting list entries: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + count := len(listEntries) + if count == 0 { + // No list entries means no accounts. + return util.EmptyPageableResponse(), nil + } + + var ( + items = make([]interface{}, count) + nextMaxIDValue string + prevMinIDValue string + ) + + // For each list entry, we want the account it points to. + // To get this, we need to first get the follow that the + // list entry pertains to, then extract the target account + // from that follow. + // + // We do paging not by account ID, but by list entry ID. + for i, listEntry := range listEntries { + if i == count-1 { + nextMaxIDValue = listEntry.ID + } + + if i == 0 { + prevMinIDValue = listEntry.ID + } + + if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil { + log.Debugf(ctx, "skipping list entry because of error populating it: %q", err) + continue + } + + if err := p.state.DB.PopulateFollow(ctx, listEntry.Follow); err != nil { + log.Debugf(ctx, "skipping list entry because of error populating follow: %q", err) + continue + } + + apiAccount, err := p.tc.AccountToAPIAccountPublic(ctx, listEntry.Follow.TargetAccount) + if err != nil { + log.Debugf(ctx, "skipping list entry because of error converting follow target account: %q", err) + continue + } + + items[i] = apiAccount + } + + return util.PackagePageableResponse(util.PageableResponseParams{ + Items: items, + Path: "api/v1/lists/" + listID + "/accounts", + NextMaxIDValue: nextMaxIDValue, + PrevMinIDValue: prevMinIDValue, + Limit: limit, + }) +} diff --git a/internal/processing/list/list.go b/internal/processing/list/list.go new file mode 100644 index 000000000..f192beb60 --- /dev/null +++ b/internal/processing/list/list.go @@ -0,0 +1,35 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package list + +import ( + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" +) + +type Processor struct { + state *state.State + tc typeutils.TypeConverter +} + +func New(state *state.State, tc typeutils.TypeConverter) Processor { + return Processor{ + state: state, + tc: tc, + } +} diff --git a/internal/processing/list/update.go b/internal/processing/list/update.go new file mode 100644 index 000000000..656af1f78 --- /dev/null +++ b/internal/processing/list/update.go @@ -0,0 +1,73 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package list + +import ( + "context" + "errors" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Update updates one list for the given account, using the provided parameters. +// These params should have already been validated by the time they reach this function. +func (p *Processor) Update( + ctx context.Context, + account *gtsmodel.Account, + id string, + title *string, + repliesPolicy *gtsmodel.RepliesPolicy, +) (*apimodel.List, gtserror.WithCode) { + list, errWithCode := p.getList( + // Use barebones ctx; no embedded + // structs necessary for this call. + gtscontext.SetBarebones(ctx), + account.ID, + id, + ) + if errWithCode != nil { + return nil, errWithCode + } + + // Only update columns we're told to update. + columns := make([]string, 0, 2) + + if title != nil { + list.Title = *title + columns = append(columns, "title") + } + + if repliesPolicy != nil { + list.RepliesPolicy = *repliesPolicy + columns = append(columns, "replies_policy") + } + + if err := p.state.DB.UpdateList(ctx, list, columns...); err != nil { + if errors.Is(err, db.ErrAlreadyExists) { + err = errors.New("you already have a list with this title") + return nil, gtserror.NewErrorConflict(err, err.Error()) + } + return nil, gtserror.NewErrorInternalError(err) + } + + return p.apiList(ctx, list) +} diff --git a/internal/processing/list/updateentries.go b/internal/processing/list/updateentries.go new file mode 100644 index 000000000..6dcb951a7 --- /dev/null +++ b/internal/processing/list/updateentries.go @@ -0,0 +1,151 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package list + +import ( + "context" + "errors" + "fmt" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" +) + +// AddToList adds targetAccountIDs to the given list, if valid. +func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode { + // Ensure this list exists + account owns it. + list, errWithCode := p.getList(ctx, account.ID, listID) + if errWithCode != nil { + return errWithCode + } + + // Pre-assemble list of entries to add. We *could* add these + // one by one as we iterate through accountIDs, but according + // to the Mastodon API we should only add them all once we know + // they're all valid, no partial updates. + listEntries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs)) + + // Check each targetAccountID is valid. + // - Follow must exist. + // - Follow must not already be in the given list. + for _, targetAccountID := range targetAccountIDs { + // Ensure follow exists. + follow, err := p.state.DB.GetFollow(ctx, account.ID, targetAccountID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + err = fmt.Errorf("you do not follow account %s", targetAccountID) + return gtserror.NewErrorNotFound(err, err.Error()) + } + return gtserror.NewErrorInternalError(err) + } + + // Ensure followID not already in list. + // This particular call to isInList will + // never error, so just check entryID. + entryID, _ := isInList( + list, + follow.ID, + func(listEntry *gtsmodel.ListEntry) (string, error) { + // Looking for the listEntry follow ID. + return listEntry.FollowID, nil + }, + ) + + // Empty entryID means entry with given + // followID wasn't found in the list. + if entryID != "" { + err = fmt.Errorf("account with id %s is already in list %s with entryID %s", targetAccountID, listID, entryID) + return gtserror.NewErrorUnprocessableEntity(err, err.Error()) + } + + // Entry wasn't in the list, we can add it. + listEntries = append(listEntries, >smodel.ListEntry{ + ID: id.NewULID(), + ListID: listID, + FollowID: follow.ID, + }) + } + + // If we get to here we can assume all + // entries are valid, so try to add them. + if err := p.state.DB.PutListEntries(ctx, listEntries); err != nil { + if errors.Is(err, db.ErrAlreadyExists) { + err = fmt.Errorf("one or more errors inserting list entries: %w", err) + return gtserror.NewErrorUnprocessableEntity(err, err.Error()) + } + return gtserror.NewErrorInternalError(err) + } + + return nil +} + +// RemoveFromList removes targetAccountIDs from the given list, if valid. +func (p *Processor) RemoveFromList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode { + // Ensure this list exists + account owns it. + list, errWithCode := p.getList(ctx, account.ID, listID) + if errWithCode != nil { + return errWithCode + } + + // For each targetAccountID, we want to check if + // a follow with that targetAccountID is in the + // given list. If it is in there, we want to remove + // it from the list. + for _, targetAccountID := range targetAccountIDs { + // Check if targetAccountID is + // on a follow in the list. + entryID, err := isInList( + list, + targetAccountID, + func(listEntry *gtsmodel.ListEntry) (string, error) { + // We need the follow so populate this + // entry, if it's not already populated. + if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil { + return "", err + } + + // Looking for the list entry targetAccountID. + return listEntry.Follow.TargetAccountID, nil + }, + ) + + // Error may be returned here if there was an issue + // populating the list entry. We only return on proper + // DB errors, we can just skip no entry errors. + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err = fmt.Errorf("error checking if targetAccountID %s was in list %s: %w", targetAccountID, listID, err) + return gtserror.NewErrorInternalError(err) + } + + if entryID == "" { + // There was an errNoEntries or targetAccount + // wasn't in this list anyway, so we can skip it. + continue + } + + // TargetAccount was in the list, remove the entry. + if err := p.state.DB.DeleteListEntry(ctx, entryID); err != nil && !errors.Is(err, db.ErrNoEntries) { + err = fmt.Errorf("error removing list entry %s from list %s: %w", entryID, listID, err) + return gtserror.NewErrorInternalError(err) + } + } + + return nil +} diff --git a/internal/processing/list/util.go b/internal/processing/list/util.go new file mode 100644 index 000000000..6186f58c7 --- /dev/null +++ b/internal/processing/list/util.go @@ -0,0 +1,85 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package list + +import ( + "context" + "errors" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// getList is a shortcut to get one list from the database and +// check that it's owned by the given accountID. Will return +// appropriate errors so caller doesn't need to bother. +func (p *Processor) getList(ctx context.Context, accountID string, listID string) (*gtsmodel.List, gtserror.WithCode) { + list, err := p.state.DB.GetListByID(ctx, listID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + // List doesn't seem to exist. + return nil, gtserror.NewErrorNotFound(err) + } + // Real database error. + return nil, gtserror.NewErrorInternalError(err) + } + + if list.AccountID != accountID { + err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, accountID) + return nil, gtserror.NewErrorNotFound(err) + } + + return list, nil +} + +// apiList is a shortcut to return the API version of the given +// list, or return an appropriate error if conversion fails. +func (p *Processor) apiList(ctx context.Context, list *gtsmodel.List) (*apimodel.List, gtserror.WithCode) { + apiList, err := p.tc.ListToAPIList(ctx, list) + if err != nil { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting list to api: %w", err)) + } + + return apiList, nil +} + +// isInList check if thisID is equal to the result of thatID +// for any entry in the given list. +// +// Will return the id of the listEntry if true, empty if false, +// or an error if the result of thatID returns an error. +func isInList( + list *gtsmodel.List, + thisID string, + getThatID func(listEntry *gtsmodel.ListEntry) (string, error), +) (string, error) { + for _, listEntry := range list.ListEntries { + thatID, err := getThatID(listEntry) + if err != nil { + return "", err + } + + if thisID == thatID { + return listEntry.ID, nil + } + } + return "", nil +} diff --git a/internal/processing/notification_test.go b/internal/processing/notification_test.go deleted file mode 100644 index bf69fc9bc..000000000 --- a/internal/processing/notification_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package processing_test - -import ( - "context" - "testing" - - "github.com/stretchr/testify/suite" - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" -) - -type NotificationTestSuite struct { - ProcessingStandardTestSuite -} - -// get a notification where someone has liked our status -func (suite *NotificationTestSuite) TestGetNotifications() { - receivingAccount := suite.testAccounts["local_account_1"] - notifsResponse, err := suite.processor.NotificationsGet(context.Background(), suite.testAutheds["local_account_1"], "", "", "", 10, nil) - suite.NoError(err) - suite.Len(notifsResponse.Items, 1) - notif, ok := notifsResponse.Items[0].(*apimodel.Notification) - if !ok { - panic("notif in response wasn't *apimodel.Notification") - } - - suite.NotNil(notif.Status) - suite.NotNil(notif.Status) - suite.NotNil(notif.Status.Account) - suite.Equal(receivingAccount.ID, notif.Status.Account.ID) - suite.Equal(`; rel="next", ; rel="prev"`, notifsResponse.LinkHeader) -} - -func TestNotificationTestSuite(t *testing.T) { - suite.Run(t, &NotificationTestSuite{}) -} diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 749987d6a..d5f88bfb2 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -29,39 +29,41 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing/account" "github.com/superseriousbusiness/gotosocial/internal/processing/admin" "github.com/superseriousbusiness/gotosocial/internal/processing/fedi" + "github.com/superseriousbusiness/gotosocial/internal/processing/list" "github.com/superseriousbusiness/gotosocial/internal/processing/media" "github.com/superseriousbusiness/gotosocial/internal/processing/report" "github.com/superseriousbusiness/gotosocial/internal/processing/status" "github.com/superseriousbusiness/gotosocial/internal/processing/stream" + "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" "github.com/superseriousbusiness/gotosocial/internal/processing/user" "github.com/superseriousbusiness/gotosocial/internal/state" - "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" ) type Processor struct { - federator federation.Federator - tc typeutils.TypeConverter - oauthServer oauth.Server - mediaManager mm.Manager - statusTimelines timeline.Manager - state *state.State - emailSender email.Sender - filter *visibility.Filter + federator federation.Federator + tc typeutils.TypeConverter + oauthServer oauth.Server + mediaManager mm.Manager + state *state.State + emailSender email.Sender + filter *visibility.Filter /* SUB-PROCESSORS */ - account account.Processor - admin admin.Processor - fedi fedi.Processor - media media.Processor - report report.Processor - status status.Processor - stream stream.Processor - user user.Processor + account account.Processor + admin admin.Processor + fedi fedi.Processor + list list.Processor + media media.Processor + report report.Processor + status status.Processor + stream stream.Processor + timeline timeline.Processor + user user.Processor } func (p *Processor) Account() *account.Processor { @@ -76,6 +78,10 @@ func (p *Processor) Fedi() *fedi.Processor { return &p.fedi } +func (p *Processor) List() *list.Processor { + return &p.list +} + func (p *Processor) Media() *media.Processor { return &p.media } @@ -92,6 +98,10 @@ func (p *Processor) Stream() *stream.Processor { return &p.stream } +func (p *Processor) Timeline() *timeline.Processor { + return &p.timeline +} + func (p *Processor) User() *user.Processor { return &p.user } @@ -114,23 +124,19 @@ func NewProcessor( tc: tc, oauthServer: oauthServer, mediaManager: mediaManager, - statusTimelines: timeline.NewManager( - StatusGrabFunction(state.DB), - StatusFilterFunction(state.DB, filter), - StatusPrepareFunction(state.DB, tc), - StatusSkipInsertFunction(), - ), - state: state, - filter: filter, - emailSender: emailSender, + state: state, + filter: filter, + emailSender: emailSender, } - // sub processors + // Instantiate sub processors. processor.account = account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc) processor.admin = admin.New(state, tc, mediaManager, federator.TransportController(), emailSender) processor.fedi = fedi.New(state, tc, federator, filter) + processor.list = list.New(state, tc) processor.media = media.New(state, tc, mediaManager, federator.TransportController()) processor.report = report.New(state, tc) + processor.timeline = timeline.New(state, tc, filter) processor.status = status.New(state, federator, tc, filter, parseMentionFunc) processor.stream = stream.New(state, oauthServer) processor.user = user.New(state, emailSender) @@ -161,13 +167,3 @@ func (p *Processor) EnqueueFederator(ctx context.Context, msgs ...messages.FromF } }) } - -// Start starts the Processor. -func (p *Processor) Start() error { - return p.statusTimelines.Start() -} - -// Stop stops the processor cleanly. -func (p *Processor) Stop() error { - return p.statusTimelines.Stop() -} diff --git a/internal/processing/processor_test.go b/internal/processing/processor_test.go index e572593d1..68c33aa04 100644 --- a/internal/processing/processor_test.go +++ b/internal/processing/processor_test.go @@ -18,6 +18,8 @@ package processing_test import ( + "context" + "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -28,8 +30,10 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/stream" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -61,6 +65,7 @@ type ProcessingStandardTestSuite struct { testAutheds map[string]*oauth.Auth testBlocks map[string]*gtsmodel.Block testActivities map[string]testrig.ActivityWithSignature + testLists map[string]*gtsmodel.List processor *processing.Processor } @@ -84,6 +89,7 @@ func (suite *ProcessingStandardTestSuite) SetupSuite() { }, } suite.testBlocks = testrig.NewTestBlocks() + suite.testLists = testrig.NewTestLists() } func (suite *ProcessingStandardTestSuite) SetupTest() { @@ -99,6 +105,13 @@ func (suite *ProcessingStandardTestSuite) SetupTest() { suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage suite.typeconverter = testrig.NewTestTypeConverter(suite.db) + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + suite.typeconverter, + ) + suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media") suite.httpClient.TestRemotePeople = testrig.NewTestFediPeople() suite.httpClient.TestRemoteStatuses = testrig.NewTestFediStatuses() @@ -115,16 +128,40 @@ func (suite *ProcessingStandardTestSuite) SetupTest() { testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../testrig/media") - if err := suite.processor.Start(); err != nil { - panic(err) - } } func (suite *ProcessingStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) - if err := suite.processor.Stop(); err != nil { - panic(err) - } testrig.StopWorkers(&suite.state) } + +func (suite *ProcessingStandardTestSuite) openStreams(ctx context.Context, account *gtsmodel.Account, listIDs []string) map[string]*stream.Stream { + streams := make(map[string]*stream.Stream) + + for _, streamType := range []string{ + stream.TimelineHome, + stream.TimelinePublic, + stream.TimelineNotifications, + } { + stream, err := suite.processor.Stream().Open(ctx, account, streamType) + if err != nil { + suite.FailNow(err.Error()) + } + + streams[streamType] = stream + } + + for _, listID := range listIDs { + streamType := stream.TimelineList + ":" + listID + + stream, err := suite.processor.Stream().Open(ctx, account, streamType) + if err != nil { + suite.FailNow(err.Error()) + } + + streams[streamType] = stream + } + + return streams +} diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go index 0de56c30e..01d8d3acd 100644 --- a/internal/processing/status/status_test.go +++ b/internal/processing/status/status_test.go @@ -88,6 +88,12 @@ func (suite *StatusStandardTestSuite) SetupTest() { suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager) filter := visibility.NewFilter(&suite.state) + testrig.StartTimelines( + &suite.state, + filter, + testrig.NewTestTypeConverter(suite.db), + ) + suite.status = status.New(&suite.state, suite.federator, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) testrig.StandardDBSetup(suite.db, suite.testAccounts) diff --git a/internal/processing/statustimeline.go b/internal/processing/statustimeline.go deleted file mode 100644 index 39c5272b6..000000000 --- a/internal/processing/statustimeline.go +++ /dev/null @@ -1,309 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package processing - -import ( - "context" - "errors" - "fmt" - - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/oauth" - "github.com/superseriousbusiness/gotosocial/internal/timeline" - "github.com/superseriousbusiness/gotosocial/internal/typeutils" - "github.com/superseriousbusiness/gotosocial/internal/util" - "github.com/superseriousbusiness/gotosocial/internal/visibility" -) - -const boostReinsertionDepth = 50 - -// StatusGrabFunction returns a function that satisfies the GrabFunction interface in internal/timeline. -func StatusGrabFunction(database db.DB) timeline.GrabFunction { - return func(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) { - statuses, err := database.GetHomeTimeline(ctx, timelineAccountID, maxID, sinceID, minID, limit, false) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - return nil, true, nil // we just don't have enough statuses left in the db so return stop = true - } - return nil, false, fmt.Errorf("statusGrabFunction: error getting statuses from db: %w", err) - } - - items := make([]timeline.Timelineable, len(statuses)) - for i, s := range statuses { - items[i] = s - } - - return items, false, nil - } -} - -// StatusFilterFunction returns a function that satisfies the FilterFunction interface in internal/timeline. -func StatusFilterFunction(database db.DB, filter *visibility.Filter) timeline.FilterFunction { - return func(ctx context.Context, timelineAccountID string, item timeline.Timelineable) (shouldIndex bool, err error) { - status, ok := item.(*gtsmodel.Status) - if !ok { - return false, errors.New("StatusFilterFunction: could not convert item to *gtsmodel.Status") - } - - requestingAccount, err := database.GetAccountByID(ctx, timelineAccountID) - if err != nil { - return false, fmt.Errorf("StatusFilterFunction: error getting account with id %s: %w", timelineAccountID, err) - } - - timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status) - if err != nil { - return false, fmt.Errorf("StatusFilterFunction: error checking hometimelineability of status %s for account %s: %w", status.ID, timelineAccountID, err) - } - - return timelineable, nil - } -} - -// StatusPrepareFunction returns a function that satisfies the PrepareFunction interface in internal/timeline. -func StatusPrepareFunction(database db.DB, tc typeutils.TypeConverter) timeline.PrepareFunction { - return func(ctx context.Context, timelineAccountID string, itemID string) (timeline.Preparable, error) { - status, err := database.GetStatusByID(ctx, itemID) - if err != nil { - return nil, fmt.Errorf("StatusPrepareFunction: error getting status with id %s: %w", itemID, err) - } - - requestingAccount, err := database.GetAccountByID(ctx, timelineAccountID) - if err != nil { - return nil, fmt.Errorf("StatusPrepareFunction: error getting account with id %s: %w", timelineAccountID, err) - } - - return tc.StatusToAPIStatus(ctx, status, requestingAccount) - } -} - -// StatusSkipInsertFunction returns a function that satisifes the SkipInsertFunction interface in internal/timeline. -func StatusSkipInsertFunction() timeline.SkipInsertFunction { - return func( - ctx context.Context, - newItemID string, - newItemAccountID string, - newItemBoostOfID string, - newItemBoostOfAccountID string, - nextItemID string, - nextItemAccountID string, - nextItemBoostOfID string, - nextItemBoostOfAccountID string, - depth int, - ) (bool, error) { - // make sure we don't insert a duplicate - if newItemID == nextItemID { - return true, nil - } - - // check if it's a boost - if newItemBoostOfID != "" { - // skip if we've recently put another boost of this status in the timeline - if newItemBoostOfID == nextItemBoostOfID { - if depth < boostReinsertionDepth { - return true, nil - } - } - - // skip if we've recently put the original status in the timeline - if newItemBoostOfID == nextItemID { - if depth < boostReinsertionDepth { - return true, nil - } - } - } - - // insert the item - return false, nil - } -} - -func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) { - statuses, err := p.statusTimelines.GetTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local) - if err != nil { - err = fmt.Errorf("HomeTimelineGet: error getting statuses: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - count := len(statuses) - if count == 0 { - return util.EmptyPageableResponse(), nil - } - - var ( - items = make([]interface{}, count) - nextMaxIDValue string - prevMinIDValue string - ) - - for i, item := range statuses { - if i == count-1 { - nextMaxIDValue = item.GetID() - } - - if i == 0 { - prevMinIDValue = item.GetID() - } - - items[i] = item - } - - return util.PackagePageableResponse(util.PageableResponseParams{ - Items: items, - Path: "api/v1/timelines/home", - NextMaxIDValue: nextMaxIDValue, - PrevMinIDValue: prevMinIDValue, - Limit: limit, - }) -} - -func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) { - statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // No statuses (left) in public timeline. - return util.EmptyPageableResponse(), nil - } - // An actual error has occurred. - err = fmt.Errorf("PublicTimelineGet: db error getting statuses: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - count := len(statuses) - if count == 0 { - return util.EmptyPageableResponse(), nil - } - - var ( - items = make([]interface{}, 0, count) - nextMaxIDValue string - prevMinIDValue string - ) - - for i, s := range statuses { - // Set next + prev values before filtering and API - // converting, so caller can still page properly. - if i == count-1 { - nextMaxIDValue = s.ID - } - - if i == 0 { - prevMinIDValue = s.ID - } - - timelineable, err := p.filter.StatusPublicTimelineable(ctx, authed.Account, s) - if err != nil { - log.Debugf(ctx, "skipping status %s because of an error checking StatusPublicTimelineable: %s", s.ID, err) - continue - } - - if !timelineable { - continue - } - - apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account) - if err != nil { - log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err) - continue - } - - items = append(items, apiStatus) - } - - return util.PackagePageableResponse(util.PageableResponseParams{ - Items: items, - Path: "api/v1/timelines/public", - NextMaxIDValue: nextMaxIDValue, - PrevMinIDValue: prevMinIDValue, - Limit: limit, - }) -} - -func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) { - statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // There are just no entries (left). - return util.EmptyPageableResponse(), nil - } - // An actual error has occurred. - err = fmt.Errorf("FavedTimelineGet: db error getting statuses: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - count := len(statuses) - if count == 0 { - return util.EmptyPageableResponse(), nil - } - - filtered, err := p.filterFavedStatuses(ctx, authed, statuses) - if err != nil { - err = fmt.Errorf("FavedTimelineGet: error filtering statuses: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - items := make([]interface{}, len(filtered)) - for i, item := range filtered { - items[i] = item - } - - return util.PackagePageableResponse(util.PageableResponseParams{ - Items: items, - Path: "api/v1/favourites", - NextMaxIDValue: nextMaxID, - PrevMinIDValue: prevMinID, - Limit: limit, - }) -} - -func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) { - apiStatuses := make([]*apimodel.Status, 0, len(statuses)) - - for _, s := range statuses { - if _, err := p.state.DB.GetAccountByID(ctx, s.AccountID); err != nil { - if errors.Is(err, db.ErrNoEntries) { - log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID) - continue - } - err = fmt.Errorf("filterFavedStatuses: db error getting status author: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - timelineable, err := p.filter.StatusVisible(ctx, authed.Account, s) - if err != nil { - log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err) - continue - } - if !timelineable { - continue - } - - apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account) - if err != nil { - log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err) - continue - } - - apiStatuses = append(apiStatuses, apiStatus) - } - - return apiStatuses, nil -} diff --git a/internal/processing/stream/open.go b/internal/processing/stream/open.go index e43152b29..1c041309f 100644 --- a/internal/processing/stream/open.go +++ b/internal/processing/stream/open.go @@ -31,60 +31,65 @@ import ( ) // Open returns a new Stream for the given account, which will contain a channel for passing messages back to the caller. -func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamTimeline string) (*stream.Stream, gtserror.WithCode) { +func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamType string) (*stream.Stream, gtserror.WithCode) { l := log.WithContext(ctx).WithFields(kv.Fields{ {"account", account.ID}, - {"streamType", streamTimeline}, + {"streamType", streamType}, }...) l.Debug("received open stream request") - // each stream needs a unique ID so we know to close it - streamID, err := id.NewRandomULID() + var ( + streamID string + err error + ) + + // Each stream needs a unique ID so we know to close it. + streamID, err = id.NewRandomULID() if err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err)) } - // Each stream can be subscibed to multiple timelines. + // Each stream can be subscibed to multiple types. // Record them in a set, and include the initial one - // if it was given to us - timelines := map[string]bool{} - if streamTimeline != "" { - timelines[streamTimeline] = true + // if it was given to us. + streamTypes := map[string]any{} + if streamType != "" { + streamTypes[streamType] = true } - thisStream := &stream.Stream{ - ID: streamID, - Timelines: timelines, - Messages: make(chan *stream.Message, 100), - Hangup: make(chan interface{}, 1), - Connected: true, + newStream := &stream.Stream{ + ID: streamID, + StreamTypes: streamTypes, + Messages: make(chan *stream.Message, 100), + Hangup: make(chan interface{}, 1), + Connected: true, } - go p.waitToCloseStream(account, thisStream) + go p.waitToCloseStream(account, newStream) v, ok := p.streamMap.Load(account.ID) - if !ok || v == nil { - // there is no entry in the streamMap for this account yet, so make one and store it - streamsForAccount := &stream.StreamsForAccount{ - Streams: []*stream.Stream{ - thisStream, - }, - } - p.streamMap.Store(account.ID, streamsForAccount) - } else { - // there is an entry in the streamMap for this account - // parse the interface as a streamsForAccount + if ok { + // There is an entry in the streamMap + // for this account. Parse it out. streamsForAccount, ok := v.(*stream.StreamsForAccount) if !ok { return nil, gtserror.NewErrorInternalError(errors.New("stream map error")) } - // append this stream to it + // Append new stream to existing entry. streamsForAccount.Lock() - streamsForAccount.Streams = append(streamsForAccount.Streams, thisStream) + streamsForAccount.Streams = append(streamsForAccount.Streams, newStream) streamsForAccount.Unlock() + } else { + // There is no entry in the streamMap for + // this account yet. Create one and store it. + p.streamMap.Store(account.ID, &stream.StreamsForAccount{ + Streams: []*stream.Stream{ + newStream, + }, + }) } - return thisStream, nil + return newStream, nil } // waitToCloseStream waits until the hangup channel is closed for the given stream. diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go index 4a4c92a00..bd49a330c 100644 --- a/internal/processing/stream/stream.go +++ b/internal/processing/stream/stream.go @@ -18,7 +18,6 @@ package stream import ( - "errors" "sync" "github.com/superseriousbusiness/gotosocial/internal/oauth" @@ -40,37 +39,38 @@ func New(state *state.State, oauthServer oauth.Server) Processor { } // toAccount streams the given payload with the given event type to any streams currently open for the given account ID. -func (p *Processor) toAccount(payload string, event string, timelines []string, accountID string) error { +func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error { + // Load all streams open for this account. v, ok := p.streamMap.Load(accountID) if !ok { - // no open connections so nothing to stream - return nil - } - - streamsForAccount, ok := v.(*stream.StreamsForAccount) - if !ok { - return errors.New("stream map error") + return nil // No entry = nothing to stream. } + streamsForAccount := v.(*stream.StreamsForAccount) //nolint:forcetypeassert streamsForAccount.Lock() defer streamsForAccount.Unlock() + for _, s := range streamsForAccount.Streams { s.Lock() defer s.Unlock() + if !s.Connected { continue } - for _, t := range timelines { - if _, found := s.Timelines[t]; found { + typeLoop: + for _, streamType := range streamTypes { + if _, found := s.StreamTypes[streamType]; found { s.Messages <- &stream.Message{ - Stream: []string{string(t)}, + Stream: []string{streamType}, Event: string(event), Payload: payload, } - // break out to the outer loop, to avoid sending duplicates - // of the same event to the same stream - break + + // Break out to the outer loop, + // to avoid sending duplicates of + // the same event to the same stream. + break typeLoop } } } diff --git a/internal/processing/stream/update.go b/internal/processing/stream/update.go index dc575c636..ee70bda11 100644 --- a/internal/processing/stream/update.go +++ b/internal/processing/stream/update.go @@ -27,11 +27,11 @@ import ( ) // Update streams the given update to any open, appropriate streams belonging to the given account. -func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, timeline string) error { +func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { bytes, err := json.Marshal(s) if err != nil { return fmt.Errorf("error marshalling status to json: %s", err) } - return p.toAccount(string(bytes), stream.EventTypeUpdate, []string{timeline}, account.ID) + return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID) } diff --git a/internal/processing/timeline/common.go b/internal/processing/timeline/common.go new file mode 100644 index 000000000..6d29d81d6 --- /dev/null +++ b/internal/processing/timeline/common.go @@ -0,0 +1,71 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package timeline + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/timeline" +) + +// SkipInsert returns a function that satisifes SkipInsertFunction. +func SkipInsert() timeline.SkipInsertFunction { + // Gap to allow between a status or boost of status, + // and reinsertion of a new boost of that status. + // This is useful to avoid a heavily boosted status + // showing up way too often in a user's timeline. + const boostReinsertionDepth = 50 + + return func( + ctx context.Context, + newItemID string, + newItemAccountID string, + newItemBoostOfID string, + newItemBoostOfAccountID string, + nextItemID string, + nextItemAccountID string, + nextItemBoostOfID string, + nextItemBoostOfAccountID string, + depth int, + ) (bool, error) { + if newItemID == nextItemID { + // Don't insert duplicates. + return true, nil + } + + if newItemBoostOfID != "" { + if newItemBoostOfID == nextItemBoostOfID && + depth < boostReinsertionDepth { + // Don't insert boosts of items + // we've seen boosted recently. + return true, nil + } + + if newItemBoostOfID == nextItemID && + depth < boostReinsertionDepth { + // Don't insert boosts of items when + // we've seen the original recently. + return true, nil + } + } + + // Proceed with insertion + // (that's what she said!). + return false, nil + } +} diff --git a/internal/processing/timeline/faved.go b/internal/processing/timeline/faved.go new file mode 100644 index 000000000..0fc92d8fa --- /dev/null +++ b/internal/processing/timeline/faved.go @@ -0,0 +1,73 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package timeline + +import ( + "context" + "errors" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) { + statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err = fmt.Errorf("FavedTimelineGet: db error getting statuses: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + count := len(statuses) + if count == 0 { + return util.EmptyPageableResponse(), nil + } + + items := make([]interface{}, 0, count) + for _, s := range statuses { + visible, err := p.filter.StatusVisible(ctx, authed.Account, s) + if err != nil { + log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err) + continue + } + + if !visible { + continue + } + + apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account) + if err != nil { + log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err) + continue + } + + items = append(items, apiStatus) + } + + return util.PackagePageableResponse(util.PageableResponseParams{ + Items: items, + Path: "api/v1/favourites", + NextMaxIDValue: nextMaxID, + PrevMinIDValue: prevMinID, + Limit: limit, + }) +} diff --git a/internal/processing/timeline/home.go b/internal/processing/timeline/home.go new file mode 100644 index 000000000..e65f12e17 --- /dev/null +++ b/internal/processing/timeline/home.go @@ -0,0 +1,133 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package timeline + +import ( + "context" + "errors" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/timeline" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/superseriousbusiness/gotosocial/internal/visibility" +) + +// HomeTimelineGrab returns a function that satisfies GrabFunction for home timelines. +func HomeTimelineGrab(state *state.State) timeline.GrabFunction { + return func(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) { + statuses, err := state.DB.GetHomeTimeline(ctx, accountID, maxID, sinceID, minID, limit, false) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, true, nil // we just don't have enough statuses left in the db so return stop = true + } + return nil, false, fmt.Errorf("HomeTimelineGrab: error getting statuses from db: %w", err) + } + + items := make([]timeline.Timelineable, len(statuses)) + for i, s := range statuses { + items[i] = s + } + + return items, false, nil + } +} + +// HomeTimelineFilter returns a function that satisfies FilterFunction for home timelines. +func HomeTimelineFilter(state *state.State, filter *visibility.Filter) timeline.FilterFunction { + return func(ctx context.Context, accountID string, item timeline.Timelineable) (shouldIndex bool, err error) { + status, ok := item.(*gtsmodel.Status) + if !ok { + return false, errors.New("HomeTimelineFilter: could not convert item to *gtsmodel.Status") + } + + requestingAccount, err := state.DB.GetAccountByID(ctx, accountID) + if err != nil { + return false, fmt.Errorf("HomeTimelineFilter: error getting account with id %s: %w", accountID, err) + } + + timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status) + if err != nil { + return false, fmt.Errorf("HomeTimelineFilter: error checking hometimelineability of status %s for account %s: %w", status.ID, accountID, err) + } + + return timelineable, nil + } +} + +// HomeTimelineStatusPrepare returns a function that satisfies PrepareFunction for home timelines. +func HomeTimelineStatusPrepare(state *state.State, tc typeutils.TypeConverter) timeline.PrepareFunction { + return func(ctx context.Context, accountID string, itemID string) (timeline.Preparable, error) { + status, err := state.DB.GetStatusByID(ctx, itemID) + if err != nil { + return nil, fmt.Errorf("StatusPrepare: error getting status with id %s: %w", itemID, err) + } + + requestingAccount, err := state.DB.GetAccountByID(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("StatusPrepare: error getting account with id %s: %w", accountID, err) + } + + return tc.StatusToAPIStatus(ctx, status, requestingAccount) + } +} + +func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) { + statuses, err := p.state.Timelines.Home.GetTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local) + if err != nil { + err = fmt.Errorf("HomeTimelineGet: error getting statuses: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + count := len(statuses) + if count == 0 { + return util.EmptyPageableResponse(), nil + } + + var ( + items = make([]interface{}, count) + nextMaxIDValue string + prevMinIDValue string + ) + + for i, item := range statuses { + if i == count-1 { + nextMaxIDValue = item.GetID() + } + + if i == 0 { + prevMinIDValue = item.GetID() + } + + items[i] = item + } + + return util.PackagePageableResponse(util.PageableResponseParams{ + Items: items, + Path: "api/v1/timelines/home", + NextMaxIDValue: nextMaxIDValue, + PrevMinIDValue: prevMinIDValue, + Limit: limit, + }) +} diff --git a/internal/processing/timeline/list.go b/internal/processing/timeline/list.go new file mode 100644 index 000000000..adad35197 --- /dev/null +++ b/internal/processing/timeline/list.go @@ -0,0 +1,157 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package timeline + +import ( + "context" + "errors" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/timeline" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/superseriousbusiness/gotosocial/internal/visibility" +) + +// ListTimelineGrab returns a function that satisfies GrabFunction for list timelines. +func ListTimelineGrab(state *state.State) timeline.GrabFunction { + return func(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) { + statuses, err := state.DB.GetListTimeline(ctx, listID, maxID, sinceID, minID, limit) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, true, nil // we just don't have enough statuses left in the db so return stop = true + } + return nil, false, fmt.Errorf("ListTimelineGrab: error getting statuses from db: %w", err) + } + + items := make([]timeline.Timelineable, len(statuses)) + for i, s := range statuses { + items[i] = s + } + + return items, false, nil + } +} + +// HomeTimelineFilter returns a function that satisfies FilterFunction for list timelines. +func ListTimelineFilter(state *state.State, filter *visibility.Filter) timeline.FilterFunction { + return func(ctx context.Context, listID string, item timeline.Timelineable) (shouldIndex bool, err error) { + status, ok := item.(*gtsmodel.Status) + if !ok { + return false, errors.New("ListTimelineFilter: could not convert item to *gtsmodel.Status") + } + + list, err := state.DB.GetListByID(ctx, listID) + if err != nil { + return false, fmt.Errorf("ListTimelineFilter: error getting list with id %s: %w", listID, err) + } + + requestingAccount, err := state.DB.GetAccountByID(ctx, list.AccountID) + if err != nil { + return false, fmt.Errorf("ListTimelineFilter: error getting account with id %s: %w", list.AccountID, err) + } + + timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status) + if err != nil { + return false, fmt.Errorf("ListTimelineFilter: error checking hometimelineability of status %s for account %s: %w", status.ID, list.AccountID, err) + } + + return timelineable, nil + } +} + +// ListTimelineStatusPrepare returns a function that satisfies PrepareFunction for list timelines. +func ListTimelineStatusPrepare(state *state.State, tc typeutils.TypeConverter) timeline.PrepareFunction { + return func(ctx context.Context, listID string, itemID string) (timeline.Preparable, error) { + status, err := state.DB.GetStatusByID(ctx, itemID) + if err != nil { + return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting status with id %s: %w", itemID, err) + } + + list, err := state.DB.GetListByID(ctx, listID) + if err != nil { + return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting list with id %s: %w", listID, err) + } + + requestingAccount, err := state.DB.GetAccountByID(ctx, list.AccountID) + if err != nil { + return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting account with id %s: %w", list.AccountID, err) + } + + return tc.StatusToAPIStatus(ctx, status, requestingAccount) + } +} + +func (p *Processor) ListTimelineGet(ctx context.Context, authed *oauth.Auth, listID string, maxID string, sinceID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) { + // Ensure list exists + is owned by this account. + list, err := p.state.DB.GetListByID(ctx, listID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorNotFound(err) + } + return nil, gtserror.NewErrorInternalError(err) + } + + if list.AccountID != authed.Account.ID { + err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, authed.Account.ID) + return nil, gtserror.NewErrorNotFound(err) + } + + statuses, err := p.state.Timelines.List.GetTimeline(ctx, listID, maxID, sinceID, minID, limit, false) + if err != nil { + err = fmt.Errorf("ListTimelineGet: error getting statuses: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + count := len(statuses) + if count == 0 { + return util.EmptyPageableResponse(), nil + } + + var ( + items = make([]interface{}, count) + nextMaxIDValue string + prevMinIDValue string + ) + + for i, item := range statuses { + if i == count-1 { + nextMaxIDValue = item.GetID() + } + + if i == 0 { + prevMinIDValue = item.GetID() + } + + items[i] = item + } + + return util.PackagePageableResponse(util.PageableResponseParams{ + Items: items, + Path: "api/v1/timelines/list/" + listID, + NextMaxIDValue: nextMaxIDValue, + PrevMinIDValue: prevMinIDValue, + Limit: limit, + }) +} diff --git a/internal/processing/notification.go b/internal/processing/timeline/notification.go similarity index 96% rename from internal/processing/notification.go rename to internal/processing/timeline/notification.go index 2e4e1788f..4a79fb82a 100644 --- a/internal/processing/notification.go +++ b/internal/processing/timeline/notification.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package processing +package timeline import ( "context" @@ -33,12 +33,7 @@ import ( func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, excludeTypes []string) (*apimodel.PageableResponse, gtserror.WithCode) { notifs, err := p.state.DB.GetAccountNotifications(ctx, authed.Account.ID, maxID, sinceID, minID, limit, excludeTypes) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // No notifs (left). - return util.EmptyPageableResponse(), nil - } - // An actual error has occurred. + if err != nil && !errors.Is(err, db.ErrNoEntries) { err = fmt.Errorf("NotificationsGet: db error getting notifications: %w", err) return nil, gtserror.NewErrorInternalError(err) } @@ -73,6 +68,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ma log.Debugf(ctx, "skipping notification %s because of an error checking notification visibility: %s", n.ID, err) continue } + if !visible { continue } @@ -85,6 +81,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ma log.Debugf(ctx, "skipping notification %s because of an error checking notification visibility: %s", n.ID, err) continue } + if !visible { continue } diff --git a/internal/processing/timeline/public.go b/internal/processing/timeline/public.go new file mode 100644 index 000000000..67893ecfa --- /dev/null +++ b/internal/processing/timeline/public.go @@ -0,0 +1,88 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package timeline + +import ( + "context" + "errors" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) { + statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err = fmt.Errorf("PublicTimelineGet: db error getting statuses: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + count := len(statuses) + if count == 0 { + return util.EmptyPageableResponse(), nil + } + + var ( + items = make([]interface{}, 0, count) + nextMaxIDValue string + prevMinIDValue string + ) + + for i, s := range statuses { + // Set next + prev values before filtering and API + // converting, so caller can still page properly. + if i == count-1 { + nextMaxIDValue = s.ID + } + + if i == 0 { + prevMinIDValue = s.ID + } + + timelineable, err := p.filter.StatusPublicTimelineable(ctx, authed.Account, s) + if err != nil { + log.Debugf(ctx, "skipping status %s because of an error checking StatusPublicTimelineable: %s", s.ID, err) + continue + } + + if !timelineable { + continue + } + + apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account) + if err != nil { + log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err) + continue + } + + items = append(items, apiStatus) + } + + return util.PackagePageableResponse(util.PageableResponseParams{ + Items: items, + Path: "api/v1/timelines/public", + NextMaxIDValue: nextMaxIDValue, + PrevMinIDValue: prevMinIDValue, + Limit: limit, + }) +} diff --git a/internal/processing/timeline/timeline.go b/internal/processing/timeline/timeline.go new file mode 100644 index 000000000..7a95f9a11 --- /dev/null +++ b/internal/processing/timeline/timeline.go @@ -0,0 +1,38 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package timeline + +import ( + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" +) + +type Processor struct { + state *state.State + tc typeutils.TypeConverter + filter *visibility.Filter +} + +func New(state *state.State, tc typeutils.TypeConverter, filter *visibility.Filter) Processor { + return Processor{ + state: state, + tc: tc, + filter: filter, + } +} diff --git a/internal/state/state.go b/internal/state/state.go index f374bc162..6ff1baa52 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -21,6 +21,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/workers" ) @@ -34,6 +35,9 @@ type State struct { // Caches provides access to this state's collection of caches. Caches cache.Caches + // Timelines provides access to this state's collection of timelines. + Timelines timeline.Timelines + // DB provides access to the database. DB db.DB diff --git a/internal/stream/stream.go b/internal/stream/stream.go index a5b5bd38b..ae815e029 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -39,6 +39,8 @@ const ( TimelineNotifications string = "user:notification" // TimelineDirect -- statuses sent to a user directly. TimelineDirect string = "direct" + // TimelineList -- statuses for a user's list timeline. + TimelineList string = "list" ) // AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes. @@ -47,6 +49,7 @@ var AllStatusTimelines = []string{ TimelinePublic, TimelineHome, TimelineDirect, + TimelineList, } // StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. @@ -62,10 +65,9 @@ type StreamsForAccount struct { type Stream struct { // ID of this stream, generated during creation. ID string - // A set of timelines of this stream: user/public/etc - // a matching key means the timeline is subscribed. The value - // is ignored - Timelines map[string]bool + // A set of types subscribed to by this stream: user/public/etc. + // It's a map to ensure no duplicates; the value is ignored. + StreamTypes map[string]any // Channel of messages for the client to read from Messages chan *Message // Channel to close when the client drops away diff --git a/internal/timeline/get.go b/internal/timeline/get.go index 4ca9023f2..a93e2d1ed 100644 --- a/internal/timeline/get.go +++ b/internal/timeline/get.go @@ -39,7 +39,7 @@ func (t *timeline) LastGot() time.Time { func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID string, minID string, prepareNext bool) ([]Preparable, error) { l := log.WithContext(ctx). WithFields(kv.Fields{ - {"accountID", t.accountID}, + {"accountID", t.timelineID}, {"amount", amount}, {"maxID", maxID}, {"sinceID", sinceID}, @@ -244,7 +244,7 @@ func (t *timeline) getXBetweenIDs(ctx context.Context, amount int, behindID stri if entry.prepared == nil { // Whoops, this entry isn't prepared yet; some // race condition? That's OK, we can do it now. - prepared, err := t.prepareFunction(ctx, t.accountID, entry.itemID) + prepared, err := t.prepareFunction(ctx, t.timelineID, entry.itemID) if err != nil { if errors.Is(err, db.ErrNoEntries) { // ErrNoEntries means something has been deleted, @@ -338,7 +338,7 @@ func (t *timeline) getXBetweenIDs(ctx context.Context, amount int, behindID stri if entry.prepared == nil { // Whoops, this entry isn't prepared yet; some // race condition? That's OK, we can do it now. - prepared, err := t.prepareFunction(ctx, t.accountID, entry.itemID) + prepared, err := t.prepareFunction(ctx, t.timelineID, entry.itemID) if err != nil { if errors.Is(err, db.ErrNoEntries) { // ErrNoEntries means something has been deleted, diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go index 444c159c4..f99e58611 100644 --- a/internal/timeline/get_test.go +++ b/internal/timeline/get_test.go @@ -26,7 +26,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" - "github.com/superseriousbusiness/gotosocial/internal/processing" + tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" @@ -58,10 +58,10 @@ func (suite *GetTestSuite) SetupTest() { tl := timeline.NewTimeline( context.Background(), suite.testAccounts["local_account_1"].ID, - processing.StatusGrabFunction(suite.db), - processing.StatusFilterFunction(suite.db, suite.filter), - processing.StatusPrepareFunction(suite.db, suite.tc), - processing.StatusSkipInsertFunction(), + tlprocessor.HomeTimelineGrab(&suite.state), + tlprocessor.HomeTimelineFilter(&suite.state, suite.filter), + tlprocessor.HomeTimelineStatusPrepare(&suite.state, suite.tc), + tlprocessor.SkipInsert(), ) // Put testrig statuses in a determinate order @@ -134,10 +134,10 @@ func (suite *GetTestSuite) TestGetNewTimelinePageDown() { tl := timeline.NewTimeline( context.Background(), suite.testAccounts["local_account_1"].ID, - processing.StatusGrabFunction(suite.db), - processing.StatusFilterFunction(suite.db, suite.filter), - processing.StatusPrepareFunction(suite.db, suite.tc), - processing.StatusSkipInsertFunction(), + tlprocessor.HomeTimelineGrab(&suite.state), + tlprocessor.HomeTimelineFilter(&suite.state, suite.filter), + tlprocessor.HomeTimelineStatusPrepare(&suite.state, suite.tc), + tlprocessor.SkipInsert(), ) // Get 5 from the top. @@ -163,10 +163,10 @@ func (suite *GetTestSuite) TestGetNewTimelinePageUp() { tl := timeline.NewTimeline( context.Background(), suite.testAccounts["local_account_1"].ID, - processing.StatusGrabFunction(suite.db), - processing.StatusFilterFunction(suite.db, suite.filter), - processing.StatusPrepareFunction(suite.db, suite.tc), - processing.StatusSkipInsertFunction(), + tlprocessor.HomeTimelineGrab(&suite.state), + tlprocessor.HomeTimelineFilter(&suite.state, suite.filter), + tlprocessor.HomeTimelineStatusPrepare(&suite.state, suite.tc), + tlprocessor.SkipInsert(), ) // Get 5 from the back. @@ -192,10 +192,10 @@ func (suite *GetTestSuite) TestGetNewTimelineMoreThanPossible() { tl := timeline.NewTimeline( context.Background(), suite.testAccounts["local_account_1"].ID, - processing.StatusGrabFunction(suite.db), - processing.StatusFilterFunction(suite.db, suite.filter), - processing.StatusPrepareFunction(suite.db, suite.tc), - processing.StatusSkipInsertFunction(), + tlprocessor.HomeTimelineGrab(&suite.state), + tlprocessor.HomeTimelineFilter(&suite.state, suite.filter), + tlprocessor.HomeTimelineStatusPrepare(&suite.state, suite.tc), + tlprocessor.SkipInsert(), ) // Get 100 from the top. @@ -213,10 +213,10 @@ func (suite *GetTestSuite) TestGetNewTimelineMoreThanPossiblePageUp() { tl := timeline.NewTimeline( context.Background(), suite.testAccounts["local_account_1"].ID, - processing.StatusGrabFunction(suite.db), - processing.StatusFilterFunction(suite.db, suite.filter), - processing.StatusPrepareFunction(suite.db, suite.tc), - processing.StatusSkipInsertFunction(), + tlprocessor.HomeTimelineGrab(&suite.state), + tlprocessor.HomeTimelineFilter(&suite.state, suite.filter), + tlprocessor.HomeTimelineStatusPrepare(&suite.state, suite.tc), + tlprocessor.SkipInsert(), ) // Get 100 from the back. diff --git a/internal/timeline/index.go b/internal/timeline/index.go index a45617134..2d556a3b2 100644 --- a/internal/timeline/index.go +++ b/internal/timeline/index.go @@ -167,7 +167,7 @@ func (t *timeline) grab(ctx context.Context, amount int, behindID string, before items, stop, err := t.grabFunction( ctx, - t.accountID, + t.timelineID, maxID, sinceID, minID, @@ -205,7 +205,7 @@ func (t *timeline) grab(ctx context.Context, amount int, behindID string, before } for _, item := range items { - ok, err := t.filterFunction(ctx, t.accountID, item) + ok, err := t.filterFunction(ctx, t.timelineID, item) if err != nil { if !errors.Is(err, db.ErrNoEntries) { // Real error here. @@ -244,7 +244,7 @@ func (t *timeline) IndexAndPrepareOne(ctx context.Context, statusID string, boos return false, nil } - preparable, err := t.prepareFunction(ctx, t.accountID, statusID) + preparable, err := t.prepareFunction(ctx, t.timelineID, statusID) if err != nil { return true, fmt.Errorf("IndexAndPrepareOne: error preparing: %w", err) } diff --git a/internal/timeline/index_test.go b/internal/timeline/index_test.go index 76b161171..f62c0a9c6 100644 --- a/internal/timeline/index_test.go +++ b/internal/timeline/index_test.go @@ -24,7 +24,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/processing" + tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" @@ -55,10 +55,10 @@ func (suite *IndexTestSuite) SetupTest() { suite.timeline = timeline.NewTimeline( context.Background(), suite.testAccounts["local_account_1"].ID, - processing.StatusGrabFunction(suite.db), - processing.StatusFilterFunction(suite.db, suite.filter), - processing.StatusPrepareFunction(suite.db, suite.tc), - processing.StatusSkipInsertFunction(), + tlprocessor.HomeTimelineGrab(&suite.state), + tlprocessor.HomeTimelineFilter(&suite.state, suite.filter), + tlprocessor.HomeTimelineStatusPrepare(&suite.state, suite.tc), + tlprocessor.SkipInsert(), ) } diff --git a/internal/timeline/manager.go b/internal/timeline/manager.go index f34cee787..b70e6cf82 100644 --- a/internal/timeline/manager.go +++ b/internal/timeline/manager.go @@ -32,10 +32,10 @@ const ( pruneLengthPrepared = 50 ) -// Manager abstracts functions for creating timelines for multiple accounts, and adding, removing, and fetching entries from those timelines. +// Manager abstracts functions for creating multiple timelines, and adding, removing, and fetching entries from those timelines. // // By the time a timelineable hits the manager interface, it should already have been filtered and it should be established that the item indeed -// belongs in the timeline of the given account ID. +// belongs in the given timeline. // // The manager makes a distinction between *indexed* items and *prepared* items. // @@ -45,33 +45,36 @@ const ( // Prepared items consist of the item's database ID, the time it was created, AND the apimodel representation of that item, for quick serialization. // Prepared items of course take up more memory than indexed items, so they should be regularly pruned if they're not being actively served. type Manager interface { - // IngestOne takes one timelineable and indexes it into the timeline for the given account ID, and then immediately prepares it for serving. + // IngestOne takes one timelineable and indexes it into the given timeline, and then immediately prepares it for serving. // This is useful in cases where we know the item will need to be shown at the top of a user's timeline immediately (eg., a new status is created). // // It should already be established before calling this function that the item actually belongs in the timeline! // // The returned bool indicates whether the item was actually put in the timeline. This could be false in cases where // a status is a boost, but a boost of the original status or the status itself already exists recently in the timeline. - IngestOne(ctx context.Context, accountID string, item Timelineable) (bool, error) + IngestOne(ctx context.Context, timelineID string, item Timelineable) (bool, error) - // GetTimeline returns limit n amount of prepared entries from the timeline of the given account ID, in descending chronological order. - GetTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]Preparable, error) + // GetTimeline returns limit n amount of prepared entries from the given timeline, in descending chronological order. + GetTimeline(ctx context.Context, timelineID string, maxID string, sinceID string, minID string, limit int, local bool) ([]Preparable, error) // GetIndexedLength returns the amount of items that have been indexed for the given account ID. - GetIndexedLength(ctx context.Context, accountID string) int + GetIndexedLength(ctx context.Context, timelineID string) int - // GetOldestIndexedID returns the id ID for the oldest item that we have indexed for the given account. + // GetOldestIndexedID returns the id ID for the oldest item that we have indexed for the given timeline. // Will be an empty string if nothing is (yet) indexed. - GetOldestIndexedID(ctx context.Context, accountID string) string + GetOldestIndexedID(ctx context.Context, timelineID string) string - // Remove removes one item from the timeline of the given timelineAccountID - Remove(ctx context.Context, accountID string, itemID string) (int, error) + // Remove removes one item from the given timeline. + Remove(ctx context.Context, timelineID string, itemID string) (int, error) + + // RemoveTimeline completely removes one timeline. + RemoveTimeline(ctx context.Context, timelineID string) error // WipeItemFromAllTimelines removes one item from the index and prepared items of all timelines WipeItemFromAllTimelines(ctx context.Context, itemID string) error - // WipeStatusesFromAccountID removes all items by the given accountID from the timelineAccountID's timelines. - WipeItemsFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error + // WipeStatusesFromAccountID removes all items by the given accountID from the given timeline. + WipeItemsFromAccountID(ctx context.Context, timelineID string, accountID string) error // Start starts hourly cleanup jobs for this timeline manager. Start() error @@ -83,7 +86,7 @@ type Manager interface { // NewManager returns a new timeline manager. func NewManager(grabFunction GrabFunction, filterFunction FilterFunction, prepareFunction PrepareFunction, skipInsertFunction SkipInsertFunction) Manager { return &manager{ - accountTimelines: sync.Map{}, + timelines: sync.Map{}, grabFunction: grabFunction, filterFunction: filterFunction, prepareFunction: prepareFunction, @@ -92,7 +95,7 @@ func NewManager(grabFunction GrabFunction, filterFunction FilterFunction, prepar } type manager struct { - accountTimelines sync.Map + timelines sync.Map grabFunction GrabFunction filterFunction FilterFunction prepareFunction PrepareFunction @@ -127,14 +130,14 @@ func (m *manager) Start() error { } if amountPruned := timeline.Prune(pruneLengthPrepared, pruneLengthIndexed); amountPruned > 0 { - log.WithField("accountID", timeline.AccountID()).Infof("pruned %d indexed and prepared items from timeline", amountPruned) + log.WithField("accountID", timeline.TimelineID()).Infof("pruned %d indexed and prepared items from timeline", amountPruned) } return true } // Execute the function for each timeline. - m.accountTimelines.Range(f) + m.timelines.Range(f) } }() @@ -145,8 +148,8 @@ func (m *manager) Stop() error { return nil } -func (m *manager) IngestOne(ctx context.Context, accountID string, item Timelineable) (bool, error) { - return m.getOrCreateTimeline(ctx, accountID).IndexAndPrepareOne( +func (m *manager) IngestOne(ctx context.Context, timelineID string, item Timelineable) (bool, error) { + return m.getOrCreateTimeline(ctx, timelineID).IndexAndPrepareOne( ctx, item.GetID(), item.GetBoostOfID(), @@ -155,27 +158,32 @@ func (m *manager) IngestOne(ctx context.Context, accountID string, item Timeline ) } -func (m *manager) Remove(ctx context.Context, accountID string, itemID string) (int, error) { - return m.getOrCreateTimeline(ctx, accountID).Remove(ctx, itemID) +func (m *manager) Remove(ctx context.Context, timelineID string, itemID string) (int, error) { + return m.getOrCreateTimeline(ctx, timelineID).Remove(ctx, itemID) } -func (m *manager) GetTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]Preparable, error) { - return m.getOrCreateTimeline(ctx, accountID).Get(ctx, limit, maxID, sinceID, minID, true) +func (m *manager) RemoveTimeline(ctx context.Context, timelineID string) error { + m.timelines.Delete(timelineID) + return nil } -func (m *manager) GetIndexedLength(ctx context.Context, accountID string) int { - return m.getOrCreateTimeline(ctx, accountID).Len() +func (m *manager) GetTimeline(ctx context.Context, timelineID string, maxID string, sinceID string, minID string, limit int, local bool) ([]Preparable, error) { + return m.getOrCreateTimeline(ctx, timelineID).Get(ctx, limit, maxID, sinceID, minID, true) } -func (m *manager) GetOldestIndexedID(ctx context.Context, accountID string) string { - return m.getOrCreateTimeline(ctx, accountID).OldestIndexedItemID() +func (m *manager) GetIndexedLength(ctx context.Context, timelineID string) int { + return m.getOrCreateTimeline(ctx, timelineID).Len() } -func (m *manager) WipeItemFromAllTimelines(ctx context.Context, statusID string) error { +func (m *manager) GetOldestIndexedID(ctx context.Context, timelineID string) string { + return m.getOrCreateTimeline(ctx, timelineID).OldestIndexedItemID() +} + +func (m *manager) WipeItemFromAllTimelines(ctx context.Context, itemID string) error { errors := gtserror.MultiError{} - m.accountTimelines.Range(func(_ any, v any) bool { - if _, err := v.(Timeline).Remove(ctx, statusID); err != nil { + m.timelines.Range(func(_ any, v any) bool { + if _, err := v.(Timeline).Remove(ctx, itemID); err != nil { errors.Append(err) } @@ -183,22 +191,21 @@ func (m *manager) WipeItemFromAllTimelines(ctx context.Context, statusID string) }) if len(errors) > 0 { - return fmt.Errorf("WipeItemFromAllTimelines: one or more errors wiping status %s: %w", statusID, errors.Combine()) + return fmt.Errorf("WipeItemFromAllTimelines: one or more errors wiping status %s: %w", itemID, errors.Combine()) } return nil } -func (m *manager) WipeItemsFromAccountID(ctx context.Context, timelineAccountID string, accountID string) error { - _, err := m.getOrCreateTimeline(ctx, timelineAccountID).RemoveAllByOrBoosting(ctx, accountID) +func (m *manager) WipeItemsFromAccountID(ctx context.Context, timelineID string, accountID string) error { + _, err := m.getOrCreateTimeline(ctx, timelineID).RemoveAllByOrBoosting(ctx, accountID) return err } -// getOrCreateTimeline returns a timeline for the given -// accountID. If a timeline does not yet exist in the -// manager's sync.Map, it will be created and stored. -func (m *manager) getOrCreateTimeline(ctx context.Context, accountID string) Timeline { - i, ok := m.accountTimelines.Load(accountID) +// getOrCreateTimeline returns a timeline with the given id, +// creating a new timeline with that id if necessary. +func (m *manager) getOrCreateTimeline(ctx context.Context, timelineID string) Timeline { + i, ok := m.timelines.Load(timelineID) if ok { // Timeline already existed in sync.Map. return i.(Timeline) //nolint:forcetypeassert @@ -206,8 +213,8 @@ func (m *manager) getOrCreateTimeline(ctx context.Context, accountID string) Tim // Timeline did not yet exist in sync.Map. // Create + store it. - timeline := NewTimeline(ctx, accountID, m.grabFunction, m.filterFunction, m.prepareFunction, m.skipInsertFunction) - m.accountTimelines.Store(accountID, timeline) + timeline := NewTimeline(ctx, timelineID, m.grabFunction, m.filterFunction, m.prepareFunction, m.skipInsertFunction) + m.timelines.Store(timelineID, timeline) return timeline } diff --git a/internal/timeline/manager_test.go b/internal/timeline/manager_test.go index cf1f5be2b..652708ccd 100644 --- a/internal/timeline/manager_test.go +++ b/internal/timeline/manager_test.go @@ -22,7 +22,7 @@ import ( "testing" "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/processing" + tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" @@ -50,10 +50,10 @@ func (suite *ManagerTestSuite) SetupTest() { testrig.StandardDBSetup(suite.db, nil) manager := timeline.NewManager( - processing.StatusGrabFunction(suite.db), - processing.StatusFilterFunction(suite.db, suite.filter), - processing.StatusPrepareFunction(suite.db, suite.tc), - processing.StatusSkipInsertFunction(), + tlprocessor.HomeTimelineGrab(&suite.state), + tlprocessor.HomeTimelineFilter(&suite.state, suite.filter), + tlprocessor.HomeTimelineStatusPrepare(&suite.state, suite.tc), + tlprocessor.SkipInsert(), ) suite.manager = manager } diff --git a/internal/timeline/prepare.go b/internal/timeline/prepare.go index cc014037b..8fbcb15a7 100644 --- a/internal/timeline/prepare.go +++ b/internal/timeline/prepare.go @@ -119,7 +119,7 @@ func (t *timeline) prepareXBetweenIDs(ctx context.Context, amount int, behindID } for e, entry := range toPrepare { - prepared, err := t.prepareFunction(ctx, t.accountID, entry.itemID) + prepared, err := t.prepareFunction(ctx, t.timelineID, entry.itemID) if err != nil { if errors.Is(err, db.ErrNoEntries) { // ErrNoEntries means something has been deleted, diff --git a/internal/timeline/prune_test.go b/internal/timeline/prune_test.go index 89164563b..d70e1eb91 100644 --- a/internal/timeline/prune_test.go +++ b/internal/timeline/prune_test.go @@ -24,7 +24,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/processing" + tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" @@ -55,10 +55,10 @@ func (suite *PruneTestSuite) SetupTest() { tl := timeline.NewTimeline( context.Background(), suite.testAccounts["local_account_1"].ID, - processing.StatusGrabFunction(suite.db), - processing.StatusFilterFunction(suite.db, suite.filter), - processing.StatusPrepareFunction(suite.db, suite.tc), - processing.StatusSkipInsertFunction(), + tlprocessor.HomeTimelineGrab(&suite.state), + tlprocessor.HomeTimelineFilter(&suite.state, suite.filter), + tlprocessor.HomeTimelineStatusPrepare(&suite.state, suite.tc), + tlprocessor.SkipInsert(), ) // put the status IDs in a determinate order since we can't trust a map to keep its order diff --git a/internal/timeline/remove.go b/internal/timeline/remove.go index e76913a2f..693c9f9b9 100644 --- a/internal/timeline/remove.go +++ b/internal/timeline/remove.go @@ -28,7 +28,7 @@ import ( func (t *timeline) Remove(ctx context.Context, statusID string) (int, error) { l := log.WithContext(ctx). WithFields(kv.Fields{ - {"accountTimeline", t.accountID}, + {"accountTimeline", t.timelineID}, {"statusID", statusID}, }...) @@ -64,7 +64,7 @@ func (t *timeline) RemoveAllByOrBoosting(ctx context.Context, accountID string) l := log. WithContext(ctx). WithFields(kv.Fields{ - {"accountTimeline", t.accountID}, + {"accountTimeline", t.timelineID}, {"accountID", accountID}, }...) diff --git a/internal/timeline/timeline.go b/internal/timeline/timeline.go index d3550e612..b973a3905 100644 --- a/internal/timeline/timeline.go +++ b/internal/timeline/timeline.go @@ -28,24 +28,24 @@ import ( // It should be provided to NewTimeline when the caller is creating a timeline // (of statuses, notifications, etc). // -// timelineAccountID: the owner of the timeline -// maxID: the maximum item ID desired. -// sinceID: the minimum item ID desired. -// minID: see sinceID -// limit: the maximum amount of items to be returned +// - timelineID: ID of the timeline. +// - maxID: the maximum item ID desired. +// - sinceID: the minimum item ID desired. +// - minID: see sinceID +// - limit: the maximum amount of items to be returned // // If an error is returned, the timeline will stop processing whatever request called GrabFunction, // and return the error. If no error is returned, but stop = true, this indicates to the caller of GrabFunction // that there are no more items to return, and processing should continue with the items already grabbed. -type GrabFunction func(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int) (items []Timelineable, stop bool, err error) +type GrabFunction func(ctx context.Context, timelineID string, maxID string, sinceID string, minID string, limit int) (items []Timelineable, stop bool, err error) // FilterFunction is used by a Timeline to filter whether or not a grabbed item should be indexed. -type FilterFunction func(ctx context.Context, timelineAccountID string, item Timelineable) (shouldIndex bool, err error) +type FilterFunction func(ctx context.Context, timelineID string, item Timelineable) (shouldIndex bool, err error) // PrepareFunction converts a Timelineable into a Preparable. // // For example, this might result in the converstion of a *gtsmodel.Status with the given itemID into a serializable *apimodel.Status. -type PrepareFunction func(ctx context.Context, timelineAccountID string, itemID string) (Preparable, error) +type PrepareFunction func(ctx context.Context, timelineID string, itemID string) (Preparable, error) // SkipInsertFunction indicates whether a new item about to be inserted in the prepared list should be skipped, // based on the item itself, the next item in the timeline, and the depth at which nextItem has been found in the list. @@ -88,8 +88,8 @@ type Timeline interface { INFO FUNCTIONS */ - // AccountID returns the id of the account this timeline belongs to. - AccountID() string + // TimelineID returns the id of this timeline. + TimelineID() string // Len returns the length of the item index at this point in time. Len() int @@ -130,19 +130,20 @@ type timeline struct { grabFunction GrabFunction filterFunction FilterFunction prepareFunction PrepareFunction - accountID string + timelineID string lastGot time.Time sync.Mutex } -func (t *timeline) AccountID() string { - return t.accountID +func (t *timeline) TimelineID() string { + return t.timelineID } -// NewTimeline returns a new Timeline for the given account ID +// NewTimeline returns a new Timeline with +// the given ID, using the given functions. func NewTimeline( ctx context.Context, - timelineAccountID string, + timelineID string, grabFunction GrabFunction, filterFunction FilterFunction, prepareFunction PrepareFunction, @@ -155,7 +156,7 @@ func NewTimeline( grabFunction: grabFunction, filterFunction: filterFunction, prepareFunction: prepareFunction, - accountID: timelineAccountID, + timelineID: timelineID, lastGot: time.Time{}, } } diff --git a/internal/timeline/timelines.go b/internal/timeline/timelines.go new file mode 100644 index 000000000..8291fef5e --- /dev/null +++ b/internal/timeline/timelines.go @@ -0,0 +1,37 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package timeline + +type Timelines struct { + // Home provides access to account home timelines. + Home Manager + + // List provides access to list timelines. + List Manager + + // prevent pass-by-value. + _ nocopy +} + +// nocopy when embedded will signal linter to +// error on pass-by-value of parent struct. +type nocopy struct{} + +func (*nocopy) Lock() {} + +func (*nocopy) Unlock() {} diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 23cc6f44e..297abd73e 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -30,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -75,6 +76,12 @@ func (suite *TransportTestSuite) SetupTest() { suite.storage = testrig.NewInMemoryStorage() suite.state.Storage = suite.storage + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) @@ -89,8 +96,6 @@ func (suite *TransportTestSuite) SetupTest() { suite.FailNow(err.Error()) } suite.transport = ts - - suite.NoError(suite.processor.Start()) } func (suite *TransportTestSuite) TearDownTest() { diff --git a/internal/typeutils/converter.go b/internal/typeutils/converter.go index 099414a10..00dbe26e8 100644 --- a/internal/typeutils/converter.go +++ b/internal/typeutils/converter.go @@ -92,6 +92,8 @@ type TypeConverter interface { ReportToAPIReport(ctx context.Context, r *gtsmodel.Report) (*apimodel.Report, error) // ReportToAdminAPIReport converts a gts model report into an admin view report, for serving at /api/v1/admin/reports ReportToAdminAPIReport(ctx context.Context, r *gtsmodel.Report, requestingAccount *gtsmodel.Account) (*apimodel.AdminReport, error) + // ListToAPIList converts one gts model list into an api model list, for serving at /api/v1/lists/{id} + ListToAPIList(ctx context.Context, l *gtsmodel.List) (*apimodel.List, error) /* INTERNAL (gts) MODEL TO FRONTEND (rss) MODEL diff --git a/internal/typeutils/converter_test.go b/internal/typeutils/converter_test.go index d92a30c13..a91e6a157 100644 --- a/internal/typeutils/converter_test.go +++ b/internal/typeutils/converter_test.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -515,6 +516,12 @@ func (suite *TypeUtilsTestSuite) TearDownTest() { // Useful when a test in the test suite needs to change some state. func (suite *TypeUtilsTestSuite) GetProcessor() *processing.Processor { testrig.StartWorkers(&suite.state) + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") transportController := testrig.NewTestTransportController(&suite.state, httpClient) mediaManager := testrig.NewTestMediaManager(&suite.state) diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 53c8af047..7d2056a4c 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -1142,6 +1142,14 @@ func (c *converter) ReportToAdminAPIReport(ctx context.Context, r *gtsmodel.Repo }, nil } +func (c *converter) ListToAPIList(ctx context.Context, l *gtsmodel.List) (*apimodel.List, error) { + return &apimodel.List{ + ID: l.ID, + Title: l.Title, + RepliesPolicy: string(l.RepliesPolicy), + }, nil +} + // convertAttachmentsToAPIAttachments will convert a slice of GTS model attachments to frontend API model attachments, falling back to IDs if no GTS models supplied. func (c *converter) convertAttachmentsToAPIAttachments(ctx context.Context, attachments []*gtsmodel.MediaAttachment, attachmentIDs []string) ([]apimodel.Attachment, error) { var errs gtserror.MultiError diff --git a/internal/validate/formvalidation.go b/internal/validate/formvalidation.go index 20d4aa782..f9328dc1f 100644 --- a/internal/validate/formvalidation.go +++ b/internal/validate/formvalidation.go @@ -45,6 +45,7 @@ const ( maximumEmojiCategoryLength = 64 maximumProfileFieldLength = 255 maximumProfileFields = 6 + maximumListTitleLength = 200 ) // NewPassword returns an error if the given password is not sufficiently strong, or nil if it's ok. @@ -257,3 +258,28 @@ func ProfileFields(fields []*gtsmodel.Field) error { return nil } + +// ListTitle validates the title of a new or updated List. +func ListTitle(title string) error { + if title == "" { + return fmt.Errorf("list title must be provided, and must be no more than %d chars", maximumListTitleLength) + } + + if length := len([]rune(title)); length > maximumListTitleLength { + return fmt.Errorf("list title length must be no more than %d chars, provided title was %d chars", maximumListTitleLength, length) + } + + return nil +} + +// ListRepliesPolicy validates the replies_policy of a new or updated list. +func ListRepliesPolicy(repliesPolicy gtsmodel.RepliesPolicy) error { + switch repliesPolicy { + case "", gtsmodel.RepliesPolicyFollowed, gtsmodel.RepliesPolicyList, gtsmodel.RepliesPolicyNone: + // No problem. + return nil + default: + // Uh oh. + return fmt.Errorf("list replies_policy must be either empty or one of 'followed', 'list', 'none'") + } +} diff --git a/test/envparsing.sh b/test/envparsing.sh index 688ff8d2c..8ceabde8f 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -39,6 +39,12 @@ EXPECT=$(cat <<"EOF" "follow-request-ttl": 1800000000000, "follow-sweep-freq": 60000000000, "follow-ttl": 1800000000000, + "list-entry-max-size": 2000, + "list-entry-sweep-freq": 60000000000, + "list-entry-ttl": 1800000000000, + "list-max-size": 2000, + "list-sweep-freq": 60000000000, + "list-ttl": 1800000000000, "media-max-size": 1000, "media-sweep-freq": 60000000000, "media-ttl": 1800000000000, diff --git a/testrig/config.go b/testrig/config.go index dea8ee641..aeac78e3f 100644 --- a/testrig/config.go +++ b/testrig/config.go @@ -33,7 +33,7 @@ func InitTestConfig() { } var testDefaults = config.Configuration{ - LogLevel: "info", + LogLevel: "trace", LogDbQueries: true, ApplicationName: "gotosocial", LandingPageUser: "", diff --git a/testrig/db.go b/testrig/db.go index d95c8f941..c169669d7 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -39,6 +39,8 @@ var testModels = []interface{}{ >smodel.EmailDomainBlock{}, >smodel.Follow{}, >smodel.FollowRequest{}, + >smodel.List{}, + >smodel.ListEntry{}, >smodel.MediaAttachment{}, >smodel.Mention{}, >smodel.Status{}, @@ -248,6 +250,18 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { } } + for _, v := range NewTestLists() { + if err := db.Put(ctx, v); err != nil { + log.Panic(nil, err) + } + } + + for _, v := range NewTestListEntries() { + if err := db.Put(ctx, v); err != nil { + log.Panic(nil, err) + } + } + for _, v := range NewTestNotifications() { if err := db.Put(ctx, v); err != nil { log.Panic(nil, err) diff --git a/testrig/testmodels.go b/testrig/testmodels.go index 15e204f85..c55b80e60 100644 --- a/testrig/testmodels.go +++ b/testrig/testmodels.go @@ -1961,6 +1961,38 @@ func NewTestFollows() map[string]*gtsmodel.Follow { } } +func NewTestLists() map[string]*gtsmodel.List { + return map[string]*gtsmodel.List{ + "local_account_1_list_1": { + ID: "01H0G8E4Q2J3FE3JDWJVWEDCD1", + CreatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"), + UpdatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"), + Title: "Cool Ass Posters From This Instance", + AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", + RepliesPolicy: gtsmodel.RepliesPolicyFollowed, + }, + } +} + +func NewTestListEntries() map[string]*gtsmodel.ListEntry { + return map[string]*gtsmodel.ListEntry{ + "local_account_1_list_1_entry_1": { + ID: "01H0G89MWVQE0M58VD2HQYMQWH", + CreatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"), + UpdatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"), + ListID: "01H0G8E4Q2J3FE3JDWJVWEDCD1", + FollowID: "01F8PYDCE8XE23GRE5DPZJDZDP", + }, + "local_account_1_list_1_entry_2": { + ID: "01H0G8FFM1AGQDRNGBGGX8CYJQ", + CreatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"), + UpdatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"), + ListID: "01H0G8E4Q2J3FE3JDWJVWEDCD1", + FollowID: "01F8PY8RHWRQZV038T4E8T9YK8", + }, + } +} + func NewTestBlocks() map[string]*gtsmodel.Block { return map[string]*gtsmodel.Block{ "local_account_2_block_remote_account_1": { diff --git a/testrig/util.go b/testrig/util.go index d7183df1c..4e52d12b5 100644 --- a/testrig/util.go +++ b/testrig/util.go @@ -20,6 +20,7 @@ package testrig import ( "bytes" "context" + "fmt" "io" "mime/multipart" "net/url" @@ -27,7 +28,11 @@ import ( "time" "github.com/superseriousbusiness/gotosocial/internal/messages" + tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/timeline" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" ) func StartWorkers(state *state.State) { @@ -47,6 +52,28 @@ func StopWorkers(state *state.State) { _ = state.Workers.Media.Stop() } +func StartTimelines(state *state.State, filter *visibility.Filter, typeConverter typeutils.TypeConverter) { + state.Timelines.Home = timeline.NewManager( + tlprocessor.HomeTimelineGrab(state), + tlprocessor.HomeTimelineFilter(state, filter), + tlprocessor.HomeTimelineStatusPrepare(state, typeConverter), + tlprocessor.SkipInsert(), + ) + if err := state.Timelines.Home.Start(); err != nil { + panic(fmt.Sprintf("error starting home timeline: %s", err)) + } + + state.Timelines.List = timeline.NewManager( + tlprocessor.ListTimelineGrab(state), + tlprocessor.ListTimelineFilter(state, filter), + tlprocessor.ListTimelineStatusPrepare(state, typeConverter), + tlprocessor.SkipInsert(), + ) + if err := state.Timelines.List.Start(); err != nil { + panic(fmt.Sprintf("error starting list timeline: %s", err)) + } +} + // CreateMultipartFormData is a handy function for taking a fieldname and a filename, and creating a multipart form bytes buffer // with the file contents set in the given fieldname. The extraFields param can be used to add extra FormFields to the request, as necessary. // The returned bytes.Buffer b can be used like so: