[feature] Add List functionality (#1802)

* start working on lists

* further list work

* test list db functions nicely

* more work on lists

* peepoopeepoo

* poke

* start list timeline func

* we're getting there lads

* couldn't be me working on stuff... could it?

* hook up handlers

* fiddling

* weeee

* woah

* screaming, pissing

* fix streaming being a whiny baby

* lint, small test fix, swagger

* tidying up, testing

* fucked! by the linter

* move timelines to state like a boss

* add timeline start to tests using state

* invalidate lists
This commit is contained in:
tobi 2023-05-25 10:37:38 +02:00 committed by GitHub
parent 282be6f26d
commit f5c004d67d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
123 changed files with 5654 additions and 970 deletions

View file

@ -32,7 +32,10 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/tracing"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"go.uber.org/automaxprocs/maxprocs"
"github.com/superseriousbusiness/gotosocial/internal/config"
@ -72,7 +75,6 @@ var Start action.GTSAction = func(ctx context.Context) error {
defer state.Caches.Stop()
// Initialize Tracing
if err := tracing.Initialize(); err != nil {
return fmt.Errorf("error initializing tracing: %w", err)
}
@ -110,36 +112,56 @@ var Start action.GTSAction = func(ctx context.Context) error {
state.Workers.Start()
defer state.Workers.Stop()
// build backend handlers
// Build handlers used in later initializations.
mediaManager := media.NewManager(&state)
oauthServer := oauth.New(ctx, dbService)
typeConverter := typeutils.NewConverter(dbService)
filter := visibility.NewFilter(&state)
federatingDB := federatingdb.New(&state, typeConverter)
transportController := transport.NewController(&state, federatingDB, &federation.Clock{}, client)
federator := federation.NewFederator(&state, federatingDB, transportController, typeConverter, mediaManager)
// decide whether to create a noop email sender (won't send emails) or a real one
// Decide whether to create a noop email
// sender (won't send emails) or a real one.
var emailSender email.Sender
if smtpHost := config.GetSMTPHost(); smtpHost != "" {
// host is defined so create a proper sender
// Host is defined; create a proper sender.
emailSender, err = email.NewSender()
if err != nil {
return fmt.Errorf("error creating email sender: %s", err)
}
} else {
// no host is defined so create a noop sender
// No host is defined; create a noop sender.
emailSender, err = email.NewNoopSender(nil)
if err != nil {
return fmt.Errorf("error creating noop email sender: %s", err)
}
}
// create the message processor using the other services we've created so far
processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender)
if err := processor.Start(); err != nil {
return fmt.Errorf("error creating processor: %s", err)
// Initialize timelines.
state.Timelines.Home = timeline.NewManager(
tlprocessor.HomeTimelineGrab(&state),
tlprocessor.HomeTimelineFilter(&state, filter),
tlprocessor.HomeTimelineStatusPrepare(&state, typeConverter),
tlprocessor.SkipInsert(),
)
if err := state.Timelines.Home.Start(); err != nil {
return fmt.Errorf("error starting home timeline: %s", err)
}
state.Timelines.List = timeline.NewManager(
tlprocessor.ListTimelineGrab(&state),
tlprocessor.ListTimelineFilter(&state, filter),
tlprocessor.ListTimelineStatusPrepare(&state, typeConverter),
tlprocessor.SkipInsert(),
)
if err := state.Timelines.List.Start(); err != nil {
return fmt.Errorf("error starting list timeline: %s", err)
}
// Create the processor using all the other services we've created so far.
processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender)
// Set state client / federator worker enqueue functions
state.Workers.EnqueueClientAPI = processor.EnqueueClientAPI
state.Workers.EnqueueFederator = processor.EnqueueFederator

View file

@ -38,9 +38,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/tracing"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/internal/web"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -89,11 +92,31 @@ var Start action.GTSAction = func(ctx context.Context) error {
federator := testrig.NewTestFederator(&state, transportController, mediaManager)
emailSender := testrig.NewEmailSender("./web/template/", nil)
typeConverter := testrig.NewTestTypeConverter(state.DB)
filter := visibility.NewFilter(&state)
// Initialize timelines.
state.Timelines.Home = timeline.NewManager(
tlprocessor.HomeTimelineGrab(&state),
tlprocessor.HomeTimelineFilter(&state, filter),
tlprocessor.HomeTimelineStatusPrepare(&state, typeConverter),
tlprocessor.SkipInsert(),
)
if err := state.Timelines.Home.Start(); err != nil {
return fmt.Errorf("error starting home timeline: %s", err)
}
state.Timelines.List = timeline.NewManager(
tlprocessor.ListTimelineGrab(&state),
tlprocessor.ListTimelineFilter(&state, filter),
tlprocessor.ListTimelineStatusPrepare(&state, typeConverter),
tlprocessor.SkipInsert(),
)
if err := state.Timelines.List.Start(); err != nil {
return fmt.Errorf("error starting list timeline: %s", err)
}
processor := testrig.NewTestProcessor(&state, federator, emailSender, mediaManager)
if err := processor.Start(); err != nil {
return fmt.Errorf("error starting processor: %s", err)
}
/*
HTTP router initialization

View file

@ -1635,6 +1635,28 @@ definitions:
type: object
x-go-name: InstanceV2Users
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model
list:
properties:
id:
description: The ID of the list.
type: string
x-go-name: ID
replies_policy:
description: |-
RepliesPolicy for this list.
followed = Show replies to any followed user
list = Show replies to members of the list
none = Show replies to no one
type: string
x-go-name: RepliesPolicy
title:
description: The user-defined title of the list.
type: string
x-go-name: Title
title: List represents a user-created list of accounts that the user follows.
type: object
x-go-name: List
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model
mediaDimensions:
properties:
aspect:
@ -2881,6 +2903,40 @@ paths:
summary: See accounts followed by given account id.
tags:
- accounts
/api/v1/accounts/{id}/lists:
get:
operationId: accountLists
parameters:
- description: Account ID.
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: Array of all lists containing this account.
schema:
items:
$ref: '#/definitions/list'
type: array
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: See all lists of yours that contain requested account.
tags:
- accounts
/api/v1/accounts/{id}/statuses:
get:
description: The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer).
@ -3211,7 +3267,7 @@ paths:
name: id
required: true
type: string
- description: Type of action to be taken (`disable`, `silence`, or `suspend`).
- description: Type of action to be taken, currently only supports `suspend`.
in: formData
name: type
required: true
@ -4453,6 +4509,343 @@ paths:
description: internal server error
tags:
- instance
/api/v1/list:
post:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: listCreate
parameters:
- description: Title of this list.
example: Cool People
in: formData
name: title
required: true
type: string
x-go-name: Title
- default: list
description: |-
RepliesPolicy for this list.
followed = Show replies to any followed user
list = Show replies to members of the list
none = Show replies to no one
example: list
in: formData
name: replies_policy
type: string
x-go-name: RepliesPolicy
produces:
- application/json
responses:
"200":
description: The newly created list.
schema:
$ref: '#/definitions/list'
"400":
description: bad request
"401":
description: unauthorized
"403":
description: forbidden
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:lists
summary: Create a new list.
tags:
- lists
put:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: listUpdate
parameters:
- description: ID of the list
example: Cool People
in: path
name: id
required: true
type: string
x-go-name: Title
- description: Title of this list.
example: Cool People
in: formData
name: title
type: string
x-go-name: RepliesPolicy
- description: |-
RepliesPolicy for this list.
followed = Show replies to any followed user
list = Show replies to members of the list
none = Show replies to no one
example: list
in: formData
name: replies_policy
type: string
produces:
- application/json
responses:
"200":
description: The newly updated list.
schema:
$ref: '#/definitions/list'
"400":
description: bad request
"401":
description: unauthorized
"403":
description: forbidden
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:lists
summary: Update an existing list.
tags:
- lists
/api/v1/list/{id}:
delete:
operationId: listDelete
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: list deleted
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:lists
summary: Delete a single list with the given ID.
tags:
- lists
get:
operationId: list
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: Requested list.
schema:
$ref: '#/definitions/list'
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Get a single list with the given ID.
tags:
- lists
/api/v1/list/{id}/accounts:
delete:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: removeListAccounts
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
- description: Array of accountIDs to modify. Each accountID must correspond to an account that the requesting account follows.
in: formData
items:
type: string
name: account_ids
required: true
type: array
produces:
- application/json
responses:
"200":
description: list accounts updated
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Remove one or more accounts from the given list.
tags:
- lists
get:
description: |-
The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline.
Example:
```
<https://example.org/api/v1/list/01H0W619198FX7J54NF7EH1NG2/accounts?limit=20&max_id=01FC3GSQ8A3MMJ43BPZSGEG29M>; rel="next", <https://example.org/api/v1/list/01H0W619198FX7J54NF7EH1NG2/accounts?limit=20&min_id=01FC3KJW2GYXSDDRA6RWNDM46M>; rel="prev"
````
operationId: listAccounts
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
- description: Return only list entries *OLDER* than the given max ID. The account from the list entry with the specified ID will not be included in the response.
in: query
name: max_id
type: string
- description: Return only list entries *NEWER* than the given since ID. The account from the list entry with the specified ID will not be included in the response.
in: query
name: since_id
type: string
- description: Return only list entries *IMMEDIATELY NEWER* than the given min ID. The account from the list entry with the specified ID will not be included in the response.
in: query
name: min_id
type: string
- default: 20
description: Number of accounts to return.
in: query
name: limit
type: integer
produces:
- application/json
responses:
"200":
description: Array of accounts.
headers:
Link:
description: Links to the next and previous queries.
type: string
schema:
items:
$ref: '#/definitions/account'
type: array
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Page through accounts in this list.
tags:
- lists
post:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: addListAccounts
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
- description: Array of accountIDs to modify. Each accountID must correspond to an account that the requesting account follows.
in: formData
items:
type: string
name: account_ids
required: true
type: array
produces:
- application/json
responses:
"200":
description: list accounts updated
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Add one or more accounts to the given list.
tags:
- lists
/api/v1/lists:
get:
operationId: lists
produces:
- application/json
responses:
"200":
description: Array of all lists owned by the requesting user.
schema:
items:
$ref: '#/definitions/list'
type: array
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Get all lists for owned by authorized user.
tags:
- lists
/api/v1/media/{id}:
get:
operationId: mediaGet
@ -5579,6 +5972,18 @@ paths:
name: stream
required: true
type: string
- description: |-
ID of the list to subscribe to.
Only used if stream type is 'list'.
in: query
name: list
type: string
- description: |-
Name of the tag to subscribe to.
Only used if stream type is 'hashtag' or 'hashtag:local'.
in: query
name: tag
type: string
produces:
- application/json
responses:
@ -5696,6 +6101,65 @@ paths:
summary: See statuses/posts by accounts you follow.
tags:
- timelines
/api/v1/timelines/list/{id}:
get:
description: |-
The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer).
The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline.
Example:
```
<https://example.org/api/v1/timelines/list/01H0W619198FX7J54NF7EH1NG2?limit=20&max_id=01FC3GSQ8A3MMJ43BPZSGEG29M>; rel="next", <https://example.org/api/v1/timelines/list/01H0W619198FX7J54NF7EH1NG2?limit=20&min_id=01FC3KJW2GYXSDDRA6RWNDM46M>; rel="prev"
````
operationId: listTimeline
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
- description: Return only statuses *OLDER* than the given max status ID. The status with the specified ID will not be included in the response.
in: query
name: max_id
type: string
- description: Return only statuses *NEWER* than the given since status ID. The status with the specified ID will not be included in the response.
in: query
name: since_id
type: string
- description: Return only statuses *NEWER* than the given since status ID. The status with the specified ID will not be included in the response.
in: query
name: min_id
type: string
- default: 20
description: Number of statuses to return.
in: query
name: limit
type: integer
produces:
- application/json
responses:
"200":
description: Array of statuses.
headers:
Link:
description: Links to the next and previous queries.
type: string
schema:
items:
$ref: '#/definitions/status'
type: array
"400":
description: bad request
"401":
description: unauthorized
security:
- OAuth2 Bearer:
- read:lists
summary: See statuses/posts from the given list timeline.
tags:
- timelines
/api/v1/timelines/public:
get:
description: |-
@ -5980,6 +6444,7 @@ securityDefinitions:
read:custom_emojis: grant read access to custom_emojis
read:favourites: grant read access to favourites
read:follows: grant read access to follows
read:lists: grant read access to lists
read:media: grant read access to media
read:notifications: grants read access to notifications
read:search: grant read access to searches
@ -5990,6 +6455,7 @@ securityDefinitions:
write:accounts: grants write access to accounts
write:blocks: grants write access to blocks
write:follows: grants write access to follows
write:lists: grants write access to lists
write:media: grants write access to media
write:statuses: grants write access to statuses
write:user: grants write access to user-level info

View file

@ -37,6 +37,7 @@
// read:custom_emojis: grant read access to custom_emojis
// read:favourites: grant read access to favourites
// read:follows: grant read access to follows
// read:lists: grant read access to lists
// read:media: grant read access to media
// read:search: grant read access to searches
// read:statuses: grants read access to statuses
@ -47,6 +48,7 @@
// write:accounts: grants write access to accounts
// write:blocks: grants write access to blocks
// write:follows: grants write access to follows
// write:lists: grants write access to lists
// write:media: grants write access to media
// write:statuses: grants write access to statuses
// write:user: grants write access to user-level info

View file

@ -289,6 +289,14 @@ cache:
follow-request-ttl: "30m"
follow-request-sweep-freq: "1m"
list-max-size: 2000
list-ttl: "30m"
list-sweep-freq: "1m"
list-entry-max-size: 2000
list-entry-ttl: "30m"
list-entry-sweep-freq: "1m"
media-max-size: 1000
media-ttl: "30m"
media-sweep-freq: "1m"

View file

@ -36,6 +36,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -74,9 +75,14 @@ func (suite *EmojiGetTestSuite) SetupTest() {
suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
@ -86,8 +92,6 @@ func (suite *EmojiGetTestSuite) SetupTest() {
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.signatureCheck = middleware.SignatureCheck(suite.db.IsURIBlocked)
suite.NoError(suite.processor.Start())
}
func (suite *EmojiGetTestSuite) TearDownTest() {

View file

@ -89,7 +89,6 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -190,7 +189,6 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -296,7 +294,6 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -425,7 +422,6 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()

View file

@ -106,7 +106,6 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -181,7 +180,6 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()

View file

@ -106,7 +106,6 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -171,7 +170,6 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()

View file

@ -31,6 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -83,6 +84,13 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
@ -94,8 +102,6 @@ func (suite *UserStandardTestSuite) SetupTest() {
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.signatureCheck = middleware.SignatureCheck(suite.db.IsURIBlocked)
suite.NoError(suite.processor.Start())
}
func (suite *UserStandardTestSuite) TearDownTest() {

View file

@ -36,6 +36,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -86,6 +87,12 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
@ -94,8 +101,6 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.accountsModule = accounts.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *AccountStandardTestSuite) TearDownTest() {

View file

@ -70,6 +70,8 @@ const (
UnblockPath = BasePathWithID + "/unblock"
// DeleteAccountPath is for deleting one's account via the API
DeleteAccountPath = BasePath + "/delete"
// ListsPath is for seeing which lists an account is.
ListsPath = BasePathWithID + "/lists"
)
type Module struct {
@ -115,4 +117,7 @@ func (m *Module) Route(attachHandler func(method string, path string, f ...gin.H
// block or unblock account
attachHandler(http.MethodPost, BlockPath, m.AccountBlockPOSTHandler)
attachHandler(http.MethodPost, UnblockPath, m.AccountUnblockPOSTHandler)
// account lists
attachHandler(http.MethodGet, ListsPath, m.AccountListsGETHandler)
}

View file

@ -0,0 +1,97 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package accounts
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// AccountListsGETHandler swagger:operation GET /api/v1/accounts/{id}/lists accountLists
//
// See all lists of yours that contain requested account.
//
// ---
// tags:
// - accounts
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: Account ID.
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// name: lists
// description: Array of all lists containing this account.
// schema:
// type: array
// items:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) AccountListsGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, false, false, false, false)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetAcctID := c.Param(IDKey)
if targetAcctID == "" {
err := errors.New("no account id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
lists, errWithCode := m.processor.Account().ListsGet(c.Request.Context(), authed.Account, targetAcctID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, lists)
}

View file

@ -0,0 +1,103 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package accounts_test
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type ListsTestSuite struct {
AccountStandardTestSuite
}
func (suite *ListsTestSuite) getLists(targetAccountID string, expectedHTTPStatus int, expectedBody string) []*apimodel.List {
var (
recorder = httptest.NewRecorder()
ctx, _ = testrig.CreateGinTestContext(recorder, nil)
request = httptest.NewRequest(http.MethodGet, "http://localhost:8080/api/v1/accounts/"+targetAccountID+"/lists", nil)
)
// Set up the test context.
ctx.Request = request
ctx.AddParam("id", targetAccountID)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
// Trigger the handler.
suite.accountsModule.AccountListsGETHandler(ctx)
// Read the result.
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
if err != nil {
suite.FailNow(err.Error())
}
errs := gtserror.MultiError{}
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
}
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
}
if err := errs.Combine(); err != nil {
suite.FailNow("", "%v (body %s)", err, string(b))
}
// Return list response.
resp := new([]*apimodel.List)
if err := json.Unmarshal(b, resp); err != nil {
suite.FailNow(err.Error())
}
return *resp
}
func (suite *ListsTestSuite) TestGetListsHit() {
targetAccount := suite.testAccounts["admin_account"]
suite.getLists(targetAccount.ID, http.StatusOK, `[{"id":"01H0G8E4Q2J3FE3JDWJVWEDCD1","title":"Cool Ass Posters From This Instance","replies_policy":"followed"}]`)
}
func (suite *ListsTestSuite) TestGetListsNoHit() {
targetAccount := suite.testAccounts["remote_account_1"]
suite.getLists(targetAccount.ID, http.StatusOK, `[]`)
}
func TestListsTestSuite(t *testing.T) {
suite.Run(t, new(ListsTestSuite))
}

View file

@ -36,6 +36,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -92,6 +93,12 @@ func (suite *AdminStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)

View file

@ -42,6 +42,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -98,6 +99,13 @@ func (suite *BookmarkTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -107,8 +115,6 @@ func (suite *BookmarkTestSuite) SetupTest() {
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.statusModule = statuses.New(suite.processor)
suite.bookmarkModule = bookmarks.New(suite.processor)
suite.NoError(suite.processor.Start())
}
func (suite *BookmarkTestSuite) TearDownTest() {

View file

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -82,6 +83,13 @@ func (suite *FavouritesStandardTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -90,8 +98,6 @@ func (suite *FavouritesStandardTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.favModule = favourites.New(suite.processor)
suite.NoError(suite.processor.Start())
}
func (suite *FavouritesStandardTestSuite) TearDownTest() {

View file

@ -128,7 +128,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) {
limit = int(i)
}
resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit)
resp, errWithCode := m.processor.Timeline().FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -35,6 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -83,6 +84,12 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
@ -90,8 +97,6 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
suite.followRequestModule = followrequests.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *FollowRequestStandardTestSuite) TearDownTest() {

View file

@ -35,6 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -85,6 +86,12 @@ func (suite *InstanceStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)

View file

@ -25,8 +25,15 @@ import (
)
const (
IDKey = "id"
// BasePath is the base path for serving the lists API, minus the 'api' prefix
BasePath = "/v1/lists"
BasePathWithID = BasePath + "/:" + IDKey
AccountsPath = BasePathWithID + "/accounts"
MaxIDKey = "max_id"
LimitKey = "limit"
SinceIDKey = "since_id"
MinIDKey = "min_id"
)
type Module struct {
@ -40,5 +47,15 @@ func New(processor *processing.Processor) *Module {
}
func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
// create / get / update / delete lists
attachHandler(http.MethodPost, BasePath, m.ListCreatePOSTHandler)
attachHandler(http.MethodGet, BasePath, m.ListsGETHandler)
attachHandler(http.MethodGet, BasePathWithID, m.ListGETHandler)
attachHandler(http.MethodPut, BasePathWithID, m.ListUpdatePUTHandler)
attachHandler(http.MethodDelete, BasePathWithID, m.ListDELETEHandler)
// get / add / remove list accounts
attachHandler(http.MethodGet, AccountsPath, m.ListAccountsGETHandler)
attachHandler(http.MethodPost, AccountsPath, m.ListAccountsPOSTHandler)
attachHandler(http.MethodDelete, AccountsPath, m.ListAccountsDELETEHandler)
}

View file

@ -0,0 +1,156 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListAccountsGETHandler swagger:operation GET /api/v1/list/{id}/accounts listAccounts
//
// Page through accounts in this list.
//
// The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline.
//
// Example:
//
// ```
// <https://example.org/api/v1/list/01H0W619198FX7J54NF7EH1NG2/accounts?limit=20&max_id=01FC3GSQ8A3MMJ43BPZSGEG29M>; rel="next", <https://example.org/api/v1/list/01H0W619198FX7J54NF7EH1NG2/accounts?limit=20&min_id=01FC3KJW2GYXSDDRA6RWNDM46M>; rel="prev"
// ````
//
// ---
// tags:
// - lists
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: max_id
// type: string
// description: >-
// Return only list entries *OLDER* than the given max ID.
// The account from the list entry with the specified ID will not be included in the response.
// in: query
// required: false
// -
// name: since_id
// type: string
// description: >-
// Return only list entries *NEWER* than the given since ID.
// The account from the list entry with the specified ID will not be included in the response.
// in: query
// -
// name: min_id
// type: string
// description: >-
// Return only list entries *IMMEDIATELY NEWER* than the given min ID.
// The account from the list entry with the specified ID will not be included in the response.
// in: query
// required: false
// -
// name: limit
// type: integer
// description: Number of accounts to return.
// default: 20
// in: query
// required: false
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// headers:
// Link:
// type: string
// description: Links to the next and previous queries.
// name: accounts
// description: Array of accounts.
// schema:
// type: array
// items:
// "$ref": "#/definitions/account"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListAccountsGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
resp, errWithCode := m.processor.List().GetListAccounts(
c.Request.Context(),
authed.Account,
targetListID,
c.Query(MaxIDKey),
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if resp.LinkHeader != "" {
c.Header("Link", resp.LinkHeader)
}
c.JSON(http.StatusOK, resp.Items)
}

View file

@ -0,0 +1,120 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListAccountsPOSTHandler swagger:operation POST /api/v1/list/{id}/accounts addListAccounts
//
// Add one or more accounts to the given list.
//
// ---
// tags:
// - lists
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: account_ids
// type: array
// items:
// type: string
// description: >-
// Array of accountIDs to modify.
// Each accountID must correspond to an account
// that the requesting account follows.
// in: formData
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// description: list accounts updated
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListAccountsPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.ListAccountsChangeRequest{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if len(form.AccountIDs) == 0 {
err := errors.New("no account IDs given")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if errWithCode := m.processor.List().AddToList(c.Request.Context(), authed.Account, targetListID, form.AccountIDs); errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, gin.H{})
}

View file

@ -0,0 +1,120 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListAccountsDELETEHandler swagger:operation DELETE /api/v1/list/{id}/accounts removeListAccounts
//
// Remove one or more accounts from the given list.
//
// ---
// tags:
// - lists
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: account_ids
// type: array
// items:
// type: string
// description: >-
// Array of accountIDs to modify.
// Each accountID must correspond to an account
// that the requesting account follows.
// in: formData
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// description: list accounts updated
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListAccountsDELETEHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.ListAccountsChangeRequest{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if len(form.AccountIDs) == 0 {
err := errors.New("no account IDs given")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if errWithCode := m.processor.List().RemoveFromList(c.Request.Context(), authed.Account, targetListID, form.AccountIDs); errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, gin.H{})
}

View file

@ -0,0 +1,106 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/validate"
)
// ListCreatePOSTHandler swagger:operation POST /api/v1/list listCreate
//
// Create a new list.
//
// ---
// tags:
// - lists
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - write:lists
//
// responses:
// '200':
// description: "The newly created list."
// schema:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListCreatePOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.ListCreateRequest{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if err := validate.ListTitle(form.Title); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
repliesPolicy := gtsmodel.RepliesPolicy(strings.ToLower(form.RepliesPolicy))
if err := validate.ListRepliesPolicy(repliesPolicy); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
apiList, errWithCode := m.processor.List().Create(c.Request.Context(), authed.Account, form.Title, repliesPolicy)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, apiList)
}

View file

@ -0,0 +1,91 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListDELETEHandler swagger:operation DELETE /api/v1/list/{id} listDelete
//
// Delete a single list with the given ID.
//
// ---
// tags:
// - lists
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - write:lists
//
// responses:
// '200':
// description: list deleted
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListDELETEHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if errWithCode := m.processor.List().Delete(c.Request.Context(), authed.Account, targetListID); errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, gin.H{})
}

View file

@ -0,0 +1,95 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListGETHandler swagger:operation GET /api/v1/list/{id} list
//
// Get a single list with the given ID.
//
// ---
// tags:
// - lists
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// name: list
// description: Requested list.
// schema:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
resp, errWithCode := m.processor.List().Get(c.Request.Context(), authed.Account, targetListID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, resp)
}

View file

@ -26,9 +26,42 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListsGETHandler returns a list of lists created by/for the authed account
// ListsGETHandler swagger:operation GET /api/v1/lists lists
//
// Get all lists for owned by authorized user.
//
// ---
// tags:
// - lists
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// name: lists
// description: Array of all lists owned by the requesting user.
// schema:
// type: array
// items:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListsGETHandler(c *gin.Context) {
if _, err := oauth.Authed(c, true, true, true, true); err != nil {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
@ -38,6 +71,11 @@ func (m *Module) ListsGETHandler(c *gin.Context) {
return
}
// todo: implement this; currently it's a no-op
c.JSON(http.StatusOK, []string{})
lists, errWithCode := m.processor.List().GetAll(c.Request.Context(), authed.Account)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, lists)
}

View file

@ -0,0 +1,152 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"strings"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/validate"
)
// ListUpdatePUTHandler swagger:operation PUT /api/v1/list listUpdate
//
// Update an existing list.
//
// ---
// tags:
// - lists
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: title
// type: string
// description: Title of this list.
// in: formData
// example: Cool People
// -
// name: replies_policy
// type: string
// description: |-
// RepliesPolicy for this list.
// followed = Show replies to any followed user
// list = Show replies to members of the list
// none = Show replies to no one
// in: formData
// example: list
//
// security:
// - OAuth2 Bearer:
// - write:lists
//
// responses:
// '200':
// description: "The newly updated list."
// schema:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListUpdatePUTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.ListUpdateRequest{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if form.Title != nil {
if err := validate.ListTitle(*form.Title); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
}
var repliesPolicy *gtsmodel.RepliesPolicy
if form.RepliesPolicy != nil {
rp := gtsmodel.RepliesPolicy(strings.ToLower(*form.RepliesPolicy))
if err := validate.ListRepliesPolicy(rp); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
repliesPolicy = &rp
}
if form.Title == nil && repliesPolicy == nil {
err = errors.New("neither title nor replies_policy was set; nothing to update")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
apiList, errWithCode := m.processor.List().Update(c.Request.Context(), authed.Account, targetListID, form.Title, repliesPolicy)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, apiList)
}

View file

@ -44,6 +44,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -90,6 +91,13 @@ func (suite *MediaCreateTestSuite) SetupSuite() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)

View file

@ -42,6 +42,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -87,6 +88,13 @@ func (suite *MediaUpdateTestSuite) SetupSuite() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)

View file

@ -77,7 +77,7 @@ func (m *Module) NotificationGETHandler(c *gin.Context) {
return
}
resp, errWithCode := m.processor.NotificationGet(c.Request.Context(), authed.Account, targetNotifID)
resp, errWithCode := m.processor.Timeline().NotificationGet(c.Request.Context(), authed.Account, targetNotifID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -69,7 +69,7 @@ func (m *Module) NotificationsClearPOSTHandler(c *gin.Context) {
return
}
errWithCode := m.processor.NotificationsClear(c.Request.Context(), authed)
errWithCode := m.processor.Timeline().NotificationsClear(c.Request.Context(), authed)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -138,7 +138,7 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) {
limit = int(i)
}
resp, errWithCode := m.processor.NotificationsGet(
resp, errWithCode := m.processor.Timeline().NotificationsGet(
c.Request.Context(),
authed,
c.Query(MaxIDKey),

View file

@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -77,6 +78,12 @@ func (suite *ReportsStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
@ -85,8 +92,6 @@ func (suite *ReportsStandardTestSuite) SetupTest() {
suite.reportsModule = reports.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *ReportsStandardTestSuite) TearDownTest() {

View file

@ -35,6 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -81,6 +82,12 @@ func (suite *SearchStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
@ -89,8 +96,6 @@ func (suite *SearchStandardTestSuite) SetupTest() {
suite.searchModule = search.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *SearchStandardTestSuite) TearDownTest() {

View file

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -83,6 +84,12 @@ func (suite *StatusStandardTestSuite) SetupTest() {
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -91,8 +98,6 @@ func (suite *StatusStandardTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.statusModule = statuses.New(suite.processor)
suite.NoError(suite.processor.Start())
}
func (suite *StatusStandardTestSuite) TearDownTest() {

View file

@ -82,6 +82,20 @@ import (
// `direct`: receive updates for direct messages.
// in: query
// required: true
// -
// name: list
// type: string
// description: |-
// ID of the list to subscribe to.
// Only used if stream type is 'list'.
// in: query
// -
// name: tag
// type: string
// description: |-
// Name of the tag to subscribe to.
// Only used if stream type is 'hashtag' or 'hashtag:local'.
// in: query
//
// security:
// - OAuth2 Bearer:
@ -164,8 +178,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}
// Get the initial stream type, if there is one.
// streamType will be an empty string if one wasn't supplied. Open() will deal with this
// By appending other query params to the streamType,
// we can allow for streaming for specific list IDs
// or hashtags.
streamType := c.Query(StreamQueryKey)
if list := c.Query(StreamListKey); list != "" {
streamType += ":" + list
} else if tag := c.Query(StreamTagKey); tag != "" {
streamType += ":" + tag
}
stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
@ -240,28 +262,41 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// If the message contains 'stream' and 'type' fields, we can
// update the set of timelines that are subscribed for events.
// everything else is ignored.
action := msg["type"]
streamType := msg["stream"]
// Ignore if the streamType is unknown (or missing), so a bad
// client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, streamType) {
l.Warnf("Unknown 'stream' field: %v", msg)
updateType, ok := msg["type"]
if !ok {
l.Warn("'type' field not provided")
continue
}
switch action {
updateStream, ok := msg["stream"]
if !ok {
l.Warn("'stream' field not provided")
continue
}
// Ignore if the updateStreamType is unknown (or missing),
// so a bad client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
l.Warnf("unknown 'stream' field: %v", msg)
continue
}
updateList, ok := msg["list"]
if ok {
updateStream += ":" + updateList
}
switch updateType {
case "subscribe":
stream.Lock()
stream.Timelines[streamType] = true
stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
delete(stream.Timelines, streamType)
delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
l.Warnf("Invalid 'type' field: %v", msg)
l.Warnf("invalid 'type' field: %v", msg)
}
}
}()
@ -276,7 +311,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
case msg := <-stream.Messages:
l.Tracef("sending message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Errorf("error writing json to websocket: %v", err)
l.Debugf("error writing json to websocket: %v", err)
return
}
@ -290,7 +325,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
websocket.PingMessage,
[]byte{},
); err != nil {
l.Errorf("error writing ping to websocket: %v", err)
l.Debugf("error writing ping to websocket: %v", err)
return
}
}

View file

@ -27,17 +27,12 @@ import (
)
const (
// BasePath is the path for the streaming api, minus the 'api' prefix
BasePath = "/v1/streaming"
// StreamQueryKey is the query key for the type of stream being requested
StreamQueryKey = "stream"
// AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests.
AccessTokenQueryKey = "access_token"
// AccessTokenHeader is the header for an oauth access token that can be passed in streaming requests instead of AccessTokenQueryKey
//nolint:gosec
AccessTokenHeader = "Sec-Websocket-Protocol"
BasePath = "/v1/streaming" // path for the streaming api, minus the 'api' prefix
StreamQueryKey = "stream" // type of stream being requested
StreamListKey = "list" // id of list being requested
StreamTagKey = "tag" // name of tag being requested
AccessTokenQueryKey = "access_token" // oauth access token
AccessTokenHeader = "Sec-Websocket-Protocol" //nolint:gosec
)
type Module struct {

View file

@ -41,6 +41,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -94,6 +95,13 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -102,7 +110,6 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.streamingModule = streaming.New(suite.processor, 1, 4096)
suite.NoError(suite.processor.Start())
}
func (suite *StreamingTestSuite) TearDownTest() {

View file

@ -18,9 +18,7 @@
package timelines
import (
"fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
@ -120,49 +118,27 @@ func (m *Module) HomeTimelineGETHandler(c *gin.Context) {
return
}
maxID := ""
maxIDString := c.Query(MaxIDKey)
if maxIDString != "" {
maxID = maxIDString
}
sinceID := ""
sinceIDString := c.Query(SinceIDKey)
if sinceIDString != "" {
sinceID = sinceIDString
}
minID := ""
minIDString := c.Query(MinIDKey)
if minIDString != "" {
minID = minIDString
}
limit := 20
limitString := c.Query(LimitKey)
if limitString != "" {
i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
limit = int(i)
}
local := false
localString := c.Query(LocalKey)
if localString != "" {
i, err := strconv.ParseBool(localString)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LocalKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
local, errWithCode := apiutil.ParseLocal(c.Query(apiutil.LocalKey), false)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
local = i
}
resp, errWithCode := m.processor.HomeTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local)
resp, errWithCode := m.processor.Timeline().HomeTimelineGet(
c.Request.Context(),
authed,
c.Query(MaxIDKey),
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
local,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -0,0 +1,152 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timelines
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListTimelineGETHandler swagger:operation GET /api/v1/timelines/list/{id} listTimeline
//
// See statuses/posts from the given list timeline.
//
// The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer).
//
// The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline.
//
// Example:
//
// ```
// <https://example.org/api/v1/timelines/list/01H0W619198FX7J54NF7EH1NG2?limit=20&max_id=01FC3GSQ8A3MMJ43BPZSGEG29M>; rel="next", <https://example.org/api/v1/timelines/list/01H0W619198FX7J54NF7EH1NG2?limit=20&min_id=01FC3KJW2GYXSDDRA6RWNDM46M>; rel="prev"
// ````
//
// ---
// tags:
// - timelines
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: max_id
// type: string
// description: >-
// Return only statuses *OLDER* than the given max status ID.
// The status with the specified ID will not be included in the response.
// in: query
// required: false
// -
// name: since_id
// type: string
// description: >-
// Return only statuses *NEWER* than the given since status ID.
// The status with the specified ID will not be included in the response.
// in: query
// -
// name: min_id
// type: string
// description: >-
// Return only statuses *NEWER* than the given since status ID.
// The status with the specified ID will not be included in the response.
// in: query
// required: false
// -
// name: limit
// type: integer
// description: Number of statuses to return.
// default: 20
// in: query
// required: false
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// name: statuses
// description: Array of statuses.
// schema:
// type: array
// items:
// "$ref": "#/definitions/status"
// headers:
// Link:
// type: string
// description: Links to the next and previous queries.
// '401':
// description: unauthorized
// '400':
// description: bad request
func (m *Module) ListTimelineGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
resp, errWithCode := m.processor.Timeline().ListTimelineGet(
c.Request.Context(),
authed,
targetListID,
c.Query(MaxIDKey),
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if resp.LinkHeader != "" {
c.Header("Link", resp.LinkHeader)
}
c.JSON(http.StatusOK, resp.Items)
}

View file

@ -18,9 +18,7 @@
package timelines
import (
"fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
@ -131,49 +129,27 @@ func (m *Module) PublicTimelineGETHandler(c *gin.Context) {
return
}
maxID := ""
maxIDString := c.Query(MaxIDKey)
if maxIDString != "" {
maxID = maxIDString
}
sinceID := ""
sinceIDString := c.Query(SinceIDKey)
if sinceIDString != "" {
sinceID = sinceIDString
}
minID := ""
minIDString := c.Query(MinIDKey)
if minIDString != "" {
minID = minIDString
}
limit := 20
limitString := c.Query(LimitKey)
if limitString != "" {
i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
limit = int(i)
}
local := false
localString := c.Query(LocalKey)
if localString != "" {
i, err := strconv.ParseBool(localString)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LocalKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
local, errWithCode := apiutil.ParseLocal(c.Query(apiutil.LocalKey), false)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
local = i
}
resp, errWithCode := m.processor.PublicTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local)
resp, errWithCode := m.processor.Timeline().PublicTimelineGet(
c.Request.Context(),
authed,
c.Query(MaxIDKey),
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
local,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -27,10 +27,12 @@ import (
const (
// BasePath is the base URI path for serving timelines, minus the 'api' prefix.
BasePath = "/v1/timelines"
IDKey = "id"
// HomeTimeline is the path for the home timeline
HomeTimeline = BasePath + "/home"
// PublicTimeline is the path for the public (and public local) timeline
PublicTimeline = BasePath + "/public"
ListTimeline = BasePath + "/list/:" + IDKey
// MaxIDKey is the url query for setting a max status ID to return
MaxIDKey = "max_id"
// SinceIDKey is the url query for returning results newer than the given ID
@ -56,4 +58,5 @@ func New(processor *processing.Processor) *Module {
func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
attachHandler(http.MethodGet, HomeTimeline, m.HomeTimelineGETHandler)
attachHandler(http.MethodGet, PublicTimeline, m.PublicTimelineGETHandler)
attachHandler(http.MethodGet, ListTimeline, m.ListTimelineGETHandler)
}

View file

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -73,6 +74,13 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
@ -81,8 +89,6 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.userModule = user.New(suite.processor)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *UserStandardTestSuite) TearDownTest() {

View file

@ -33,6 +33,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -81,6 +82,13 @@ func (suite *FileserverTestSuite) SetupSuite() {
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)

View file

@ -17,14 +17,57 @@
package model
// List represents a list of some users that the authenticated user follows.
// List represents a user-created list of accounts that the user follows.
//
// swagger:model list
type List struct {
// The internal database ID of the list.
// The ID of the list.
ID string `json:"id"`
// The user-defined title of the list.
Title string `json:"title"`
// RepliesPolicy for this list.
// followed = Show replies to any followed user
// list = Show replies to members of the list
// none = Show replies to no one
RepliesPolicy string `json:"replies_policy"`
}
// ListCreateRequest models list creation parameters.
//
// swagger:parameters listCreate
type ListCreateRequest struct {
// Title of this list.
// example: Cool People
// in: formData
// required: true
Title string `form:"title" json:"title" xml:"title"`
// RepliesPolicy for this list.
// followed = Show replies to any followed user
// list = Show replies to members of the list
// none = Show replies to no one
// example: list
// default: list
// in: formData
RepliesPolicy string `form:"replies_policy" json:"replies_policy" xml:"replies_policy"`
}
// ListUpdateRequest models list update parameters.
//
// swagger:parameters listUpdate
type ListUpdateRequest struct {
// Title of this list.
// example: Cool People
// in: formData
Title *string `form:"title" json:"title" xml:"title"`
// RepliesPolicy for this list.
// followed = Show replies to any followed user
// list = Show replies to members of the list
// none = Show replies to no one
// in: formData
RepliesPolicy *string `form:"replies_policy" json:"replies_policy" xml:"replies_policy"`
}
// swagger:ignore
type ListAccountsChangeRequest struct {
AccountIDs []string `form:"account_ids[]" json:"account_ids" xml:"account_ids"`
}

View file

@ -0,0 +1,58 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package util
import (
"fmt"
"strconv"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
)
const (
LimitKey = "limit"
LocalKey = "local"
)
func ParseLimit(limit string, defaultLimit int) (int, gtserror.WithCode) {
if limit == "" {
return defaultLimit, nil
}
i, err := strconv.Atoi(limit)
if err != nil {
err := fmt.Errorf("error parsing %s: %w", LimitKey, err)
return 0, gtserror.NewErrorBadRequest(err, err.Error())
}
return i, nil
}
func ParseLocal(local string, defaultLocal bool) (bool, gtserror.WithCode) {
if local == "" {
return defaultLocal, nil
}
i, err := strconv.ParseBool(local)
if err != nil {
err := fmt.Errorf("error parsing %s: %w", LocalKey, err)
return false, gtserror.NewErrorBadRequest(err, err.Error())
}
return i, nil
}

View file

@ -30,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -79,6 +80,13 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
@ -89,8 +97,6 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *WebfingerStandardTestSuite) TearDownTest() {

43
internal/cache/gts.go vendored
View file

@ -35,6 +35,8 @@ type GTSCaches struct {
emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
follow *result.Cache[*gtsmodel.Follow]
followRequest *result.Cache[*gtsmodel.FollowRequest]
list *result.Cache[*gtsmodel.List]
listEntry *result.Cache[*gtsmodel.ListEntry]
media *result.Cache[*gtsmodel.MediaAttachment]
mention *result.Cache[*gtsmodel.Mention]
notification *result.Cache[*gtsmodel.Notification]
@ -57,6 +59,8 @@ func (c *GTSCaches) Init() {
c.initEmojiCategory()
c.initFollow()
c.initFollowRequest()
c.initList()
c.initListEntry()
c.initMedia()
c.initMention()
c.initNotification()
@ -76,6 +80,8 @@ func (c *GTSCaches) Start() {
tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStart(c.follow, config.GetCacheGTSFollowSweepFreq())
tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
tryStart(c.list, config.GetCacheGTSListSweepFreq())
tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
tryStart(c.media, config.GetCacheGTSMediaSweepFreq())
tryStart(c.mention, config.GetCacheGTSMentionSweepFreq())
tryStart(c.notification, config.GetCacheGTSNotificationSweepFreq())
@ -100,6 +106,8 @@ func (c *GTSCaches) Stop() {
tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStop(c.follow, config.GetCacheGTSFollowSweepFreq())
tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
tryStop(c.list, config.GetCacheGTSListSweepFreq())
tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
tryStop(c.media, config.GetCacheGTSMediaSweepFreq())
tryStop(c.mention, config.GetCacheGTSNotificationSweepFreq())
tryStop(c.notification, config.GetCacheGTSNotificationSweepFreq())
@ -146,6 +154,16 @@ func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] {
return c.followRequest
}
// List provides access to the gtsmodel List database cache.
func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] {
return c.list
}
// ListEntry provides access to the gtsmodel ListEntry database cache.
func (c *GTSCaches) ListEntry() *result.Cache[*gtsmodel.ListEntry] {
return c.listEntry
}
// Media provides access to the gtsmodel Media database cache.
func (c *GTSCaches) Media() *result.Cache[*gtsmodel.MediaAttachment] {
return c.media
@ -283,6 +301,30 @@ func (c *GTSCaches) initFollowRequest() {
c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true)
}
func (c *GTSCaches) initList() {
c.list = result.New([]result.Lookup{
{Name: "ID"},
}, func(l1 *gtsmodel.List) *gtsmodel.List {
l2 := new(gtsmodel.List)
*l2 = *l1
return l2
}, config.GetCacheGTSListMaxSize())
c.list.SetTTL(config.GetCacheGTSListTTL(), true)
c.list.IgnoreErrors(ignoreErrors)
}
func (c *GTSCaches) initListEntry() {
c.listEntry = result.New([]result.Lookup{
{Name: "ID"},
}, func(l1 *gtsmodel.ListEntry) *gtsmodel.ListEntry {
l2 := new(gtsmodel.ListEntry)
*l2 = *l1
return l2
}, config.GetCacheGTSListEntryMaxSize())
c.list.SetTTL(config.GetCacheGTSListEntryTTL(), true)
c.list.IgnoreErrors(ignoreErrors)
}
func (c *GTSCaches) initMedia() {
c.media = result.New([]result.Lookup{
{Name: "ID"},
@ -359,7 +401,6 @@ func (c *GTSCaches) initStatusFave() {
c.status.IgnoreErrors(ignoreErrors)
}
// initTombstone will initialize the gtsmodel.Tombstone cache.
func (c *GTSCaches) initTombstone() {
c.tombstone = result.New([]result.Lookup{
{Name: "ID"},

View file

@ -199,6 +199,14 @@ type GTSCacheConfiguration struct {
FollowRequestTTL time.Duration `name:"follow-request-ttl"`
FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"`
ListMaxSize int `name:"list-max-size"`
ListTTL time.Duration `name:"list-ttl"`
ListSweepFreq time.Duration `name:"list-sweep-freq"`
ListEntryMaxSize int `name:"list-entry-max-size"`
ListEntryTTL time.Duration `name:"list-entry-ttl"`
ListEntrySweepFreq time.Duration `name:"list-entry-sweep-freq"`
MediaMaxSize int `name:"media-max-size"`
MediaTTL time.Duration `name:"media-ttl"`
MediaSweepFreq time.Duration `name:"media-sweep-freq"`

View file

@ -153,6 +153,14 @@ var Defaults = Configuration{
FollowRequestTTL: time.Minute * 30,
FollowRequestSweepFreq: time.Minute,
ListMaxSize: 2000,
ListTTL: time.Minute * 30,
ListSweepFreq: time.Minute,
ListEntryMaxSize: 2000,
ListEntryTTL: time.Minute * 30,
ListEntrySweepFreq: time.Minute,
MediaMaxSize: 1000,
MediaTTL: time.Minute * 30,
MediaSweepFreq: time.Minute,

View file

@ -2778,6 +2778,156 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration {
// SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) }
// GetCacheGTSListMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListMaxSize' field
func (st *ConfigState) GetCacheGTSListMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListMaxSize
st.mutex.Unlock()
return
}
// SetCacheGTSListMaxSize safely sets the Configuration value for state's 'Cache.GTS.ListMaxSize' field
func (st *ConfigState) SetCacheGTSListMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListMaxSize = v
st.reloadToViper()
}
// CacheGTSListMaxSizeFlag returns the flag name for the 'Cache.GTS.ListMaxSize' field
func CacheGTSListMaxSizeFlag() string { return "cache-gts-list-max-size" }
// GetCacheGTSListMaxSize safely fetches the value for global configuration 'Cache.GTS.ListMaxSize' field
func GetCacheGTSListMaxSize() int { return global.GetCacheGTSListMaxSize() }
// SetCacheGTSListMaxSize safely sets the value for global configuration 'Cache.GTS.ListMaxSize' field
func SetCacheGTSListMaxSize(v int) { global.SetCacheGTSListMaxSize(v) }
// GetCacheGTSListTTL safely fetches the Configuration value for state's 'Cache.GTS.ListTTL' field
func (st *ConfigState) GetCacheGTSListTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListTTL
st.mutex.Unlock()
return
}
// SetCacheGTSListTTL safely sets the Configuration value for state's 'Cache.GTS.ListTTL' field
func (st *ConfigState) SetCacheGTSListTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListTTL = v
st.reloadToViper()
}
// CacheGTSListTTLFlag returns the flag name for the 'Cache.GTS.ListTTL' field
func CacheGTSListTTLFlag() string { return "cache-gts-list-ttl" }
// GetCacheGTSListTTL safely fetches the value for global configuration 'Cache.GTS.ListTTL' field
func GetCacheGTSListTTL() time.Duration { return global.GetCacheGTSListTTL() }
// SetCacheGTSListTTL safely sets the value for global configuration 'Cache.GTS.ListTTL' field
func SetCacheGTSListTTL(v time.Duration) { global.SetCacheGTSListTTL(v) }
// GetCacheGTSListSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.ListSweepFreq' field
func (st *ConfigState) GetCacheGTSListSweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListSweepFreq
st.mutex.Unlock()
return
}
// SetCacheGTSListSweepFreq safely sets the Configuration value for state's 'Cache.GTS.ListSweepFreq' field
func (st *ConfigState) SetCacheGTSListSweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListSweepFreq = v
st.reloadToViper()
}
// CacheGTSListSweepFreqFlag returns the flag name for the 'Cache.GTS.ListSweepFreq' field
func CacheGTSListSweepFreqFlag() string { return "cache-gts-list-sweep-freq" }
// GetCacheGTSListSweepFreq safely fetches the value for global configuration 'Cache.GTS.ListSweepFreq' field
func GetCacheGTSListSweepFreq() time.Duration { return global.GetCacheGTSListSweepFreq() }
// SetCacheGTSListSweepFreq safely sets the value for global configuration 'Cache.GTS.ListSweepFreq' field
func SetCacheGTSListSweepFreq(v time.Duration) { global.SetCacheGTSListSweepFreq(v) }
// GetCacheGTSListEntryMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListEntryMaxSize' field
func (st *ConfigState) GetCacheGTSListEntryMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListEntryMaxSize
st.mutex.Unlock()
return
}
// SetCacheGTSListEntryMaxSize safely sets the Configuration value for state's 'Cache.GTS.ListEntryMaxSize' field
func (st *ConfigState) SetCacheGTSListEntryMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListEntryMaxSize = v
st.reloadToViper()
}
// CacheGTSListEntryMaxSizeFlag returns the flag name for the 'Cache.GTS.ListEntryMaxSize' field
func CacheGTSListEntryMaxSizeFlag() string { return "cache-gts-list-entry-max-size" }
// GetCacheGTSListEntryMaxSize safely fetches the value for global configuration 'Cache.GTS.ListEntryMaxSize' field
func GetCacheGTSListEntryMaxSize() int { return global.GetCacheGTSListEntryMaxSize() }
// SetCacheGTSListEntryMaxSize safely sets the value for global configuration 'Cache.GTS.ListEntryMaxSize' field
func SetCacheGTSListEntryMaxSize(v int) { global.SetCacheGTSListEntryMaxSize(v) }
// GetCacheGTSListEntryTTL safely fetches the Configuration value for state's 'Cache.GTS.ListEntryTTL' field
func (st *ConfigState) GetCacheGTSListEntryTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListEntryTTL
st.mutex.Unlock()
return
}
// SetCacheGTSListEntryTTL safely sets the Configuration value for state's 'Cache.GTS.ListEntryTTL' field
func (st *ConfigState) SetCacheGTSListEntryTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListEntryTTL = v
st.reloadToViper()
}
// CacheGTSListEntryTTLFlag returns the flag name for the 'Cache.GTS.ListEntryTTL' field
func CacheGTSListEntryTTLFlag() string { return "cache-gts-list-entry-ttl" }
// GetCacheGTSListEntryTTL safely fetches the value for global configuration 'Cache.GTS.ListEntryTTL' field
func GetCacheGTSListEntryTTL() time.Duration { return global.GetCacheGTSListEntryTTL() }
// SetCacheGTSListEntryTTL safely sets the value for global configuration 'Cache.GTS.ListEntryTTL' field
func SetCacheGTSListEntryTTL(v time.Duration) { global.SetCacheGTSListEntryTTL(v) }
// GetCacheGTSListEntrySweepFreq safely fetches the Configuration value for state's 'Cache.GTS.ListEntrySweepFreq' field
func (st *ConfigState) GetCacheGTSListEntrySweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListEntrySweepFreq
st.mutex.Unlock()
return
}
// SetCacheGTSListEntrySweepFreq safely sets the Configuration value for state's 'Cache.GTS.ListEntrySweepFreq' field
func (st *ConfigState) SetCacheGTSListEntrySweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListEntrySweepFreq = v
st.reloadToViper()
}
// CacheGTSListEntrySweepFreqFlag returns the flag name for the 'Cache.GTS.ListEntrySweepFreq' field
func CacheGTSListEntrySweepFreqFlag() string { return "cache-gts-list-entry-sweep-freq" }
// GetCacheGTSListEntrySweepFreq safely fetches the value for global configuration 'Cache.GTS.ListEntrySweepFreq' field
func GetCacheGTSListEntrySweepFreq() time.Duration { return global.GetCacheGTSListEntrySweepFreq() }
// SetCacheGTSListEntrySweepFreq safely sets the value for global configuration 'Cache.GTS.ListEntrySweepFreq' field
func SetCacheGTSListEntrySweepFreq(v time.Duration) { global.SetCacheGTSListEntrySweepFreq(v) }
// GetCacheGTSMediaMaxSize safely fetches the Configuration value for state's 'Cache.GTS.MediaMaxSize' field
func (st *ConfigState) GetCacheGTSMediaMaxSize() (v int) {
st.mutex.Lock()

View file

@ -65,6 +65,7 @@ type DBService struct {
db.Domain
db.Emoji
db.Instance
db.List
db.Media
db.Mention
db.Notification
@ -179,6 +180,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
Instance: &instanceDB{
conn: conn,
},
List: &listDB{
conn: conn,
state: state,
},
Media: &mediaDB{
conn: conn,
state: state,

View file

@ -22,6 +22,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -46,6 +47,8 @@ type BunDBStandardTestSuite struct {
testReports map[string]*gtsmodel.Report
testBookmarks map[string]*gtsmodel.StatusBookmark
testFaves map[string]*gtsmodel.StatusFave
testLists map[string]*gtsmodel.List
testListEntries map[string]*gtsmodel.ListEntry
}
func (suite *BunDBStandardTestSuite) SetupSuite() {
@ -63,6 +66,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testReports = testrig.NewTestReports()
suite.testBookmarks = testrig.NewTestBookmarks()
suite.testFaves = testrig.NewTestFaves()
suite.testLists = testrig.NewTestLists()
suite.testListEntries = testrig.NewTestListEntries()
}
func (suite *BunDBStandardTestSuite) SetupTest() {
@ -70,6 +75,7 @@ func (suite *BunDBStandardTestSuite) SetupTest() {
testrig.InitTestLog()
suite.state.Caches.Init()
suite.db = testrig.NewTestDB(&suite.state)
testrig.StartTimelines(&suite.state, visibility.NewFilter(&suite.state), testrig.NewTestTypeConverter(suite.db))
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}

467
internal/db/bundb/list.go Normal file
View file

@ -0,0 +1,467 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
type listDB struct {
conn *DBConn
state *state.State
}
/*
LIST FUNCTIONS
*/
func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {
list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) {
var list gtsmodel.List
// Not cached! Perform database query.
if err := dbQuery(&list); err != nil {
return nil, l.conn.ProcessError(err)
}
return &list, nil
}, keyParts...)
if err != nil {
return nil, err // already processed
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return list, nil
}
if err := l.state.DB.PopulateList(ctx, list); err != nil {
return nil, err
}
return list, nil
}
func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) {
return l.getList(
ctx,
"ID",
func(list *gtsmodel.List) error {
return l.conn.NewSelect().
Model(list).
Where("? = ?", bun.Ident("list.id"), id).
Scan(ctx)
},
id,
)
}
func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) {
// Fetch IDs of all lists owned by this account.
var listIDs []string
if err := l.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")).
Column("list.id").
Where("? = ?", bun.Ident("list.account_id"), accountID).
Order("list.id DESC").
Scan(ctx, &listIDs); err != nil {
return nil, l.conn.ProcessError(err)
}
if len(listIDs) == 0 {
return nil, nil
}
// Select each list using its ID to ensure cache used.
lists := make([]*gtsmodel.List, 0, len(listIDs))
for _, id := range listIDs {
list, err := l.state.DB.GetListByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list %q: %v", id, err)
continue
}
// Append list.
lists = append(lists, list)
}
return lists, nil
}
func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
var (
err error
errs = make(gtserror.MultiError, 0, 2)
)
if list.Account == nil {
// List account is not set, fetch from the database.
list.Account, err = l.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
list.AccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating list account: %w", err))
}
}
if list.ListEntries == nil {
// List entries are not set, fetch from the database.
list.ListEntries, err = l.state.DB.GetListEntries(
gtscontext.SetBarebones(ctx),
list.ID,
"", "", "", 0,
)
if err != nil {
errs.Append(fmt.Errorf("error populating list entries: %w", err))
}
}
return errs.Combine()
}
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
return l.state.Caches.GTS.List().Store(list, func() error {
_, err := l.conn.NewInsert().Model(list).Exec(ctx)
return l.conn.ProcessError(err)
})
}
func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error {
list.UpdatedAt = time.Now()
if len(columns) > 0 {
// If we're updating by column, ensure "updated_at" is included.
columns = append(columns, "updated_at")
}
return l.state.Caches.GTS.List().Store(list, func() error {
if _, err := l.conn.NewUpdate().
Model(list).
Where("? = ?", bun.Ident("list.id"), list.ID).
Column(columns...).
Exec(ctx); err != nil {
return l.conn.ProcessError(err)
}
return nil
})
}
func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
defer l.state.Caches.GTS.List().Invalidate("ID", id)
// Select all entries that belong to this list.
listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0)
if err != nil {
return fmt.Errorf("error selecting entries from list %q: %w", id, err)
}
// Delete each list entry. This will
// invalidate the list timeline too.
for _, listEntry := range listEntries {
err := l.state.DB.DeleteListEntry(ctx, listEntry.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
}
// Finally delete list itself from DB.
_, err = l.conn.NewDelete().
Table("lists").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return l.conn.ProcessError(err)
}
/*
LIST ENTRY functions
*/
func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {
listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) {
var listEntry gtsmodel.ListEntry
// Not cached! Perform database query.
if err := dbQuery(&listEntry); err != nil {
return nil, l.conn.ProcessError(err)
}
return &listEntry, nil
}, keyParts...)
if err != nil {
return nil, err // already processed
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return listEntry, nil
}
// Further populate the list entry fields where applicable.
if err := l.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
return nil, err
}
return listEntry, nil
}
func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) {
return l.getListEntry(
ctx,
"ID",
func(listEntry *gtsmodel.ListEntry) error {
return l.conn.NewSelect().
Model(listEntry).
Where("? = ?", bun.Ident("list_entry.id"), id).
Scan(ctx)
},
id,
)
}
func (l *listDB) GetListEntries(ctx context.Context,
listID string,
maxID string,
sinceID string,
minID string,
limit int,
) ([]*gtsmodel.ListEntry, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
var (
entryIDs = make([]string, 0, limit)
frontToBack = true
)
q := l.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
// Select only IDs from table
Column("entry.id").
// Select only entries belonging to listID.
Where("? = ?", bun.Ident("entry.list_id"), listID)
if maxID != "" {
// return only entries LOWER (ie., older) than maxID
q = q.Where("? < ?", bun.Ident("entry.id"), maxID)
}
if sinceID != "" {
// return only entries HIGHER (ie., newer) than sinceID
q = q.Where("? > ?", bun.Ident("entry.id"), sinceID)
}
if minID != "" {
// return only entries HIGHER (ie., newer) than minID
q = q.Where("? > ?", bun.Ident("entry.id"), minID)
// page up
frontToBack = false
}
if limit > 0 {
// limit amount of entries returned
q = q.Limit(limit)
}
if frontToBack {
// Page down.
q = q.Order("entry.id DESC")
} else {
// Page up.
q = q.Order("entry.id ASC")
}
if err := q.Scan(ctx, &entryIDs); err != nil {
return nil, l.conn.ProcessError(err)
}
if len(entryIDs) == 0 {
return nil, nil
}
// If we're paging up, we still want entries
// to be sorted by ID desc, so reverse ids slice.
// https://zchee.github.io/golang-wiki/SliceTricks/#reversing
if !frontToBack {
for l, r := 0, len(entryIDs)-1; l < r; l, r = l+1, r-1 {
entryIDs[l], entryIDs[r] = entryIDs[r], entryIDs[l]
}
}
// Select each list entry using its ID to ensure cache used.
listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
for _, id := range entryIDs {
listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
continue
}
// Append list entries.
listEntries = append(listEntries, listEntry)
}
return listEntries, nil
}
func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
entryIDs := []string{}
if err := l.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
// Select only IDs from table
Column("entry.id").
// Select only entries belonging with given followID.
Where("? = ?", bun.Ident("entry.follow_id"), followID).
Scan(ctx, &entryIDs); err != nil {
return nil, l.conn.ProcessError(err)
}
if len(entryIDs) == 0 {
return nil, nil
}
// Select each list entry using its ID to ensure cache used.
listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
for _, id := range entryIDs {
listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
continue
}
// Append list entries.
listEntries = append(listEntries, listEntry)
}
return listEntries, nil
}
func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error {
var err error
if listEntry.Follow == nil {
// ListEntry follow is not set, fetch from the database.
listEntry.Follow, err = l.state.DB.GetFollowByID(
gtscontext.SetBarebones(ctx),
listEntry.FollowID,
)
if err != nil {
return fmt.Errorf("error populating listEntry follow: %w", err)
}
}
return nil
}
func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error {
return l.conn.RunInTx(ctx, func(tx bun.Tx) error {
for _, listEntry := range listEntries {
if _, err := tx.
NewInsert().
Model(listEntry).
Exec(ctx); err != nil {
return err
}
// Invalidate the timeline for the list this entry belongs to.
if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil {
log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err)
}
}
return nil
})
}
func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id)
// Load list entry into cache before attempting a delete,
// as we need the followID from it in order to trigger
// timeline invalidation.
listEntry, err := l.GetListEntryByID(
// Don't populate the entry;
// we only want the list ID.
gtscontext.SetBarebones(ctx),
id,
)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// Already gone.
return nil
}
return err
}
defer func() {
// Invalidate the timeline for the list this entry belongs to.
if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil {
log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err)
}
}()
if _, err := l.conn.NewDelete().
Table("list_entries").
Where("? = ?", bun.Ident("id"), listEntry.ID).
Exec(ctx); err != nil {
return l.conn.ProcessError(err)
}
return nil
}
func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error {
// Fetch IDs of all entries that pertain to this follow.
var listEntryIDs []string
if err := l.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")).
Column("list_entry.id").
Where("? = ?", bun.Ident("list_entry.follow_id"), followID).
Order("list_entry.id DESC").
Scan(ctx, &listEntryIDs); err != nil {
return l.conn.ProcessError(err)
}
for _, id := range listEntryIDs {
if err := l.DeleteListEntry(ctx, id); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,315 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"golang.org/x/exp/slices"
)
type ListTestSuite struct {
BunDBStandardTestSuite
}
func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) {
testList := &gtsmodel.List{}
*testList = *suite.testLists["local_account_1_list_1"]
// Populate entries on this list as we'd expect them back from the db.
entries := make([]*gtsmodel.ListEntry, 0, len(suite.testListEntries))
for _, entry := range suite.testListEntries {
entries = append(entries, entry)
}
// Sort by ID descending (again, as we'd expect from the db).
slices.SortFunc(entries, func(a, b *gtsmodel.ListEntry) bool {
return b.ID < a.ID
})
testList.ListEntries = entries
testAccount := &gtsmodel.Account{}
*testAccount = *suite.testAccounts["local_account_1"]
return testList, testAccount
}
func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) {
suite.Equal(expected.ID, actual.ID)
suite.Equal(expected.Title, actual.Title)
suite.Equal(expected.AccountID, actual.AccountID)
suite.Equal(expected.RepliesPolicy, actual.RepliesPolicy)
suite.NotNil(actual.Account)
}
func (suite *ListTestSuite) checkListEntry(expected *gtsmodel.ListEntry, actual *gtsmodel.ListEntry) {
suite.Equal(expected.ID, actual.ID)
suite.Equal(expected.ListID, actual.ListID)
suite.Equal(expected.FollowID, actual.FollowID)
}
func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, actual []*gtsmodel.ListEntry) {
var (
lExpected = len(expected)
lActual = len(actual)
)
if lExpected != lActual {
suite.FailNow("", "expected %d list entries, got %d", lExpected, lActual)
}
var topID string
for i, expectedEntry := range expected {
actualEntry := actual[i]
// Ensure ID descending.
if topID == "" {
topID = actualEntry.ID
} else {
suite.Less(actualEntry.ID, topID)
}
suite.checkListEntry(expectedEntry, actualEntry)
}
}
func (suite *ListTestSuite) TestGetListByID() {
testList, _ := suite.testStructs()
dbList, err := suite.db.GetListByID(context.Background(), testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkList(testList, dbList)
suite.checkListEntries(testList.ListEntries, dbList.ListEntries)
}
func (suite *ListTestSuite) TestGetListsForAccountID() {
testList, testAccount := suite.testStructs()
dbLists, err := suite.db.GetListsForAccountID(context.Background(), testAccount.ID)
if err != nil {
suite.FailNow(err.Error())
}
if l := len(dbLists); l != 1 {
suite.FailNow("", "expected %d lists, got %d", 1, l)
}
suite.checkList(testList, dbLists[0])
}
func (suite *ListTestSuite) TestGetListEntries() {
testList, _ := suite.testStructs()
dbListEntries, err := suite.db.GetListEntries(context.Background(), testList.ID, "", "", "", 0)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkListEntries(testList.ListEntries, dbListEntries)
}
func (suite *ListTestSuite) TestPutList() {
ctx := context.Background()
_, testAccount := suite.testStructs()
testList := &gtsmodel.List{
ID: "01H0J2PMYM54618VCV8Y8QYAT4",
Title: "Test List!",
AccountID: testAccount.ID,
}
if err := suite.db.PutList(ctx, testList); err != nil {
suite.FailNow(err.Error())
}
dbList, err := suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
// Bodge testlist as though default had been set.
testList.RepliesPolicy = gtsmodel.RepliesPolicyFollowed
suite.checkList(testList, dbList)
}
func (suite *ListTestSuite) TestUpdateList() {
ctx := context.Background()
testList, _ := suite.testStructs()
// Get List in the cache first.
dbList, err := suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
// Now do the update.
testList.Title = "New Title!"
if err := suite.db.UpdateList(ctx, testList, "title"); err != nil {
suite.FailNow(err.Error())
}
// Cache should be invalidated
// + we should have updated list.
dbList, err = suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkList(testList, dbList)
}
func (suite *ListTestSuite) TestDeleteList() {
ctx := context.Background()
testList, _ := suite.testStructs()
// Get List in the cache first.
if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
suite.FailNow(err.Error())
}
// Now do the delete.
if err := suite.db.DeleteListByID(ctx, testList.ID); err != nil {
suite.FailNow(err.Error())
}
// Cache should be invalidated
// + we should have no list.
_, err := suite.db.GetListByID(ctx, testList.ID)
suite.ErrorIs(err, db.ErrNoEntries)
// All entries belonging to this
// list should now be deleted.
listEntries, err := suite.db.GetListEntries(ctx, testList.ID, "", "", "", 0)
if err != nil {
suite.FailNow(err.Error())
}
suite.Empty(listEntries)
}
func (suite *ListTestSuite) TestPutListEntries() {
ctx := context.Background()
testList, _ := suite.testStructs()
listEntries := []*gtsmodel.ListEntry{
{
ID: "01H0MKMQY69HWDSDR2SWGA17R4",
ListID: testList.ID,
FollowID: "01H0MKNFRFZS8R9WV6DBX31Y03", // random id, doesn't exist
},
{
ID: "01H0MKPGQF0E7QAVW5BKTHZ630",
ListID: testList.ID,
FollowID: "01H0MKP6RR8VEHN3GVWFBP2H30", // random id, doesn't exist
},
{
ID: "01H0MKPPP2DT68FRBMR1FJM32T",
ListID: testList.ID,
FollowID: "01H0MKQ0KA29C6NFJ27GTZD16J", // random id, doesn't exist
},
}
if err := suite.db.PutListEntries(ctx, listEntries); err != nil {
suite.FailNow(err.Error())
}
// Add these entries to the test list, sort it again
// to reflect what we'd expect to get from the db.
testList.ListEntries = append(testList.ListEntries, listEntries...)
slices.SortFunc(testList.ListEntries, func(a, b *gtsmodel.ListEntry) bool {
return b.ID < a.ID
})
// Now get all list entries from the db.
// Use barebones for this because the ones
// we just added will fail if we try to get
// the nonexistent follows.
dbListEntries, err := suite.db.GetListEntries(
gtscontext.SetBarebones(ctx),
testList.ID,
"", "", "", 0)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkListEntries(testList.ListEntries, dbListEntries)
}
func (suite *ListTestSuite) TestDeleteListEntry() {
ctx := context.Background()
testList, _ := suite.testStructs()
// Get List in the cache first.
if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
suite.FailNow(err.Error())
}
// Delete the first entry.
if err := suite.db.DeleteListEntry(ctx, testList.ListEntries[0].ID); err != nil {
suite.FailNow(err.Error())
}
// Get list from the db again.
dbList, err := suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
// Bodge the testlist as though
// we'd removed the first entry.
testList.ListEntries = testList.ListEntries[1:]
suite.checkList(testList, dbList)
}
func (suite *ListTestSuite) TestDeleteListEntriesForFollowID() {
ctx := context.Background()
testList, _ := suite.testStructs()
// Get List in the cache first.
if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
suite.FailNow(err.Error())
}
// Delete the first entry.
if err := suite.db.DeleteListEntriesForFollowID(ctx, testList.ListEntries[0].FollowID); err != nil {
suite.FailNow(err.Error())
}
// Get list from the db again.
dbList, err := suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
// Bodge the testlist as though
// we'd removed the first entry.
testList.ListEntries = testList.ListEntries[1:]
suite.checkList(testList, dbList)
}
func TestListTestSuite(t *testing.T) {
suite.Run(t, new(ListTestSuite))
}

View file

@ -0,0 +1,92 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// List table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.List{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Add indexes to the List table.
for index, columns := range map[string][]string{
"lists_id_idx": {"id"},
"lists_account_id_idx": {"account_id"},
} {
if _, err := tx.
NewCreateIndex().
Table("lists").
Index(index).
Column(columns...).
Exec(ctx); err != nil {
return err
}
}
// List entry table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.ListEntry{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Add indexes to the List entry table.
for index, columns := range map[string][]string{
"list_entries_id_idx": {"id"},
"list_entries_list_id_idx": {"list_id"},
"list_entries_follow_id_idx": {"follow_id"},
} {
if _, err := tx.
NewCreateIndex().
Table("list_entries").
Index(index).
Column(columns...).
Exec(ctx); err != nil {
return err
}
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
@ -149,25 +150,42 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f
return follow, nil
}
// Set the follow source account
if err := r.state.DB.PopulateFollow(ctx, follow); err != nil {
return nil, err
}
return follow, nil
}
func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error {
var (
err error
errs = make(gtserror.MultiError, 0, 2)
)
if follow.Account == nil {
// Follow account is not set, fetch from the database.
follow.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow source account: %w", err)
errs.Append(fmt.Errorf("error populating follow account: %w", err))
}
}
// Set the follow target account
if follow.TargetAccount == nil {
// Follow target account is not set, fetch from the database.
follow.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow target account: %w", err)
errs.Append(fmt.Errorf("error populating follow target account: %w", err))
}
}
return follow, nil
return errs.Combine()
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
@ -197,27 +215,40 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
})
}
func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {
// Delete the follow itself using the given ID.
if _, err := r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Delete every list entry that used this followID.
if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {
return fmt.Errorf("deleteFollow: error deleting list entries: %w", err)
}
return nil
}
func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
_, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id)
follow, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// not an issue.
err = nil
// Already gone.
return nil
}
return err
}
// Finally delete follow from DB.
_, err = r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return r.conn.ProcessError(err)
return r.deleteFollow(ctx, follow.ID)
}
func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
@ -226,21 +257,17 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
_, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri)
follow, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// not an issue.
err = nil
// Already gone.
return nil
}
return err
}
// Finally delete follow from DB.
_, err = r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx)
return r.conn.ProcessError(err)
return r.deleteFollow(ctx, follow.ID)
}
func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
@ -272,16 +299,16 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
for _, id := range followIDs {
_, err := r.GetFollowByID(ctx, id)
follow, err := r.GetFollowByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Delete each follow from DB.
if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
}
// Finally delete all from DB.
_, err := r.conn.NewDelete().
Table("follows").
Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)).
Exec(ctx)
return r.conn.ProcessError(err)
return nil
}

View file

@ -807,16 +807,27 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() {
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.NotNil(follow)
followID := follow.ID
err = suite.db.DeleteFollowByID(context.Background(), follow.ID)
// We should have list entries for this follow.
listEntries, err := suite.db.GetListEntriesForFollowID(context.Background(), followID)
suite.NoError(err)
suite.NotEmpty(listEntries)
err = suite.db.DeleteFollowByID(context.Background(), followID)
suite.NoError(err)
follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(follow)
// ListEntries pertaining to this follow should be deleted too.
listEntries, err = suite.db.GetListEntriesForFollowID(context.Background(), followID)
suite.NoError(err)
suite.Empty(listEntries)
}
func (suite *RelationshipTestSuite) TestUnfollowNotExisting() {
func (suite *RelationshipTestSuite) TestGetFollowNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"

View file

@ -19,9 +19,11 @@ package bundb
import (
"context"
"fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
@ -281,3 +283,130 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
prevMinID := faves[0].ID
return statuses, nextMaxID, prevMinID, nil
}
func (t *timelineDB) GetListTimeline(
ctx context.Context,
listID string,
maxID string,
sinceID string,
minID string,
limit int,
) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
var (
statusIDs = make([]string, 0, limit)
frontToBack = true
)
// Fetch all listEntries entries from the database.
listEntries, err := t.state.DB.GetListEntries(
// Don't need actual follows
// for this, just the IDs.
gtscontext.SetBarebones(ctx),
listID,
"", "", "", 0,
)
if err != nil {
return nil, fmt.Errorf("error getting entries for list %s: %w", listID, err)
}
// Extract just the IDs of each follow.
followIDs := make([]string, 0, len(listEntries))
for _, listEntry := range listEntries {
followIDs = append(followIDs, listEntry.FollowID)
}
// Select target account IDs from follows.
subQ := t.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.target_account_id").
Where("? IN (?)", bun.Ident("follow.id"), bun.In(followIDs))
// Select only status IDs created
// by one of the followed accounts.
q := t.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
Column("status.id").
Where("? IN (?)", bun.Ident("status.account_id"), subQ)
if maxID == "" || maxID >= id.Highest {
const future = 24 * time.Hour
var err error
// don't return statuses more than 24hr in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}
}
// return only statuses LOWER (ie., older) than maxID
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
if sinceID != "" {
// return only statuses HIGHER (ie., newer) than sinceID
q = q.Where("? > ?", bun.Ident("status.id"), sinceID)
}
if minID != "" {
// return only statuses HIGHER (ie., newer) than minID
q = q.Where("? > ?", bun.Ident("status.id"), minID)
// page up
frontToBack = false
}
if limit > 0 {
// limit amount of statuses returned
q = q.Limit(limit)
}
if frontToBack {
// Page down.
q = q.Order("status.id DESC")
} else {
// Page up.
q = q.Order("status.id ASC")
}
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err)
}
if len(statusIDs) == 0 {
return nil, nil
}
// If we're paging up, we still want statuses
// to be sorted by ID desc, so reverse ids slice.
// https://zchee.github.io/golang-wiki/SliceTricks/#reversing
if !frontToBack {
for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 {
statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l]
}
}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
for _, id := range statusIDs {
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
}

View file

@ -33,99 +33,6 @@ type TimelineTestSuite struct {
BunDBStandardTestSuite
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
var count int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
count++
}
}
ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
var count int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
count++
}
}
ctx := context.Background()
futureStatus := getFutureStatus()
err := suite.db.PutStatus(ctx, futureStatus)
suite.NoError(err)
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
suite.NotContains(s, futureStatus)
suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"]
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
suite.NoError(err)
suite.Len(s, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"]
futureStatus := getFutureStatus()
err := suite.db.PutStatus(ctx, futureStatus)
suite.NoError(err)
s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)
suite.NoError(err)
suite.NotContains(s, futureStatus)
suite.Len(s, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"]
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", id.Lowest, 5, false)
suite.NoError(err)
suite.Len(s, 5)
suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", s[0].ID)
suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"]
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, id.Highest, "", "", 5, false)
suite.NoError(err)
suite.Len(s, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID)
}
func getFutureStatus() *gtsmodel.Status {
theDistantFuture := time.Now().Add(876600 * time.Hour)
id, err := id.NewULIDFromTime(theDistantFuture)
@ -163,6 +70,208 @@ func getFutureStatus() *gtsmodel.Status {
}
}
func (suite *TimelineTestSuite) publicCount() int {
var publicCount int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
publicCount++
}
}
return publicCount
}
func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) {
if l := len(statuses); l != expectedLength {
suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l)
} else if l == 0 {
// Can't test empty slice.
return
}
// Check ordering + bounds of statuses.
highest := statuses[0].ID
for _, status := range statuses {
id := status.ID
if id >= maxID {
suite.FailNow("", "%s greater than maxID %s", id, maxID)
}
if id <= minID {
suite.FailNow("", "%s smaller than minID %s", id, minID)
}
if id > highest {
suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID")
}
highest = id
}
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
ctx := context.Background()
// Insert a status set far in the
// future, it shouldn't be retrieved.
futureStatus := getFutureStatus()
if err := suite.db.PutStatus(ctx, futureStatus); err != nil {
suite.FailNow(err.Error())
}
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotContains(s, futureStatus)
suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
var (
ctx = context.Background()
viewingAccount = suite.testAccounts["local_account_1"]
)
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
var (
ctx = context.Background()
viewingAccount = suite.testAccounts["local_account_1"]
)
// Insert a status set far in the
// future, it shouldn't be retrieved.
futureStatus := getFutureStatus()
if err := suite.db.PutStatus(ctx, futureStatus); err != nil {
suite.FailNow(err.Error())
}
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotContains(s, futureStatus)
suite.checkStatuses(s, id.Highest, id.Lowest, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() {
var (
ctx = context.Background()
viewingAccount = suite.testAccounts["local_account_1"]
)
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", id.Lowest, 5, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", s[0].ID)
suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {
var (
ctx = context.Background()
viewingAccount = suite.testAccounts["local_account_1"]
)
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, id.Highest, "", "", 5, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
var (
ctx = context.Background()
list = suite.testLists["local_account_1_list_1"]
)
s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "", 20)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 11)
}
func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
var (
ctx = context.Background()
list = suite.testLists["local_account_1_list_1"]
)
s, err := suite.db.GetListTimeline(ctx, list.ID, id.Highest, "", "", 5)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineMinID() {
var (
ctx = context.Background()
list = suite.testLists["local_account_1_list_1"]
)
s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", id.Lowest, 5)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01F8MHC8VWDRBQR0N1BATDDEM5", s[0].ID)
suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineMinIDPagingUp() {
var (
ctx = context.Background()
list = suite.testLists["local_account_1_list_1"]
)
s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "01F8MHC8VWDRBQR0N1BATDDEM5", 5)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, "01F8MHC8VWDRBQR0N1BATDDEM5", 5)
suite.Equal("01G20ZM733MGN8J344T4ZDDFY1", s[0].ID)
suite.Equal("01F8MHCP5P2NWYQ416SBA0XSEV", s[len(s)-1].ID)
}
func TestTimelineTestSuite(t *testing.T) {
suite.Run(t, new(TimelineTestSuite))
}

View file

@ -36,6 +36,7 @@ type DB interface {
Domain
Emoji
Instance
List
Media
Mention
Notification

67
internal/db/list.go Normal file
View file

@ -0,0 +1,67 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type List interface {
// GetListByID gets one list with the given id.
GetListByID(ctx context.Context, id string) (*gtsmodel.List, error)
// GetListsForAccountID gets all lists owned by the given accountID.
GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error)
// PopulateList ensures that the list's struct fields are populated.
PopulateList(ctx context.Context, list *gtsmodel.List) error
// PutList puts a new list in the database.
PutList(ctx context.Context, list *gtsmodel.List) error
// UpdateList updates the given list.
// Columns is optional, if not specified all will be updated.
UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error
// DeleteListByID deletes one list with the given ID.
DeleteListByID(ctx context.Context, id string) error
// GetListEntryByID gets one list entry with the given ID.
GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error)
// GetListEntries gets list entries from the given listID, using the given parameters.
GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error)
// GetListEntriesForFollowID returns all listEntries that pertain to the given followID.
GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error)
// PopulateListEntry ensures that the listEntry's struct fields are populated.
PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error
// PutListEntries inserts a slice of listEntries into the database.
// It uses a transaction to ensure no partial updates.
PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error
// DeleteListEntry deletes one list entry with the given id.
DeleteListEntry(ctx context.Context, id string) error
// DeleteListEntryForFollowID deletes all list entries with the given followID.
DeleteListEntriesForFollowID(ctx context.Context, followID string) error
}

View file

@ -64,6 +64,9 @@ type Relationship interface {
// GetFollow retrieves a follow if it exists between source and target accounts.
GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// PopulateFollow populates the struct pointers on the given follow.
PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error
// GetFollowRequestByID fetches follow request with given ID from the database.
GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error)

View file

@ -44,4 +44,8 @@ type Timeline interface {
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
// GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID.
// Statuses should be returned in descending order of when they were created (newest first).
GetListTimeline(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error)
}

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -61,6 +62,13 @@ func (suite *DereferencerStandardTestSuite) SetupTest() {
testrig.StartWorkers(&suite.state)
suite.db = testrig.NewTestDB(&suite.state)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.storage = testrig.NewInMemoryStorage()
suite.state.DB = suite.db
suite.state.Storage = suite.storage

View file

@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -76,8 +77,16 @@ func (suite *FederatingDBTestSuite) SetupTest() {
}
suite.db = testrig.NewTestDB(&suite.state)
suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.federatingDB = testrig.NewTestFederatingDB(&suite.state)
testrig.StandardDBSetup(suite.db, suite.testAccounts)

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -58,6 +59,13 @@ func (suite *FederatorStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.state.DB = suite.db
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}

51
internal/gtsmodel/list.go Normal file
View file

@ -0,0 +1,51 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtsmodel
import "time"
// List refers to a list of follows for which the owning account wants to view a timeline of posts.
type List struct {
ID string `validate:"required,ulid" bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
Title string `validate:"required" bun:",nullzero,notnull,unique:listaccounttitle"` // Title of this list.
AccountID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listaccounttitle"` // Account that created/owns the list
Account *Account `validate:"-" bun:"-"` // Account corresponding to accountID
ListEntries []*ListEntry `validate:"-" bun:"-"` // Entries contained by this list.
RepliesPolicy RepliesPolicy `validate:"-" bun:",nullzero,notnull,default:'followed'"` // RepliesPolicy for this list.
}
// ListEntry refers to a single follow entry in a list.
type ListEntry struct {
ID string `validate:"required,ulid" bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
ListID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listentrylistfollow"` // ID of the list that this entry belongs to.
FollowID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listentrylistfollow"` // Follow that the account owning this entry wants to see posts of in the timeline.
Follow *Follow `validate:"-" bun:"-"` // Follow corresponding to followID.
}
// RepliesPolicy denotes which replies should be shown in the list.
type RepliesPolicy string
const (
RepliesPolicyFollowed RepliesPolicy = "followed" // Show replies to any followed user.
RepliesPolicyList RepliesPolicy = "list" // Show replies to members of the list only.
RepliesPolicyNone RepliesPolicy = "none" // Don't show replies.
)

View file

@ -33,6 +33,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/state"
gtsstorage "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type ManagerTestSuite struct {
@ -395,9 +396,6 @@ func (suite *ManagerTestSuite) TestSlothVineProcessBlocking() {
// fetch the attachment id from the processing media
attachmentID := processingMedia.AttachmentID()
// Give time for processing
time.Sleep(time.Second * 3)
// do a blocking call to fetch the attachment
attachment, err := processingMedia.LoadAttachment(ctx)
suite.NoError(err)
@ -1027,13 +1025,14 @@ func (suite *ManagerTestSuite) TestSimpleJpegProcessAsync() {
// fetch the attachment id from the processing media
attachmentID := processingMedia.AttachmentID()
// Give time for processing to happen.
time.Sleep(time.Second * 3)
// fetch the attachment from the database
attachment, err := suite.db.GetAttachmentByID(ctx, attachmentID)
suite.NoError(err)
suite.NotNil(attachment)
// wait for processing to complete
var attachment *gtsmodel.MediaAttachment
if !testrig.WaitFor(func() bool {
attachment, err = suite.db.GetAttachmentByID(ctx, attachmentID)
return err == nil && attachment != nil
}) {
suite.FailNow("timed out waiting for attachment to process")
}
// make sure it's got the stuff set on it that we expect
// the attachment ID and accountID we expect

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -41,22 +42,27 @@ type MediaStandardTestSuite struct {
testEmojis map[string]*gtsmodel.Emoji
}
func (suite *MediaStandardTestSuite) SetupSuite() {
func (suite *MediaStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
suite.db = testrig.NewTestDB(&suite.state)
suite.storage = testrig.NewInMemoryStorage()
suite.state.DB = suite.db
suite.state.Storage = suite.storage
}
func (suite *MediaStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
testrig.StandardDBSetup(suite.db, nil)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.testAttachments = testrig.NewTestAttachments()
suite.testAccounts = testrig.NewTestAccounts()
suite.testEmojis = testrig.NewTestEmojis()

View file

@ -88,6 +88,13 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)

View file

@ -0,0 +1,107 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package account
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
var noLists = make([]*apimodel.List, 0)
// ListsGet returns all lists owned by requestingAccount, which contain a follow for targetAccountID.
func (p *Processor) ListsGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]*apimodel.List, gtserror.WithCode) {
targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
}
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err))
}
visible, err := p.filter.AccountVisible(ctx, requestingAccount, targetAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err))
}
if !visible {
return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
}
// Requester has to follow targetAccount
// for them to be in any of their lists.
follow, err := p.state.DB.GetFollow(
// Don't populate follow.
gtscontext.SetBarebones(ctx),
requestingAccount.ID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err))
}
if follow == nil {
return noLists, nil // by definition we know they're in no lists
}
listEntries, err := p.state.DB.GetListEntriesForFollowID(
// Don't populate entries.
gtscontext.SetBarebones(ctx),
follow.ID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err))
}
count := len(listEntries)
if count == 0 {
return noLists, nil
}
apiLists := make([]*apimodel.List, 0, count)
for _, listEntry := range listEntries {
list, err := p.state.DB.GetListByID(
// Don't populate list.
gtscontext.SetBarebones(ctx),
listEntry.ListID,
)
if err != nil {
log.Debugf(ctx, "skipping list %s due to error %q", listEntry.ListID, err)
continue
}
apiList, err := p.tc.ListToAPIList(ctx, list)
if err != nil {
log.Debugf(ctx, "skipping list %s due to error %q", listEntry.ListID, err)
continue
}
apiLists = append(apiLists, apiList)
}
return apiLists, nil
}

View file

@ -217,10 +217,10 @@ func (p *Processor) processCreateBlockFromClientAPI(ctx context.Context, clientM
}
// remove any of the blocking account's statuses from the blocked account's timeline, and vice versa
if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
return err
}
if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
return err
}

View file

@ -20,6 +20,7 @@ package processing_test
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/stretchr/testify/suite"
@ -36,24 +37,21 @@ type FromClientAPITestSuite struct {
ProcessingStandardTestSuite
}
// This test ensures that when admin_account posts a new
// status, it ends up in the correct streaming timelines
// of local_account_1, which follows it.
func (suite *FromClientAPITestSuite) TestProcessStreamNewStatus() {
ctx := context.Background()
var (
ctx = context.Background()
postingAccount = suite.testAccounts["admin_account"]
receivingAccount = suite.testAccounts["local_account_1"]
testList = suite.testLists["local_account_1_list_1"]
streams = suite.openStreams(ctx, receivingAccount, []string{testList.ID})
homeStream = streams[stream.TimelineHome]
listStream = streams[stream.TimelineList+":"+testList.ID]
)
// let's say that the admin account posts a new status: it should end up in the
// timeline of any account that follows it and has a stream open
postingAccount := suite.testAccounts["admin_account"]
receivingAccount := suite.testAccounts["local_account_1"]
// open a home timeline stream for zork
wssStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelineHome)
suite.NoError(errWithCode)
// open another stream for zork, but for a different timeline;
// this shouldn't get stuff streamed into it, since it's for the public timeline
irrelevantStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelinePublic)
suite.NoError(errWithCode)
// make a new status from admin account
// Make a new status from admin account.
newStatus := &gtsmodel.Status{
ID: "01FN4B2F88TF9676DYNXWE1WSS",
URI: "http://localhost:8080/users/admin/statuses/01FN4B2F88TF9676DYNXWE1WSS",
@ -82,87 +80,110 @@ func (suite *FromClientAPITestSuite) TestProcessStreamNewStatus() {
ActivityStreamsType: ap.ObjectNote,
}
// put the status in the db first, to mimic what would have already happened earlier up the flow
err := suite.db.PutStatus(ctx, newStatus)
suite.NoError(err)
// Put the status in the db first, to mimic what
// would have already happened earlier up the flow.
if err := suite.db.PutStatus(ctx, newStatus); err != nil {
suite.FailNow(err.Error())
}
// process the new status
err = suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
// Process the new status.
if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
GTSModel: newStatus,
OriginAccount: postingAccount,
})
suite.NoError(err)
}); err != nil {
suite.FailNow(err.Error())
}
// zork's stream should have the newly created status in it now
msg := <-wssStream.Messages
suite.Equal(stream.EventTypeUpdate, msg.Event)
suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
statusStreamed := &apimodel.Status{}
err = json.Unmarshal([]byte(msg.Payload), statusStreamed)
suite.NoError(err)
suite.Equal("01FN4B2F88TF9676DYNXWE1WSS", statusStreamed.ID)
suite.Equal("this status should stream :)", statusStreamed.Content)
// Check message in home stream.
homeMsg := <-homeStream.Messages
suite.Equal(stream.EventTypeUpdate, homeMsg.Event)
suite.EqualValues([]string{stream.TimelineHome}, homeMsg.Stream)
suite.Empty(homeStream.Messages) // Stream should now be empty.
// and stream should now be empty
suite.Empty(wssStream.Messages)
// Check status from home stream.
homeStreamStatus := &apimodel.Status{}
if err := json.Unmarshal([]byte(homeMsg.Payload), homeStreamStatus); err != nil {
suite.FailNow(err.Error())
}
suite.Equal(newStatus.ID, homeStreamStatus.ID)
suite.Equal(newStatus.Content, homeStreamStatus.Content)
// the irrelevant messages stream should also be empty
suite.Empty(irrelevantStream.Messages)
// Check message in list stream.
listMsg := <-listStream.Messages
suite.Equal(stream.EventTypeUpdate, listMsg.Event)
suite.EqualValues([]string{stream.TimelineList + ":" + testList.ID}, listMsg.Stream)
suite.Empty(listStream.Messages) // Stream should now be empty.
// Check status from list stream.
listStreamStatus := &apimodel.Status{}
if err := json.Unmarshal([]byte(listMsg.Payload), listStreamStatus); err != nil {
suite.FailNow(err.Error())
}
suite.Equal(newStatus.ID, listStreamStatus.ID)
suite.Equal(newStatus.Content, listStreamStatus.Content)
}
func (suite *FromClientAPITestSuite) TestProcessStatusDelete() {
ctx := context.Background()
var (
ctx = context.Background()
deletingAccount = suite.testAccounts["local_account_1"]
receivingAccount = suite.testAccounts["local_account_2"]
deletedStatus = suite.testStatuses["local_account_1_status_1"]
boostOfDeletedStatus = suite.testStatuses["admin_account_status_4"]
streams = suite.openStreams(ctx, receivingAccount, nil)
homeStream = streams[stream.TimelineHome]
)
deletingAccount := suite.testAccounts["local_account_1"]
receivingAccount := suite.testAccounts["local_account_2"]
// Delete the status from the db first, to mimic what
// would have already happened earlier up the flow
if err := suite.db.DeleteStatusByID(ctx, deletedStatus.ID); err != nil {
suite.FailNow(err.Error())
}
deletedStatus := suite.testStatuses["local_account_1_status_1"]
boostOfDeletedStatus := suite.testStatuses["admin_account_status_4"]
// open a home timeline stream for turtle, who follows zork
wssStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelineHome)
suite.NoError(errWithCode)
// delete the status from the db first, to mimic what would have already happened earlier up the flow
err := suite.db.DeleteStatusByID(ctx, deletedStatus.ID)
suite.NoError(err)
// process the status delete
err = suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
// Process the status delete.
if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityDelete,
GTSModel: deletedStatus,
OriginAccount: deletingAccount,
})
suite.NoError(err)
}); err != nil {
suite.FailNow(err.Error())
}
// turtle's stream should have the delete of admin's boost in it now
msg := <-wssStream.Messages
// Stream should have the delete of admin's boost in it now.
msg := <-homeStream.Messages
suite.Equal(stream.EventTypeDelete, msg.Event)
suite.Equal(boostOfDeletedStatus.ID, msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
// turtle's stream should also have the delete of the message itself in it
msg = <-wssStream.Messages
// Stream should also have the delete of the message itself in it.
msg = <-homeStream.Messages
suite.Equal(stream.EventTypeDelete, msg.Event)
suite.Equal(deletedStatus.ID, msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
// stream should now be empty
suite.Empty(wssStream.Messages)
// Stream should now be empty.
suite.Empty(homeStream.Messages)
// the boost should no longer be in the database
_, err = suite.db.GetStatusByID(ctx, boostOfDeletedStatus.ID)
suite.ErrorIs(err, db.ErrNoEntries)
// Boost should no longer be in the database.
if !testrig.WaitFor(func() bool {
_, err := suite.db.GetStatusByID(ctx, boostOfDeletedStatus.ID)
return errors.Is(err, db.ErrNoEntries)
}) {
suite.FailNow("timed out waiting for status delete")
}
}
func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() {
ctx := context.Background()
postingAccount := suite.testAccounts["admin_account"]
receivingAccount := suite.testAccounts["local_account_1"]
var (
ctx = context.Background()
postingAccount = suite.testAccounts["admin_account"]
receivingAccount = suite.testAccounts["local_account_1"]
streams = suite.openStreams(ctx, receivingAccount, nil)
notifStream = streams[stream.TimelineNotifications]
)
// Update the follow from receiving account -> posting account so
// that receiving account wants notifs when posting account posts.
@ -204,8 +225,9 @@ func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() {
// Put the status in the db first, to mimic what
// would have already happened earlier up the flow.
err := suite.db.PutStatus(ctx, newStatus)
suite.NoError(err)
if err := suite.db.PutStatus(ctx, newStatus); err != nil {
suite.FailNow(err.Error())
}
// Process the new status.
if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
@ -230,6 +252,19 @@ func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() {
}) {
suite.FailNow("timed out waiting for new status notification")
}
// Check message in notification stream.
notifMsg := <-notifStream.Messages
suite.Equal(stream.EventTypeNotification, notifMsg.Event)
suite.EqualValues([]string{stream.TimelineNotifications}, notifMsg.Stream)
suite.Empty(notifStream.Messages) // Stream should now be empty.
// Check notif.
notif := &apimodel.Notification{}
if err := json.Unmarshal([]byte(notifMsg.Payload), notif); err != nil {
suite.FailNow(err.Error())
}
suite.Equal(newStatus.ID, notif.Status.ID)
}
func TestFromClientAPITestSuite(t *testing.T) {

View file

@ -30,12 +30,14 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/stream"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
)
// timelineAndNotifyStatus processes the given new status and inserts it into
// the HOME timelines of accounts that follow the status author. It will also
// handle notifications for any mentions attached to the account, and also
// notifications for any local accounts that want a notif when this account posts.
// the HOME and LIST timelines of accounts that follow the status author.
//
// It will also handle notifications for any mentions attached to the account, and
// also notifications for any local accounts that want to know when this account posts.
func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmodel.Status) error {
// Ensure status fully populated; including account, mentions, etc.
if err := p.state.DB.PopulateStatus(ctx, status); err != nil {
@ -89,10 +91,43 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta
continue
}
// Add status to each list that this follow
// is included in, and stream it if applicable.
listEntries, err := p.state.DB.GetListEntriesForFollowID(
// We only need the list IDs.
gtscontext.SetBarebones(ctx),
follow.ID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err))
continue
}
for _, listEntry := range listEntries {
if _, err := p.timelineStatus(
ctx,
p.state.Timelines.List.IngestOne,
listEntry.ListID, // list timelines are keyed by list ID
follow.Account,
status,
stream.TimelineList+":"+listEntry.ListID, // key streamType to this specific list
); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err))
continue
}
}
// Add status to home timeline for this
// follower, and stream it if applicable.
if timelined, err := p.timelineStatusForAccount(ctx, follow.Account, status); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error timelining status: %w", err))
if timelined, err := p.timelineStatus(
ctx,
p.state.Timelines.Home.IngestOne,
follow.AccountID, // home timelines are keyed by account ID
follow.Account,
status,
stream.TimelineHome,
); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error home timelining status: %w", err))
continue
} else if !timelined {
// Status wasn't added to home tomeline,
@ -133,13 +168,21 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta
return errs.Combine()
}
// timelineStatusForAccount puts the given status in the HOME timeline
// of the account with given accountID, if it's HomeTimelineable.
// timelineStatus uses the provided ingest function to put the given
// status in a timeline with the given ID, if it's timelineable.
//
// If the status was inserted into the home timeline of the given account,
// true will be returned + it will also be streamed via websockets to the user.
func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
// If the status was inserted into the timeline, true will be returned
// + it will also be streamed to the user using the given streamType.
func (p *Processor) timelineStatus(
ctx context.Context,
ingest func(context.Context, string, timeline.Timelineable) (bool, error),
timelineID string,
account *gtsmodel.Account,
status *gtsmodel.Status,
streamType string,
) (bool, error) {
// Make sure the status is timelineable.
// This works for both home and list timelines.
if timelineable, err := p.filter.StatusHomeTimelineable(ctx, account, status); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", account.ID, err)
return false, err
@ -148,8 +191,8 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmo
return false, nil
}
// Insert status in the home timeline of account.
if inserted, err := p.statusTimelines.IngestOne(ctx, account.ID, status); err != nil {
// Ingest status into given timeline using provided function.
if inserted, err := ingest(ctx, timelineID, status); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %w", status.ID, err)
return false, err
} else if !inserted {
@ -164,7 +207,7 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmo
return true, err
}
if err := p.stream.Update(apiStatus, account, stream.TimelineHome); err != nil {
if err := p.stream.Update(apiStatus, account, []string{streamType}); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error streaming update for status %s: %w", status.ID, err)
return true, err
}
@ -401,7 +444,7 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
// deleteStatusFromTimelines completely removes the given status from all timelines.
// It will also stream deletion of the status to all open streams.
func (p *Processor) deleteStatusFromTimelines(ctx context.Context, status *gtsmodel.Status) error {
if err := p.statusTimelines.WipeItemFromAllTimelines(ctx, status.ID); err != nil {
if err := p.state.Timelines.Home.WipeItemFromAllTimelines(ctx, status.ID); err != nil {
return err
}

View file

@ -342,10 +342,10 @@ func (p *Processor) processCreateBlockFromFederator(ctx context.Context, federat
}
// remove any of the blocking account's statuses from the blocked account's timeline, and vice versa
if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
return err
}
if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
return err
}
// TODO: same with notifications

View file

@ -0,0 +1,50 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
// Create creates one a new list for the given account, using the provided parameters.
// These params should have already been validated by the time they reach this function.
func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, title string, repliesPolicy gtsmodel.RepliesPolicy) (*apimodel.List, gtserror.WithCode) {
list := &gtsmodel.List{
ID: id.NewULID(),
Title: title,
AccountID: account.ID,
RepliesPolicy: repliesPolicy,
}
if err := p.state.DB.PutList(ctx, list); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
err = errors.New("you already have a list with this title")
return nil, gtserror.NewErrorConflict(err, err.Error())
}
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiList(ctx, list)
}

View file

@ -0,0 +1,46 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Delete deletes one list for the given account.
func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, id string) gtserror.WithCode {
list, errWithCode := p.getList(
// Use barebones ctx; no embedded
// structs necessary for this call.
gtscontext.SetBarebones(ctx),
account.ID,
id,
)
if errWithCode != nil {
return errWithCode
}
if err := p.state.DB.DeleteListByID(ctx, list.ID); err != nil {
return gtserror.NewErrorInternalError(err)
}
return nil
}

View file

@ -0,0 +1,155 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// Get returns the api model of one list with the given ID.
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.List, gtserror.WithCode) {
list, errWithCode := p.getList(
// Use barebones ctx; no embedded
// structs necessary for this call.
gtscontext.SetBarebones(ctx),
account.ID,
id,
)
if errWithCode != nil {
return nil, errWithCode
}
return p.apiList(ctx, list)
}
// GetMultiple returns multiple lists created by the given account, sorted by list ID DESC (newest first).
func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.List, gtserror.WithCode) {
lists, err := p.state.DB.GetListsForAccountID(
// Use barebones ctx; no embedded
// structs necessary for simple GET.
gtscontext.SetBarebones(ctx),
account.ID,
)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, nil
}
return nil, gtserror.NewErrorInternalError(err)
}
apiLists := make([]*apimodel.List, 0, len(lists))
for _, list := range lists {
apiList, errWithCode := p.apiList(ctx, list)
if errWithCode != nil {
return nil, errWithCode
}
apiLists = append(apiLists, apiList)
}
return apiLists, nil
}
// GetListAccounts returns accounts that are in the given list, owned by the given account.
// The additional parameters can be used for paging.
func (p *Processor) GetListAccounts(
ctx context.Context,
account *gtsmodel.Account,
listID string,
maxID string,
sinceID string,
minID string,
limit int,
) (*apimodel.PageableResponse, gtserror.WithCode) {
// Ensure list exists + is owned by requesting account.
if _, errWithCode := p.getList(ctx, account.ID, listID); errWithCode != nil {
return nil, errWithCode
}
// To know which accounts are in the list,
// we need to first get requested list entries.
listEntries, err := p.state.DB.GetListEntries(ctx, listID, maxID, sinceID, minID, limit)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("GetListAccounts: error getting list entries: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(listEntries)
if count == 0 {
// No list entries means no accounts.
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, count)
nextMaxIDValue string
prevMinIDValue string
)
// For each list entry, we want the account it points to.
// To get this, we need to first get the follow that the
// list entry pertains to, then extract the target account
// from that follow.
//
// We do paging not by account ID, but by list entry ID.
for i, listEntry := range listEntries {
if i == count-1 {
nextMaxIDValue = listEntry.ID
}
if i == 0 {
prevMinIDValue = listEntry.ID
}
if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
log.Debugf(ctx, "skipping list entry because of error populating it: %q", err)
continue
}
if err := p.state.DB.PopulateFollow(ctx, listEntry.Follow); err != nil {
log.Debugf(ctx, "skipping list entry because of error populating follow: %q", err)
continue
}
apiAccount, err := p.tc.AccountToAPIAccountPublic(ctx, listEntry.Follow.TargetAccount)
if err != nil {
log.Debugf(ctx, "skipping list entry because of error converting follow target account: %q", err)
continue
}
items[i] = apiAccount
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/lists/" + listID + "/accounts",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}

View file

@ -0,0 +1,35 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
)
type Processor struct {
state *state.State
tc typeutils.TypeConverter
}
func New(state *state.State, tc typeutils.TypeConverter) Processor {
return Processor{
state: state,
tc: tc,
}
}

View file

@ -0,0 +1,73 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Update updates one list for the given account, using the provided parameters.
// These params should have already been validated by the time they reach this function.
func (p *Processor) Update(
ctx context.Context,
account *gtsmodel.Account,
id string,
title *string,
repliesPolicy *gtsmodel.RepliesPolicy,
) (*apimodel.List, gtserror.WithCode) {
list, errWithCode := p.getList(
// Use barebones ctx; no embedded
// structs necessary for this call.
gtscontext.SetBarebones(ctx),
account.ID,
id,
)
if errWithCode != nil {
return nil, errWithCode
}
// Only update columns we're told to update.
columns := make([]string, 0, 2)
if title != nil {
list.Title = *title
columns = append(columns, "title")
}
if repliesPolicy != nil {
list.RepliesPolicy = *repliesPolicy
columns = append(columns, "replies_policy")
}
if err := p.state.DB.UpdateList(ctx, list, columns...); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
err = errors.New("you already have a list with this title")
return nil, gtserror.NewErrorConflict(err, err.Error())
}
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiList(ctx, list)
}

View file

@ -0,0 +1,151 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
// AddToList adds targetAccountIDs to the given list, if valid.
func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode {
// Ensure this list exists + account owns it.
list, errWithCode := p.getList(ctx, account.ID, listID)
if errWithCode != nil {
return errWithCode
}
// Pre-assemble list of entries to add. We *could* add these
// one by one as we iterate through accountIDs, but according
// to the Mastodon API we should only add them all once we know
// they're all valid, no partial updates.
listEntries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs))
// Check each targetAccountID is valid.
// - Follow must exist.
// - Follow must not already be in the given list.
for _, targetAccountID := range targetAccountIDs {
// Ensure follow exists.
follow, err := p.state.DB.GetFollow(ctx, account.ID, targetAccountID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("you do not follow account %s", targetAccountID)
return gtserror.NewErrorNotFound(err, err.Error())
}
return gtserror.NewErrorInternalError(err)
}
// Ensure followID not already in list.
// This particular call to isInList will
// never error, so just check entryID.
entryID, _ := isInList(
list,
follow.ID,
func(listEntry *gtsmodel.ListEntry) (string, error) {
// Looking for the listEntry follow ID.
return listEntry.FollowID, nil
},
)
// Empty entryID means entry with given
// followID wasn't found in the list.
if entryID != "" {
err = fmt.Errorf("account with id %s is already in list %s with entryID %s", targetAccountID, listID, entryID)
return gtserror.NewErrorUnprocessableEntity(err, err.Error())
}
// Entry wasn't in the list, we can add it.
listEntries = append(listEntries, &gtsmodel.ListEntry{
ID: id.NewULID(),
ListID: listID,
FollowID: follow.ID,
})
}
// If we get to here we can assume all
// entries are valid, so try to add them.
if err := p.state.DB.PutListEntries(ctx, listEntries); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
err = fmt.Errorf("one or more errors inserting list entries: %w", err)
return gtserror.NewErrorUnprocessableEntity(err, err.Error())
}
return gtserror.NewErrorInternalError(err)
}
return nil
}
// RemoveFromList removes targetAccountIDs from the given list, if valid.
func (p *Processor) RemoveFromList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode {
// Ensure this list exists + account owns it.
list, errWithCode := p.getList(ctx, account.ID, listID)
if errWithCode != nil {
return errWithCode
}
// For each targetAccountID, we want to check if
// a follow with that targetAccountID is in the
// given list. If it is in there, we want to remove
// it from the list.
for _, targetAccountID := range targetAccountIDs {
// Check if targetAccountID is
// on a follow in the list.
entryID, err := isInList(
list,
targetAccountID,
func(listEntry *gtsmodel.ListEntry) (string, error) {
// We need the follow so populate this
// entry, if it's not already populated.
if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
return "", err
}
// Looking for the list entry targetAccountID.
return listEntry.Follow.TargetAccountID, nil
},
)
// Error may be returned here if there was an issue
// populating the list entry. We only return on proper
// DB errors, we can just skip no entry errors.
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("error checking if targetAccountID %s was in list %s: %w", targetAccountID, listID, err)
return gtserror.NewErrorInternalError(err)
}
if entryID == "" {
// There was an errNoEntries or targetAccount
// wasn't in this list anyway, so we can skip it.
continue
}
// TargetAccount was in the list, remove the entry.
if err := p.state.DB.DeleteListEntry(ctx, entryID); err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("error removing list entry %s from list %s: %w", entryID, listID, err)
return gtserror.NewErrorInternalError(err)
}
}
return nil
}

View file

@ -0,0 +1,85 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// getList is a shortcut to get one list from the database and
// check that it's owned by the given accountID. Will return
// appropriate errors so caller doesn't need to bother.
func (p *Processor) getList(ctx context.Context, accountID string, listID string) (*gtsmodel.List, gtserror.WithCode) {
list, err := p.state.DB.GetListByID(ctx, listID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// List doesn't seem to exist.
return nil, gtserror.NewErrorNotFound(err)
}
// Real database error.
return nil, gtserror.NewErrorInternalError(err)
}
if list.AccountID != accountID {
err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, accountID)
return nil, gtserror.NewErrorNotFound(err)
}
return list, nil
}
// apiList is a shortcut to return the API version of the given
// list, or return an appropriate error if conversion fails.
func (p *Processor) apiList(ctx context.Context, list *gtsmodel.List) (*apimodel.List, gtserror.WithCode) {
apiList, err := p.tc.ListToAPIList(ctx, list)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting list to api: %w", err))
}
return apiList, nil
}
// isInList check if thisID is equal to the result of thatID
// for any entry in the given list.
//
// Will return the id of the listEntry if true, empty if false,
// or an error if the result of thatID returns an error.
func isInList(
list *gtsmodel.List,
thisID string,
getThatID func(listEntry *gtsmodel.ListEntry) (string, error),
) (string, error) {
for _, listEntry := range list.ListEntries {
thatID, err := getThatID(listEntry)
if err != nil {
return "", err
}
if thisID == thatID {
return listEntry.ID, nil
}
}
return "", nil
}

View file

@ -1,52 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package processing_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
)
type NotificationTestSuite struct {
ProcessingStandardTestSuite
}
// get a notification where someone has liked our status
func (suite *NotificationTestSuite) TestGetNotifications() {
receivingAccount := suite.testAccounts["local_account_1"]
notifsResponse, err := suite.processor.NotificationsGet(context.Background(), suite.testAutheds["local_account_1"], "", "", "", 10, nil)
suite.NoError(err)
suite.Len(notifsResponse.Items, 1)
notif, ok := notifsResponse.Items[0].(*apimodel.Notification)
if !ok {
panic("notif in response wasn't *apimodel.Notification")
}
suite.NotNil(notif.Status)
suite.NotNil(notif.Status)
suite.NotNil(notif.Status.Account)
suite.Equal(receivingAccount.ID, notif.Status.Account.ID)
suite.Equal(`<http://localhost:8080/api/v1/notifications?limit=10&max_id=01F8Q0ANPTWW10DAKTX7BRPBJP>; rel="next", <http://localhost:8080/api/v1/notifications?limit=10&min_id=01F8Q0ANPTWW10DAKTX7BRPBJP>; rel="prev"`, notifsResponse.LinkHeader)
}
func TestNotificationTestSuite(t *testing.T) {
suite.Run(t, &NotificationTestSuite{})
}

View file

@ -29,13 +29,14 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/processing/admin"
"github.com/superseriousbusiness/gotosocial/internal/processing/fedi"
"github.com/superseriousbusiness/gotosocial/internal/processing/list"
"github.com/superseriousbusiness/gotosocial/internal/processing/media"
"github.com/superseriousbusiness/gotosocial/internal/processing/report"
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
"github.com/superseriousbusiness/gotosocial/internal/processing/stream"
"github.com/superseriousbusiness/gotosocial/internal/processing/timeline"
"github.com/superseriousbusiness/gotosocial/internal/processing/user"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
@ -45,7 +46,6 @@ type Processor struct {
tc typeutils.TypeConverter
oauthServer oauth.Server
mediaManager mm.Manager
statusTimelines timeline.Manager
state *state.State
emailSender email.Sender
filter *visibility.Filter
@ -57,10 +57,12 @@ type Processor struct {
account account.Processor
admin admin.Processor
fedi fedi.Processor
list list.Processor
media media.Processor
report report.Processor
status status.Processor
stream stream.Processor
timeline timeline.Processor
user user.Processor
}
@ -76,6 +78,10 @@ func (p *Processor) Fedi() *fedi.Processor {
return &p.fedi
}
func (p *Processor) List() *list.Processor {
return &p.list
}
func (p *Processor) Media() *media.Processor {
return &p.media
}
@ -92,6 +98,10 @@ func (p *Processor) Stream() *stream.Processor {
return &p.stream
}
func (p *Processor) Timeline() *timeline.Processor {
return &p.timeline
}
func (p *Processor) User() *user.Processor {
return &p.user
}
@ -114,23 +124,19 @@ func NewProcessor(
tc: tc,
oauthServer: oauthServer,
mediaManager: mediaManager,
statusTimelines: timeline.NewManager(
StatusGrabFunction(state.DB),
StatusFilterFunction(state.DB, filter),
StatusPrepareFunction(state.DB, tc),
StatusSkipInsertFunction(),
),
state: state,
filter: filter,
emailSender: emailSender,
}
// sub processors
// Instantiate sub processors.
processor.account = account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc)
processor.admin = admin.New(state, tc, mediaManager, federator.TransportController(), emailSender)
processor.fedi = fedi.New(state, tc, federator, filter)
processor.list = list.New(state, tc)
processor.media = media.New(state, tc, mediaManager, federator.TransportController())
processor.report = report.New(state, tc)
processor.timeline = timeline.New(state, tc, filter)
processor.status = status.New(state, federator, tc, filter, parseMentionFunc)
processor.stream = stream.New(state, oauthServer)
processor.user = user.New(state, emailSender)
@ -161,13 +167,3 @@ func (p *Processor) EnqueueFederator(ctx context.Context, msgs ...messages.FromF
}
})
}
// Start starts the Processor.
func (p *Processor) Start() error {
return p.statusTimelines.Start()
}
// Stop stops the processor cleanly.
func (p *Processor) Stop() error {
return p.statusTimelines.Stop()
}

View file

@ -18,6 +18,8 @@
package processing_test
import (
"context"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -28,8 +30,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/stream"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -61,6 +65,7 @@ type ProcessingStandardTestSuite struct {
testAutheds map[string]*oauth.Auth
testBlocks map[string]*gtsmodel.Block
testActivities map[string]testrig.ActivityWithSignature
testLists map[string]*gtsmodel.List
processor *processing.Processor
}
@ -84,6 +89,7 @@ func (suite *ProcessingStandardTestSuite) SetupSuite() {
},
}
suite.testBlocks = testrig.NewTestBlocks()
suite.testLists = testrig.NewTestLists()
}
func (suite *ProcessingStandardTestSuite) SetupTest() {
@ -99,6 +105,13 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.typeconverter = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.typeconverter,
)
suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media")
suite.httpClient.TestRemotePeople = testrig.NewTestFediPeople()
suite.httpClient.TestRemoteStatuses = testrig.NewTestFediStatuses()
@ -115,16 +128,40 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
if err := suite.processor.Start(); err != nil {
panic(err)
}
}
func (suite *ProcessingStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
if err := suite.processor.Stop(); err != nil {
panic(err)
}
testrig.StopWorkers(&suite.state)
}
func (suite *ProcessingStandardTestSuite) openStreams(ctx context.Context, account *gtsmodel.Account, listIDs []string) map[string]*stream.Stream {
streams := make(map[string]*stream.Stream)
for _, streamType := range []string{
stream.TimelineHome,
stream.TimelinePublic,
stream.TimelineNotifications,
} {
stream, err := suite.processor.Stream().Open(ctx, account, streamType)
if err != nil {
suite.FailNow(err.Error())
}
streams[streamType] = stream
}
for _, listID := range listIDs {
streamType := stream.TimelineList + ":" + listID
stream, err := suite.processor.Stream().Open(ctx, account, streamType)
if err != nil {
suite.FailNow(err.Error())
}
streams[streamType] = stream
}
return streams
}

View file

@ -88,6 +88,12 @@ func (suite *StatusStandardTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager)
filter := visibility.NewFilter(&suite.state)
testrig.StartTimelines(
&suite.state,
filter,
testrig.NewTestTypeConverter(suite.db),
)
suite.status = status.New(&suite.state, suite.federator, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, suite.testAccounts)

View file

@ -1,309 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package processing
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
const boostReinsertionDepth = 50
// StatusGrabFunction returns a function that satisfies the GrabFunction interface in internal/timeline.
func StatusGrabFunction(database db.DB) timeline.GrabFunction {
return func(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) {
statuses, err := database.GetHomeTimeline(ctx, timelineAccountID, maxID, sinceID, minID, limit, false)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, true, nil // we just don't have enough statuses left in the db so return stop = true
}
return nil, false, fmt.Errorf("statusGrabFunction: error getting statuses from db: %w", err)
}
items := make([]timeline.Timelineable, len(statuses))
for i, s := range statuses {
items[i] = s
}
return items, false, nil
}
}
// StatusFilterFunction returns a function that satisfies the FilterFunction interface in internal/timeline.
func StatusFilterFunction(database db.DB, filter *visibility.Filter) timeline.FilterFunction {
return func(ctx context.Context, timelineAccountID string, item timeline.Timelineable) (shouldIndex bool, err error) {
status, ok := item.(*gtsmodel.Status)
if !ok {
return false, errors.New("StatusFilterFunction: could not convert item to *gtsmodel.Status")
}
requestingAccount, err := database.GetAccountByID(ctx, timelineAccountID)
if err != nil {
return false, fmt.Errorf("StatusFilterFunction: error getting account with id %s: %w", timelineAccountID, err)
}
timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status)
if err != nil {
return false, fmt.Errorf("StatusFilterFunction: error checking hometimelineability of status %s for account %s: %w", status.ID, timelineAccountID, err)
}
return timelineable, nil
}
}
// StatusPrepareFunction returns a function that satisfies the PrepareFunction interface in internal/timeline.
func StatusPrepareFunction(database db.DB, tc typeutils.TypeConverter) timeline.PrepareFunction {
return func(ctx context.Context, timelineAccountID string, itemID string) (timeline.Preparable, error) {
status, err := database.GetStatusByID(ctx, itemID)
if err != nil {
return nil, fmt.Errorf("StatusPrepareFunction: error getting status with id %s: %w", itemID, err)
}
requestingAccount, err := database.GetAccountByID(ctx, timelineAccountID)
if err != nil {
return nil, fmt.Errorf("StatusPrepareFunction: error getting account with id %s: %w", timelineAccountID, err)
}
return tc.StatusToAPIStatus(ctx, status, requestingAccount)
}
}
// StatusSkipInsertFunction returns a function that satisifes the SkipInsertFunction interface in internal/timeline.
func StatusSkipInsertFunction() timeline.SkipInsertFunction {
return func(
ctx context.Context,
newItemID string,
newItemAccountID string,
newItemBoostOfID string,
newItemBoostOfAccountID string,
nextItemID string,
nextItemAccountID string,
nextItemBoostOfID string,
nextItemBoostOfAccountID string,
depth int,
) (bool, error) {
// make sure we don't insert a duplicate
if newItemID == nextItemID {
return true, nil
}
// check if it's a boost
if newItemBoostOfID != "" {
// skip if we've recently put another boost of this status in the timeline
if newItemBoostOfID == nextItemBoostOfID {
if depth < boostReinsertionDepth {
return true, nil
}
}
// skip if we've recently put the original status in the timeline
if newItemBoostOfID == nextItemID {
if depth < boostReinsertionDepth {
return true, nil
}
}
}
// insert the item
return false, nil
}
}
func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, err := p.statusTimelines.GetTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local)
if err != nil {
err = fmt.Errorf("HomeTimelineGet: error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, item := range statuses {
if i == count-1 {
nextMaxIDValue = item.GetID()
}
if i == 0 {
prevMinIDValue = item.GetID()
}
items[i] = item
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/home",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}
func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// No statuses (left) in public timeline.
return util.EmptyPageableResponse(), nil
}
// An actual error has occurred.
err = fmt.Errorf("PublicTimelineGet: db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, 0, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, s := range statuses {
// Set next + prev values before filtering and API
// converting, so caller can still page properly.
if i == count-1 {
nextMaxIDValue = s.ID
}
if i == 0 {
prevMinIDValue = s.ID
}
timelineable, err := p.filter.StatusPublicTimelineable(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking StatusPublicTimelineable: %s", s.ID, err)
continue
}
if !timelineable {
continue
}
apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account)
if err != nil {
log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err)
continue
}
items = append(items, apiStatus)
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/public",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}
func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// There are just no entries (left).
return util.EmptyPageableResponse(), nil
}
// An actual error has occurred.
err = fmt.Errorf("FavedTimelineGet: db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
filtered, err := p.filterFavedStatuses(ctx, authed, statuses)
if err != nil {
err = fmt.Errorf("FavedTimelineGet: error filtering statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
items := make([]interface{}, len(filtered))
for i, item := range filtered {
items[i] = item
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/favourites",
NextMaxIDValue: nextMaxID,
PrevMinIDValue: prevMinID,
Limit: limit,
})
}
func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
apiStatuses := make([]*apimodel.Status, 0, len(statuses))
for _, s := range statuses {
if _, err := p.state.DB.GetAccountByID(ctx, s.AccountID); err != nil {
if errors.Is(err, db.ErrNoEntries) {
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
}
err = fmt.Errorf("filterFavedStatuses: db error getting status author: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
timelineable, err := p.filter.StatusVisible(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
}
if !timelineable {
continue
}
apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account)
if err != nil {
log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err)
continue
}
apiStatuses = append(apiStatuses, apiStatus)
}
return apiStatuses, nil
}

View file

@ -31,60 +31,65 @@ import (
)
// Open returns a new Stream for the given account, which will contain a channel for passing messages back to the caller.
func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamTimeline string) (*stream.Stream, gtserror.WithCode) {
func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamType string) (*stream.Stream, gtserror.WithCode) {
l := log.WithContext(ctx).WithFields(kv.Fields{
{"account", account.ID},
{"streamType", streamTimeline},
{"streamType", streamType},
}...)
l.Debug("received open stream request")
// each stream needs a unique ID so we know to close it
streamID, err := id.NewRandomULID()
var (
streamID string
err error
)
// Each stream needs a unique ID so we know to close it.
streamID, err = id.NewRandomULID()
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err))
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err))
}
// Each stream can be subscibed to multiple timelines.
// Each stream can be subscibed to multiple types.
// Record them in a set, and include the initial one
// if it was given to us
timelines := map[string]bool{}
if streamTimeline != "" {
timelines[streamTimeline] = true
// if it was given to us.
streamTypes := map[string]any{}
if streamType != "" {
streamTypes[streamType] = true
}
thisStream := &stream.Stream{
newStream := &stream.Stream{
ID: streamID,
Timelines: timelines,
StreamTypes: streamTypes,
Messages: make(chan *stream.Message, 100),
Hangup: make(chan interface{}, 1),
Connected: true,
}
go p.waitToCloseStream(account, thisStream)
go p.waitToCloseStream(account, newStream)
v, ok := p.streamMap.Load(account.ID)
if !ok || v == nil {
// there is no entry in the streamMap for this account yet, so make one and store it
streamsForAccount := &stream.StreamsForAccount{
Streams: []*stream.Stream{
thisStream,
},
}
p.streamMap.Store(account.ID, streamsForAccount)
} else {
// there is an entry in the streamMap for this account
// parse the interface as a streamsForAccount
if ok {
// There is an entry in the streamMap
// for this account. Parse it out.
streamsForAccount, ok := v.(*stream.StreamsForAccount)
if !ok {
return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
}
// append this stream to it
// Append new stream to existing entry.
streamsForAccount.Lock()
streamsForAccount.Streams = append(streamsForAccount.Streams, thisStream)
streamsForAccount.Streams = append(streamsForAccount.Streams, newStream)
streamsForAccount.Unlock()
} else {
// There is no entry in the streamMap for
// this account yet. Create one and store it.
p.streamMap.Store(account.ID, &stream.StreamsForAccount{
Streams: []*stream.Stream{
newStream,
},
})
}
return thisStream, nil
return newStream, nil
}
// waitToCloseStream waits until the hangup channel is closed for the given stream.

View file

@ -18,7 +18,6 @@
package stream
import (
"errors"
"sync"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -40,37 +39,38 @@ func New(state *state.State, oauthServer oauth.Server) Processor {
}
// toAccount streams the given payload with the given event type to any streams currently open for the given account ID.
func (p *Processor) toAccount(payload string, event string, timelines []string, accountID string) error {
func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error {
// Load all streams open for this account.
v, ok := p.streamMap.Load(accountID)
if !ok {
// no open connections so nothing to stream
return nil
}
streamsForAccount, ok := v.(*stream.StreamsForAccount)
if !ok {
return errors.New("stream map error")
return nil // No entry = nothing to stream.
}
streamsForAccount := v.(*stream.StreamsForAccount) //nolint:forcetypeassert
streamsForAccount.Lock()
defer streamsForAccount.Unlock()
for _, s := range streamsForAccount.Streams {
s.Lock()
defer s.Unlock()
if !s.Connected {
continue
}
for _, t := range timelines {
if _, found := s.Timelines[t]; found {
typeLoop:
for _, streamType := range streamTypes {
if _, found := s.StreamTypes[streamType]; found {
s.Messages <- &stream.Message{
Stream: []string{string(t)},
Stream: []string{streamType},
Event: string(event),
Payload: payload,
}
// break out to the outer loop, to avoid sending duplicates
// of the same event to the same stream
break
// Break out to the outer loop,
// to avoid sending duplicates of
// the same event to the same stream.
break typeLoop
}
}
}

View file

@ -27,11 +27,11 @@ import (
)
// Update streams the given update to any open, appropriate streams belonging to the given account.
func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, timeline string) error {
func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
bytes, err := json.Marshal(s)
if err != nil {
return fmt.Errorf("error marshalling status to json: %s", err)
}
return p.toAccount(string(bytes), stream.EventTypeUpdate, []string{timeline}, account.ID)
return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID)
}

View file

@ -0,0 +1,71 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
)
// SkipInsert returns a function that satisifes SkipInsertFunction.
func SkipInsert() timeline.SkipInsertFunction {
// Gap to allow between a status or boost of status,
// and reinsertion of a new boost of that status.
// This is useful to avoid a heavily boosted status
// showing up way too often in a user's timeline.
const boostReinsertionDepth = 50
return func(
ctx context.Context,
newItemID string,
newItemAccountID string,
newItemBoostOfID string,
newItemBoostOfAccountID string,
nextItemID string,
nextItemAccountID string,
nextItemBoostOfID string,
nextItemBoostOfAccountID string,
depth int,
) (bool, error) {
if newItemID == nextItemID {
// Don't insert duplicates.
return true, nil
}
if newItemBoostOfID != "" {
if newItemBoostOfID == nextItemBoostOfID &&
depth < boostReinsertionDepth {
// Don't insert boosts of items
// we've seen boosted recently.
return true, nil
}
if newItemBoostOfID == nextItemID &&
depth < boostReinsertionDepth {
// Don't insert boosts of items when
// we've seen the original recently.
return true, nil
}
}
// Proceed with insertion
// (that's what she said!).
return false, nil
}
}

View file

@ -0,0 +1,73 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FavedTimelineGet: db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
items := make([]interface{}, 0, count)
for _, s := range statuses {
visible, err := p.filter.StatusVisible(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
}
if !visible {
continue
}
apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account)
if err != nil {
log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err)
continue
}
items = append(items, apiStatus)
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/favourites",
NextMaxIDValue: nextMaxID,
PrevMinIDValue: prevMinID,
Limit: limit,
})
}

View file

@ -0,0 +1,133 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
// HomeTimelineGrab returns a function that satisfies GrabFunction for home timelines.
func HomeTimelineGrab(state *state.State) timeline.GrabFunction {
return func(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) {
statuses, err := state.DB.GetHomeTimeline(ctx, accountID, maxID, sinceID, minID, limit, false)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, true, nil // we just don't have enough statuses left in the db so return stop = true
}
return nil, false, fmt.Errorf("HomeTimelineGrab: error getting statuses from db: %w", err)
}
items := make([]timeline.Timelineable, len(statuses))
for i, s := range statuses {
items[i] = s
}
return items, false, nil
}
}
// HomeTimelineFilter returns a function that satisfies FilterFunction for home timelines.
func HomeTimelineFilter(state *state.State, filter *visibility.Filter) timeline.FilterFunction {
return func(ctx context.Context, accountID string, item timeline.Timelineable) (shouldIndex bool, err error) {
status, ok := item.(*gtsmodel.Status)
if !ok {
return false, errors.New("HomeTimelineFilter: could not convert item to *gtsmodel.Status")
}
requestingAccount, err := state.DB.GetAccountByID(ctx, accountID)
if err != nil {
return false, fmt.Errorf("HomeTimelineFilter: error getting account with id %s: %w", accountID, err)
}
timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status)
if err != nil {
return false, fmt.Errorf("HomeTimelineFilter: error checking hometimelineability of status %s for account %s: %w", status.ID, accountID, err)
}
return timelineable, nil
}
}
// HomeTimelineStatusPrepare returns a function that satisfies PrepareFunction for home timelines.
func HomeTimelineStatusPrepare(state *state.State, tc typeutils.TypeConverter) timeline.PrepareFunction {
return func(ctx context.Context, accountID string, itemID string) (timeline.Preparable, error) {
status, err := state.DB.GetStatusByID(ctx, itemID)
if err != nil {
return nil, fmt.Errorf("StatusPrepare: error getting status with id %s: %w", itemID, err)
}
requestingAccount, err := state.DB.GetAccountByID(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("StatusPrepare: error getting account with id %s: %w", accountID, err)
}
return tc.StatusToAPIStatus(ctx, status, requestingAccount)
}
}
func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, err := p.state.Timelines.Home.GetTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local)
if err != nil {
err = fmt.Errorf("HomeTimelineGet: error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, item := range statuses {
if i == count-1 {
nextMaxIDValue = item.GetID()
}
if i == 0 {
prevMinIDValue = item.GetID()
}
items[i] = item
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/home",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}

View file

@ -0,0 +1,157 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
// ListTimelineGrab returns a function that satisfies GrabFunction for list timelines.
func ListTimelineGrab(state *state.State) timeline.GrabFunction {
return func(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) {
statuses, err := state.DB.GetListTimeline(ctx, listID, maxID, sinceID, minID, limit)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, true, nil // we just don't have enough statuses left in the db so return stop = true
}
return nil, false, fmt.Errorf("ListTimelineGrab: error getting statuses from db: %w", err)
}
items := make([]timeline.Timelineable, len(statuses))
for i, s := range statuses {
items[i] = s
}
return items, false, nil
}
}
// HomeTimelineFilter returns a function that satisfies FilterFunction for list timelines.
func ListTimelineFilter(state *state.State, filter *visibility.Filter) timeline.FilterFunction {
return func(ctx context.Context, listID string, item timeline.Timelineable) (shouldIndex bool, err error) {
status, ok := item.(*gtsmodel.Status)
if !ok {
return false, errors.New("ListTimelineFilter: could not convert item to *gtsmodel.Status")
}
list, err := state.DB.GetListByID(ctx, listID)
if err != nil {
return false, fmt.Errorf("ListTimelineFilter: error getting list with id %s: %w", listID, err)
}
requestingAccount, err := state.DB.GetAccountByID(ctx, list.AccountID)
if err != nil {
return false, fmt.Errorf("ListTimelineFilter: error getting account with id %s: %w", list.AccountID, err)
}
timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status)
if err != nil {
return false, fmt.Errorf("ListTimelineFilter: error checking hometimelineability of status %s for account %s: %w", status.ID, list.AccountID, err)
}
return timelineable, nil
}
}
// ListTimelineStatusPrepare returns a function that satisfies PrepareFunction for list timelines.
func ListTimelineStatusPrepare(state *state.State, tc typeutils.TypeConverter) timeline.PrepareFunction {
return func(ctx context.Context, listID string, itemID string) (timeline.Preparable, error) {
status, err := state.DB.GetStatusByID(ctx, itemID)
if err != nil {
return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting status with id %s: %w", itemID, err)
}
list, err := state.DB.GetListByID(ctx, listID)
if err != nil {
return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting list with id %s: %w", listID, err)
}
requestingAccount, err := state.DB.GetAccountByID(ctx, list.AccountID)
if err != nil {
return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting account with id %s: %w", list.AccountID, err)
}
return tc.StatusToAPIStatus(ctx, status, requestingAccount)
}
}
func (p *Processor) ListTimelineGet(ctx context.Context, authed *oauth.Auth, listID string, maxID string, sinceID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) {
// Ensure list exists + is owned by this account.
list, err := p.state.DB.GetListByID(ctx, listID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorNotFound(err)
}
return nil, gtserror.NewErrorInternalError(err)
}
if list.AccountID != authed.Account.ID {
err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, authed.Account.ID)
return nil, gtserror.NewErrorNotFound(err)
}
statuses, err := p.state.Timelines.List.GetTimeline(ctx, listID, maxID, sinceID, minID, limit, false)
if err != nil {
err = fmt.Errorf("ListTimelineGet: error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, item := range statuses {
if i == count-1 {
nextMaxIDValue = item.GetID()
}
if i == 0 {
prevMinIDValue = item.GetID()
}
items[i] = item
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/list/" + listID,
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}

View file

@ -15,7 +15,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package processing
package timeline
import (
"context"
@ -33,12 +33,7 @@ import (
func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, excludeTypes []string) (*apimodel.PageableResponse, gtserror.WithCode) {
notifs, err := p.state.DB.GetAccountNotifications(ctx, authed.Account.ID, maxID, sinceID, minID, limit, excludeTypes)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// No notifs (left).
return util.EmptyPageableResponse(), nil
}
// An actual error has occurred.
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("NotificationsGet: db error getting notifications: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
@ -73,6 +68,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ma
log.Debugf(ctx, "skipping notification %s because of an error checking notification visibility: %s", n.ID, err)
continue
}
if !visible {
continue
}
@ -85,6 +81,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ma
log.Debugf(ctx, "skipping notification %s because of an error checking notification visibility: %s", n.ID, err)
continue
}
if !visible {
continue
}

View file

@ -0,0 +1,88 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("PublicTimelineGet: db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, 0, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, s := range statuses {
// Set next + prev values before filtering and API
// converting, so caller can still page properly.
if i == count-1 {
nextMaxIDValue = s.ID
}
if i == 0 {
prevMinIDValue = s.ID
}
timelineable, err := p.filter.StatusPublicTimelineable(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking StatusPublicTimelineable: %s", s.ID, err)
continue
}
if !timelineable {
continue
}
apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account)
if err != nil {
log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err)
continue
}
items = append(items, apiStatus)
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/public",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}

View file

@ -0,0 +1,38 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
type Processor struct {
state *state.State
tc typeutils.TypeConverter
filter *visibility.Filter
}
func New(state *state.State, tc typeutils.TypeConverter, filter *visibility.Filter) Processor {
return Processor{
state: state,
tc: tc,
filter: filter,
}
}

Some files were not shown because too many files have changed in this diff Show more